File size: 3,646 Bytes
a982432
 
 
f7ffb56
312f78e
aa6c44a
 
 
 
312f78e
 
5fed0cd
 
 
 
 
 
 
 
 
a982432
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fed0cd
 
 
a982432
ba3ad98
 
 
a982432
 
5fed0cd
f7ffb56
5fed0cd
 
cb16315
5fed0cd
 
 
 
 
 
 
 
 
cb16315
 
a4ef5a9
 
ba3ad98
a4ef5a9
 
 
 
ba3ad98
a982432
ba3ad98
 
ed77ed1
 
 
 
 
 
 
 
a982432
 
5fed0cd
a982432
5fed0cd
 
f350335
5fed0cd
 
 
f350335
a982432
5fed0cd
 
 
a982432
5fed0cd
 
 
 
 
ed77ed1
a982432
 
5fed0cd
a982432
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import spaces
import os 

import gc 
import torch

# Create the necessary directories
os.makedirs('.gradio/cached_examples/17', exist_ok=True)

def get_model_name(language):
    """Map language choice to the corresponding model."""
    model_mapping = {
        "English": "microsoft/Phi-3-mini-4k-instruct",
        "Arabic": "ALLaM-AI/ALLaM-7B-Instruct-preview"
    }
    return model_mapping.get(language, "ALLaM-AI/ALLaM-7B-Instruct-preview")  # Default to Arabic model

def load_model(model_name):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map=device,
        torch_dtype="auto",
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    generator = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        return_full_text=False,
        max_new_tokens=500,
        do_sample=True,  # Enable sampling for more creative outputs
        top_k=50,        # Control diversity
        top_p=0.95       # Control diversity
    )
    del model 
    del tokenizer
    
    return generator


@spaces.GPU
def generate_kids_story(character, setting, language):
    model_name = get_model_name(language)
    generator = load_model(model_name)

    # Define prompt for the AI model
    if language == "English":
        prompt = (f"Write a short story for kids about a character named {character} who goes on an adventure in {setting}. "
                  "Make it fun, engaging, and suitable for children.")
    else:
        prompt = (f"اكتب قصة قصيرة للأطفال عن شخصية اسمها {character} التي تذهب في مغامرة في {setting}. "
                  "اجعلها ممتعة وجذابة ومناسبة للأطفال.")

    messages = [{"role": "user", "content": prompt}]
    output = generator(messages)
    
    # Delete model and associated objects 
    del generator
    # Run garbage collection
    gc.collect ()
    # Empty CUDA cache
    torch.cuda.empty_cache()
    
    return output[0]["generated_text"]


css_style = """
    body {
        background-image: url('https://cdna.artstation.com/p/assets/images/images/074/776/904/large/pietro-chiovaro-r1-castle-chp.jpg?1712916847');
        background-size: cover;
        background-position: center;
        color: #fff;  /* General text color */
        font-family: 'Arial', sans-serif;
    }"""
# Create Gradio interface
demo = gr.Interface(
    fn=generate_kids_story,
    inputs=[
        gr.Textbox(placeholder="Enter a character name (e.g., Benny the Bunny)...", label="Character Name"),
        gr.Textbox(placeholder="Enter a setting (e.g., a magical forest)...", label="Setting"),
        gr.Dropdown(
            choices=["English", "Arabic"],
            label="Choose Language",
            value="English"  # Default to English
        )
    ],
    outputs=gr.Textbox(label="Kids' Story"),
    title="📖 AI Kids' Story Generator - English & Arabic 📖",
    description="Enter a character name and a setting, and AI will generate a fun short story for kids in English or Arabic.",
    examples=[
        ["Benny the Bunny", "a magical forest", "English"],
        ["علي البطل", "غابة سحرية", "Arabic"],
        ["Lila the Ladybug", "a garden full of flowers", "English"],
        ["ليلى الجنية", "حديقة مليئة بالأزهار", "Arabic"]
    ],
    css= css_style,
)

# Launch the Gradio app
demo.launch()