Update app.py
Browse files
app.py
CHANGED
@@ -5,23 +5,51 @@ from PIL import Image
|
|
5 |
from io import BytesIO
|
6 |
import base64
|
7 |
import os
|
|
|
8 |
|
9 |
# Initialize the Google Generative AI client with the API key from environment variables
|
10 |
-
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def generate_item(tag):
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
text_response = client.models.generate_content(
|
16 |
model='gemini-2.5-flash-preview-04-17',
|
17 |
contents=[prompt]
|
18 |
)
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
image_response = client.models.generate_images(
|
23 |
model='imagen-3.0-generate-002',
|
24 |
-
prompt=
|
25 |
config=types.GenerateImagesConfig(
|
26 |
number_of_images=1,
|
27 |
aspect_ratio="9:16",
|
@@ -34,8 +62,8 @@ def generate_item(tag):
|
|
34 |
generated_image = image_response.generated_images[0]
|
35 |
image = Image.open(BytesIO(generated_image.image.image_bytes))
|
36 |
else:
|
37 |
-
# Fallback to a placeholder image
|
38 |
-
image = Image.new('RGB', (300, 533), color='gray') #
|
39 |
|
40 |
# Convert the image to base64
|
41 |
buffered = BytesIO()
|
@@ -54,6 +82,8 @@ def start_feed(tag):
|
|
54 |
Returns:
|
55 |
tuple: (current_tag, feed_items, html_content)
|
56 |
"""
|
|
|
|
|
57 |
item = generate_item(tag)
|
58 |
feed_items = [item]
|
59 |
html_content = generate_html(feed_items)
|
@@ -77,7 +107,7 @@ def load_more(current_tag, feed_items):
|
|
77 |
|
78 |
def generate_html(feed_items):
|
79 |
"""
|
80 |
-
Generate an HTML string to display the feed items.
|
81 |
|
82 |
Args:
|
83 |
feed_items (list): List of dictionaries containing 'text' and 'image_base64'.
|
@@ -85,41 +115,103 @@ def generate_html(feed_items):
|
|
85 |
Returns:
|
86 |
str: HTML string representing the feed.
|
87 |
"""
|
88 |
-
html_str =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
for item in feed_items:
|
90 |
html_str += f"""
|
91 |
-
<div style="
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
</div>
|
95 |
"""
|
96 |
-
html_str +=
|
97 |
return html_str
|
98 |
|
99 |
# Define the Gradio interface
|
100 |
-
with gr.Blocks(
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
# Output display
|
120 |
-
feed_html = gr.HTML(
|
121 |
|
122 |
-
# State variables
|
123 |
current_tag = gr.State(value="")
|
124 |
feed_items = gr.State(value=[])
|
125 |
|
|
|
5 |
from io import BytesIO
|
6 |
import base64
|
7 |
import os
|
8 |
+
import json
|
9 |
|
10 |
# Initialize the Google Generative AI client with the API key from environment variables
|
11 |
+
try:
|
12 |
+
api_key = os.environ['GEMINI_API_KEY']
|
13 |
+
except KeyError:
|
14 |
+
raise ValueError("Please set the GEMINI_API_KEY environment variable.")
|
15 |
+
client = genai.Client(api_key=api_key)
|
16 |
|
17 |
def generate_item(tag):
|
18 |
+
"""
|
19 |
+
Generate a single feed item consisting of text from Gemini LLM and an image from Imagen.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
tag (str): The tag to base the content on.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
dict: A dictionary with 'text' (str) and 'image_base64' (str).
|
26 |
+
"""
|
27 |
+
# Generate text using Gemini LLM with JSON output
|
28 |
+
prompt = f"""
|
29 |
+
Generate a short, engaging TikTok-style caption about {tag}.
|
30 |
+
Return the response as a JSON object with a single key 'caption' containing the caption text.
|
31 |
+
Example: {{"caption": "Craving this yummy treat! 😍 #foodie"}}
|
32 |
+
Do not include additional commentary or options.
|
33 |
+
"""
|
34 |
text_response = client.models.generate_content(
|
35 |
model='gemini-2.5-flash-preview-04-17',
|
36 |
contents=[prompt]
|
37 |
)
|
38 |
+
# Parse JSON response to extract the caption
|
39 |
+
try:
|
40 |
+
response_json = json.loads(text_response.text.strip())
|
41 |
+
text = response_json['caption']
|
42 |
+
except (json.JSONDecodeError, KeyError):
|
43 |
+
text = f"Wow, {tag} is amazing! 😍 #{tag}" # Fallback caption
|
44 |
+
|
45 |
+
# Generate an image based on the tag, avoiding text
|
46 |
+
image_prompt = f"""
|
47 |
+
A vivid, high-quality visual scene representing {tag}, designed for a TikTok video.
|
48 |
+
The image should be colorful and engaging, with no text or letters included.
|
49 |
+
"""
|
50 |
image_response = client.models.generate_images(
|
51 |
model='imagen-3.0-generate-002',
|
52 |
+
prompt=image_prompt,
|
53 |
config=types.GenerateImagesConfig(
|
54 |
number_of_images=1,
|
55 |
aspect_ratio="9:16",
|
|
|
62 |
generated_image = image_response.generated_images[0]
|
63 |
image = Image.open(BytesIO(generated_image.image.image_bytes))
|
64 |
else:
|
65 |
+
# Fallback to a placeholder image
|
66 |
+
image = Image.new('RGB', (300, 533), color='gray') # 9:16 aspect ratio
|
67 |
|
68 |
# Convert the image to base64
|
69 |
buffered = BytesIO()
|
|
|
82 |
Returns:
|
83 |
tuple: (current_tag, feed_items, html_content)
|
84 |
"""
|
85 |
+
if not tag.strip():
|
86 |
+
tag = "trending" # Default tag if empty
|
87 |
item = generate_item(tag)
|
88 |
feed_items = [item]
|
89 |
html_content = generate_html(feed_items)
|
|
|
107 |
|
108 |
def generate_html(feed_items):
|
109 |
"""
|
110 |
+
Generate an HTML string to display the feed items in a TikTok-like vertical layout.
|
111 |
|
112 |
Args:
|
113 |
feed_items (list): List of dictionaries containing 'text' and 'image_base64'.
|
|
|
115 |
Returns:
|
116 |
str: HTML string representing the feed.
|
117 |
"""
|
118 |
+
html_str = """
|
119 |
+
<div style="
|
120 |
+
display: flex;
|
121 |
+
flex-direction: column;
|
122 |
+
align-items: center;
|
123 |
+
max-width: 360px;
|
124 |
+
margin: 0 auto;
|
125 |
+
background-color: #000;
|
126 |
+
height: 640px;
|
127 |
+
overflow-y: auto;
|
128 |
+
scrollbar-width: none;
|
129 |
+
-ms-overflow-style: none;
|
130 |
+
border: 1px solid #333;
|
131 |
+
border-radius: 10px;
|
132 |
+
">
|
133 |
+
"""
|
134 |
+
# Hide scrollbar for a cleaner look
|
135 |
+
html_str += """
|
136 |
+
<style>
|
137 |
+
div::-webkit-scrollbar {
|
138 |
+
display: none;
|
139 |
+
}
|
140 |
+
</style>
|
141 |
+
"""
|
142 |
for item in feed_items:
|
143 |
html_str += f"""
|
144 |
+
<div style="
|
145 |
+
width: 100%;
|
146 |
+
height: 640px;
|
147 |
+
position: relative;
|
148 |
+
display: flex;
|
149 |
+
flex-direction: column;
|
150 |
+
justify-content: flex-end;
|
151 |
+
overflow: hidden;
|
152 |
+
">
|
153 |
+
<img src="data:image/png;base64,{item['image_base64']}" style="
|
154 |
+
width: 100%;
|
155 |
+
height: 100%;
|
156 |
+
object-fit: cover;
|
157 |
+
position: absolute;
|
158 |
+
top: 0;
|
159 |
+
left: 0;
|
160 |
+
z-index: 1;
|
161 |
+
">
|
162 |
+
<div style="
|
163 |
+
position: relative;
|
164 |
+
z-index: 2;
|
165 |
+
background: linear-gradient(to top, rgba(0,0,0,0.7), transparent);
|
166 |
+
padding: 20px;
|
167 |
+
color: white;
|
168 |
+
font-family: Arial, sans-serif;
|
169 |
+
font-size: 18px;
|
170 |
+
font-weight: bold;
|
171 |
+
text-shadow: 1px 1px 2px rgba(0,0,0,0.5);
|
172 |
+
">
|
173 |
+
{item['text']}
|
174 |
+
</div>
|
175 |
</div>
|
176 |
"""
|
177 |
+
html_str += "</div>"
|
178 |
return html_str
|
179 |
|
180 |
# Define the Gradio interface
|
181 |
+
with gr.Blocks(
|
182 |
+
css="""
|
183 |
+
body { background-color: #000; color: #fff; font-family: Arial, sans-serif; }
|
184 |
+
.gradio-container { max-width: 400px; margin: 0 auto; padding: 10px; }
|
185 |
+
input, select, button { border-radius: 5px; background-color: #222; color: #fff; border: 1px solid #444; }
|
186 |
+
button { background-color: #ff2d55; border: none; }
|
187 |
+
button:hover { background-color: #e0264b; }
|
188 |
+
.gr-button { width: 100%; margin-top: 10px; }
|
189 |
+
.gr-form { background-color: #111; padding: 15px; border-radius: 10px; }
|
190 |
+
""",
|
191 |
+
title="TikTok-Style Infinite Feed"
|
192 |
+
) as demo:
|
193 |
+
# Input section
|
194 |
+
with gr.Column(elem_classes="gr-form"):
|
195 |
+
gr.Markdown("### Create Your TikTok Feed")
|
196 |
+
with gr.Row():
|
197 |
+
suggested_tags = gr.Dropdown(
|
198 |
+
choices=["food", "travel", "fashion", "tech"],
|
199 |
+
label="Pick a Tag",
|
200 |
+
value="food"
|
201 |
+
)
|
202 |
+
tag_input = gr.Textbox(
|
203 |
+
label="Or Enter a Custom Tag",
|
204 |
+
value="food",
|
205 |
+
placeholder="e.g., sushi, adventure"
|
206 |
+
)
|
207 |
+
with gr.Row():
|
208 |
+
start_button = gr.Button("Start Feed")
|
209 |
+
load_more_button = gr.Button("Load More")
|
210 |
|
211 |
# Output display
|
212 |
+
feed_html = gr.HTML()
|
213 |
|
214 |
+
# State variables
|
215 |
current_tag = gr.State(value="")
|
216 |
feed_items = gr.State(value=[])
|
217 |
|