Lohia, Aditya
commited on
Commit
·
4b91514
1
Parent(s):
1ccd3bb
all files
Browse files- .gitattributes +1 -0
- app.py +364 -0
- assets/sample-images/01.png +3 -0
- assets/sample-images/02.png +3 -0
- assets/sample-images/03.png +3 -0
- assets/sample-images/04.png +3 -0
- assets/sample-images/06-1.png +3 -0
- assets/sample-images/08.png +3 -0
- assets/sample-images/1.png +3 -0
- assets/sample-images/10.png +3 -0
- assets/sample-images/2.png +3 -0
- assets/sample-images/3.png +3 -0
- assets/sample-images/4.png +3 -0
- assets/sample-images/barchart.png +3 -0
- gateway.py +105 -0
- requirements.txt +4 -0
- style.css +10 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import re
|
4 |
+
import gradio as gr
|
5 |
+
import base64
|
6 |
+
import io
|
7 |
+
import json
|
8 |
+
from PIL import Image
|
9 |
+
from typing import Iterator
|
10 |
+
|
11 |
+
from gateway import request_generation
|
12 |
+
|
13 |
+
# Setup logging
|
14 |
+
logging.basicConfig(level=logging.INFO)
|
15 |
+
|
16 |
+
# CONSTANTS
|
17 |
+
# Get max new tokens from environment variable, if it is not set, default to 2048
|
18 |
+
MAX_NEW_TOKENS: int = os.getenv("MAX_NEW_TOKENS", 2048)
|
19 |
+
|
20 |
+
# Get max number of images to be passed in the prompt
|
21 |
+
MAX_NUM_IMAGES: int = os.getenv("MAX_NUM_IMAGES")
|
22 |
+
if not MAX_NUM_IMAGES:
|
23 |
+
raise EnvironmentError("MAX_NUM_IMAGES is not set. Please set it to 1 or more.")
|
24 |
+
|
25 |
+
# Validate environment variables
|
26 |
+
CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT")
|
27 |
+
if not CLOUD_GATEWAY_API:
|
28 |
+
raise EnvironmentError("API_ENDPOINT is not set.")
|
29 |
+
|
30 |
+
MODEL_NAME: str = os.getenv("MODEL_NAME")
|
31 |
+
if not MODEL_NAME:
|
32 |
+
raise EnvironmentError("MODEL_NAME is not set.")
|
33 |
+
|
34 |
+
# Get API Key
|
35 |
+
API_KEY = os.getenv("API_KEY")
|
36 |
+
if not API_KEY: # simple check to validate API Key
|
37 |
+
raise Exception("API Key not valid.")
|
38 |
+
|
39 |
+
# Create a header, avoid declaring multiple times
|
40 |
+
HEADER = {"x-api-key": f"{API_KEY}"}
|
41 |
+
|
42 |
+
|
43 |
+
def validate_media(message: str, chat_history: list = None) -> bool:
|
44 |
+
"""Validate the number of image files in the new message.
|
45 |
+
Args:
|
46 |
+
message (str): input message from the user
|
47 |
+
chat_history (list[tuple[str, str]]): entire chat history of the session
|
48 |
+
Returns:
|
49 |
+
bool: True if the number of image files is less than or equal to MAX_NUM_IMAGES, False otherwise
|
50 |
+
"""
|
51 |
+
image_count = sum(1 for path in message["files"])
|
52 |
+
|
53 |
+
if image_count > MAX_NUM_IMAGES:
|
54 |
+
gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images at a time.")
|
55 |
+
return False
|
56 |
+
|
57 |
+
# If there are files, check if they are images
|
58 |
+
if not all(
|
59 |
+
file.lower().endswith((".png", ".jpg", ".jpeg")) for file in message["files"]
|
60 |
+
):
|
61 |
+
gr.Warning("Only images are allowed. Format available: PNG, JPG, JPEG")
|
62 |
+
return True
|
63 |
+
|
64 |
+
|
65 |
+
def encode_pil_to_base64(pil_image: Image.Image, format: str) -> str:
|
66 |
+
"""Encode a PIL image to base64 string.
|
67 |
+
Args:
|
68 |
+
pil_image (Image.Image): PIL image object
|
69 |
+
format (str): format to save the image, defaults to JPEG
|
70 |
+
Returns:
|
71 |
+
str: base64 encoded string of the image
|
72 |
+
"""
|
73 |
+
buffered = io.BytesIO()
|
74 |
+
|
75 |
+
# Handle potential transparency issues for JPEG or JPG
|
76 |
+
if format == "JPEG" and pil_image.mode in ("RGBA", "LA", "P"):
|
77 |
+
# Convert to RGB
|
78 |
+
pil_image = pil_image.convert("RGB")
|
79 |
+
|
80 |
+
# Define save arguments, including quality for JPEG
|
81 |
+
save_kwargs = {"format": format}
|
82 |
+
if format == "JPEG":
|
83 |
+
save_kwargs["quality"] = 85 # Adjust quality as needed (0-100)
|
84 |
+
|
85 |
+
try:
|
86 |
+
pil_image.save(buffered, **save_kwargs)
|
87 |
+
except Exception as e:
|
88 |
+
print(f"Error saving image to buffer with format {format}: {e}")
|
89 |
+
|
90 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
91 |
+
|
92 |
+
# Determine the MIME type based on the format
|
93 |
+
mime_format_part = format.lower()
|
94 |
+
if mime_format_part == "jpeg":
|
95 |
+
mime_type = "image/jpeg"
|
96 |
+
elif mime_format_part == "png":
|
97 |
+
mime_type = "image/png"
|
98 |
+
else:
|
99 |
+
gr.Error(f"Unsupported image format: {format}")
|
100 |
+
return None
|
101 |
+
|
102 |
+
return f"data:{mime_type};base64,{img_str}"
|
103 |
+
|
104 |
+
|
105 |
+
def process_images(message: list) -> list[dict]:
|
106 |
+
"""Process images in the message.
|
107 |
+
Args:
|
108 |
+
message (list): message list containing text and files
|
109 |
+
Returns:
|
110 |
+
list[dict]: list of dictionaries containing text and image content
|
111 |
+
"""
|
112 |
+
content = []
|
113 |
+
|
114 |
+
# Iterate through the files in the message
|
115 |
+
for path in message:
|
116 |
+
pil_image = Image.open(path)
|
117 |
+
# Get the image format
|
118 |
+
image_format = pil_image.format.upper()
|
119 |
+
if image_format == "JPG":
|
120 |
+
image_format = "JPEG"
|
121 |
+
|
122 |
+
if image_format in ["JPEG", "PNG"]:
|
123 |
+
# Converting image to base64
|
124 |
+
base64_image_data = encode_pil_to_base64(pil_image, format=image_format)
|
125 |
+
content.append(
|
126 |
+
{"type": "image_url", "image_url": {"url": base64_image_data}}
|
127 |
+
)
|
128 |
+
|
129 |
+
return content
|
130 |
+
|
131 |
+
|
132 |
+
def process_new_user_message(message: dict) -> list[dict]:
|
133 |
+
"""Process the new user message and return a list of dictionaries containing text and image content.
|
134 |
+
Args:
|
135 |
+
message (dict): message dictionary containing text and files
|
136 |
+
Returns:
|
137 |
+
list[dict]: list of dictionaries containing text and image content
|
138 |
+
"""
|
139 |
+
# Create the content list messages
|
140 |
+
messages = []
|
141 |
+
|
142 |
+
if message["text"]:
|
143 |
+
# Append the text part to the content list
|
144 |
+
messages.append({"type": "text", "text": message["text"]})
|
145 |
+
|
146 |
+
if not message["files"]:
|
147 |
+
# If there are no files, return the text part only
|
148 |
+
return messages
|
149 |
+
else:
|
150 |
+
# If there are files, process the images
|
151 |
+
image_content = process_images(message["files"])
|
152 |
+
# Append the image content to the messages list
|
153 |
+
messages.extend(image_content)
|
154 |
+
return messages
|
155 |
+
else:
|
156 |
+
# If there are no text parts, throw a gr.Warning to insert prompt and return nothing
|
157 |
+
gr.Warning("Please insert a prompt.")
|
158 |
+
return []
|
159 |
+
|
160 |
+
|
161 |
+
def run(
|
162 |
+
message: str,
|
163 |
+
chat_history: list,
|
164 |
+
system_prompt: str,
|
165 |
+
max_new_tokens: int = 1024,
|
166 |
+
temperature: float = 0.6,
|
167 |
+
frequency_penalty: float = 0.0,
|
168 |
+
presence_penalty: float = 0.0,
|
169 |
+
) -> Iterator[str]:
|
170 |
+
"""Send a request to backend, fetch the streaming responses and emit to the UI.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
message (str): input message from the user
|
174 |
+
chat_history (list[tuple[str, str]]): entire chat history of the session
|
175 |
+
system_prompt (str): system prompt
|
176 |
+
max_new_tokens (int, optional): maximum number of tokens to generate, ignoring the number of tokens in the
|
177 |
+
prompt. Defaults to 1024.
|
178 |
+
temperature (float, optional): the value used to module the next token probabilities. Defaults to 0.6.
|
179 |
+
top_p (float, optional): if set to float<1, only the smallest set of most probable tokens with probabilities
|
180 |
+
that add up to top_p or higher are kept for generation. Defaults to 0.9.
|
181 |
+
top_k (int, optional): the number of highest probability vocabulary tokens to keep for top-k-filtering.
|
182 |
+
Defaults to 50.
|
183 |
+
repetition_penalty (float, optional): the parameter for repetition penalty. 1.0 means no penalty.
|
184 |
+
Defaults to 1.2.
|
185 |
+
|
186 |
+
Yields:
|
187 |
+
Iterator[str]: Streaming responses to the UI
|
188 |
+
"""
|
189 |
+
if not validate_media(message):
|
190 |
+
# If the number of image files is not valid, return an empty string
|
191 |
+
yield ""
|
192 |
+
return
|
193 |
+
|
194 |
+
messages = []
|
195 |
+
if system_prompt:
|
196 |
+
messages.append(
|
197 |
+
{"role": "system", "content": [{"type": "text", "text": system_prompt}]}
|
198 |
+
)
|
199 |
+
|
200 |
+
# Append the new user message if it returns anything other than empty string
|
201 |
+
content = process_new_user_message(message)
|
202 |
+
if content:
|
203 |
+
# Append the new user message to the messages list
|
204 |
+
messages.append({"role": "user", "content": content})
|
205 |
+
else:
|
206 |
+
# If the content is empty, return an empty string
|
207 |
+
yield ""
|
208 |
+
return
|
209 |
+
|
210 |
+
# sample method to yield responses from the llm model
|
211 |
+
outputs = []
|
212 |
+
for text in request_generation(
|
213 |
+
header=HEADER,
|
214 |
+
messages=messages,
|
215 |
+
max_new_tokens=max_new_tokens,
|
216 |
+
temperature=temperature,
|
217 |
+
presence_penalty=presence_penalty,
|
218 |
+
frequency_penalty=frequency_penalty,
|
219 |
+
cloud_gateway_api=CLOUD_GATEWAY_API,
|
220 |
+
model_name=MODEL_NAME,
|
221 |
+
):
|
222 |
+
outputs.append(text)
|
223 |
+
yield "".join(outputs)
|
224 |
+
|
225 |
+
|
226 |
+
examples = [
|
227 |
+
["Plan a three-day trip to Washington DC for Cherry Blossom Festival."],
|
228 |
+
["How many hours does it take a man to eat a Helicopter?"],
|
229 |
+
[
|
230 |
+
{
|
231 |
+
"text": "Write the matplotlib code to generate the same bar chart.",
|
232 |
+
"files": ["assets/sample-images/barchart.png"],
|
233 |
+
}
|
234 |
+
],
|
235 |
+
[
|
236 |
+
{
|
237 |
+
"text": "Describe the atmosphere of the scene.",
|
238 |
+
"files": ["assets/sample-images/06-1.png"],
|
239 |
+
}
|
240 |
+
],
|
241 |
+
[
|
242 |
+
{
|
243 |
+
"text": "Write a short story about what might have happened in this house.",
|
244 |
+
"files": ["assets/sample-images/08.png"],
|
245 |
+
}
|
246 |
+
],
|
247 |
+
[
|
248 |
+
{
|
249 |
+
"text": "Describe the creatures that would live in this world.",
|
250 |
+
"files": ["assets/sample-images/10.png"],
|
251 |
+
}
|
252 |
+
],
|
253 |
+
[
|
254 |
+
{
|
255 |
+
"text": "Read text in the image.",
|
256 |
+
"files": ["assets/sample-images/1.png"],
|
257 |
+
}
|
258 |
+
],
|
259 |
+
[
|
260 |
+
{
|
261 |
+
"text": "When is this ticket dated and how much did it cost?",
|
262 |
+
"files": ["assets/sample-images/2.png"],
|
263 |
+
}
|
264 |
+
],
|
265 |
+
[
|
266 |
+
{
|
267 |
+
"text": "Read the text in the image into markdown.",
|
268 |
+
"files": ["assets/sample-images/3.png"],
|
269 |
+
}
|
270 |
+
],
|
271 |
+
[
|
272 |
+
{
|
273 |
+
"text": "Evaluate this integral.",
|
274 |
+
"files": ["assets/sample-images/4.png"],
|
275 |
+
}
|
276 |
+
],
|
277 |
+
[
|
278 |
+
{
|
279 |
+
"text": "Caption this image",
|
280 |
+
"files": ["assets/sample-images/01.png"],
|
281 |
+
}
|
282 |
+
],
|
283 |
+
[
|
284 |
+
{
|
285 |
+
"text": "What's the sign says?",
|
286 |
+
"files": ["assets/sample-images/02.png"],
|
287 |
+
}
|
288 |
+
],
|
289 |
+
[
|
290 |
+
{
|
291 |
+
"text": "Compare and contrast the two images.",
|
292 |
+
"files": ["assets/sample-images/03.png"],
|
293 |
+
}
|
294 |
+
],
|
295 |
+
[
|
296 |
+
{
|
297 |
+
"text": "List all the objects in the image and their colors.",
|
298 |
+
"files": ["assets/sample-images/04.png"],
|
299 |
+
}
|
300 |
+
],
|
301 |
+
]
|
302 |
+
|
303 |
+
|
304 |
+
demo = gr.ChatInterface(
|
305 |
+
fn=run,
|
306 |
+
type="messages",
|
307 |
+
chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
|
308 |
+
textbox=gr.MultimodalTextbox(
|
309 |
+
file_types=["image"],
|
310 |
+
file_count="single" if MAX_NUM_IMAGES == 1 else "multiple",
|
311 |
+
autofocus=True,
|
312 |
+
),
|
313 |
+
multimodal=True,
|
314 |
+
additional_inputs=[
|
315 |
+
gr.Textbox(
|
316 |
+
label="System prompt",
|
317 |
+
# value="You are a highly capable AI assistant. Provide accurate, concise, and fact-based responses that are directly relevant to the user's query. Avoid speculation, ensure logical consistency, and maintain clarity in longer outputs.",
|
318 |
+
value="",
|
319 |
+
lines=3,
|
320 |
+
),
|
321 |
+
gr.Slider(
|
322 |
+
label="Max New Tokens",
|
323 |
+
minimum=1,
|
324 |
+
maximum=MAX_NEW_TOKENS,
|
325 |
+
step=1,
|
326 |
+
value=2048,
|
327 |
+
),
|
328 |
+
gr.Slider(
|
329 |
+
label="Temperature",
|
330 |
+
minimum=0.1,
|
331 |
+
maximum=4.0,
|
332 |
+
step=0.1,
|
333 |
+
value=0.3,
|
334 |
+
),
|
335 |
+
gr.Slider(
|
336 |
+
label="Frequency penalty",
|
337 |
+
minimum=-2.0,
|
338 |
+
maximum=2.0,
|
339 |
+
step=0.1,
|
340 |
+
value=0.0,
|
341 |
+
),
|
342 |
+
gr.Slider(
|
343 |
+
label="Presence penalty",
|
344 |
+
minimum=-2.0,
|
345 |
+
maximum=2.0,
|
346 |
+
step=0.1,
|
347 |
+
value=0.0,
|
348 |
+
),
|
349 |
+
],
|
350 |
+
stop_btn=False,
|
351 |
+
title="Llama-4 Scout Instruct",
|
352 |
+
description="This Space is an Alpha release that demonstrates [Llama-4-Scout](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E) model running on AMD MI300 infrastructure. The space is built with Meta Llama 4 [License](https://www.llama.com/llama4/license/). Feel free to play with it!",
|
353 |
+
fill_height=True,
|
354 |
+
run_examples_on_click=False,
|
355 |
+
examples=examples,
|
356 |
+
cache_examples=False,
|
357 |
+
)
|
358 |
+
|
359 |
+
|
360 |
+
if __name__ == "__main__":
|
361 |
+
demo.queue(
|
362 |
+
max_size=int(os.getenv("QUEUE")),
|
363 |
+
default_concurrency_limit=int(os.getenv("CONCURRENCY_LIMIT")),
|
364 |
+
).launch()
|
assets/sample-images/01.png
ADDED
![]() |
Git LFS Details
|
assets/sample-images/02.png
ADDED
![]() |
Git LFS Details
|
assets/sample-images/03.png
ADDED
![]() |
Git LFS Details
|
assets/sample-images/04.png
ADDED
![]() |
Git LFS Details
|
assets/sample-images/06-1.png
ADDED
![]() |
Git LFS Details
|
assets/sample-images/08.png
ADDED
![]() |
Git LFS Details
|
assets/sample-images/1.png
ADDED
![]() |
Git LFS Details
|
assets/sample-images/10.png
ADDED
![]() |
Git LFS Details
|
assets/sample-images/2.png
ADDED
![]() |
Git LFS Details
|
assets/sample-images/3.png
ADDED
![]() |
Git LFS Details
|
assets/sample-images/4.png
ADDED
![]() |
Git LFS Details
|
assets/sample-images/barchart.png
ADDED
![]() |
Git LFS Details
|
gateway.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import requests
|
4 |
+
import urllib3
|
5 |
+
|
6 |
+
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
7 |
+
|
8 |
+
# Setup logging
|
9 |
+
logging.basicConfig(level=logging.INFO)
|
10 |
+
|
11 |
+
|
12 |
+
def request_generation(
|
13 |
+
header: dict,
|
14 |
+
messages: dict,
|
15 |
+
cloud_gateway_api: str,
|
16 |
+
model_name: str,
|
17 |
+
max_new_tokens: int = 1024,
|
18 |
+
temperature: float = 0.3,
|
19 |
+
frequency_penalty: float = 0.0,
|
20 |
+
presence_penalty: float = 0.0,
|
21 |
+
):
|
22 |
+
"""
|
23 |
+
Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize
|
24 |
+
token-by-token generation from LLM.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
header: authorization header for the API.
|
28 |
+
message: prompt from the user.
|
29 |
+
system_prompt: system prompt to append.
|
30 |
+
cloud_gateway_api (str): API endpoint to send the request.
|
31 |
+
max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt.
|
32 |
+
temperature: the value used to module the next token probabilities.
|
33 |
+
top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p
|
34 |
+
or higher are kept for generation.
|
35 |
+
repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
|
39 |
+
"""
|
40 |
+
|
41 |
+
payload = {
|
42 |
+
"model": model_name,
|
43 |
+
"messages": messages,
|
44 |
+
"max_tokens": max_new_tokens,
|
45 |
+
"temperature": temperature,
|
46 |
+
"frequency_penalty": frequency_penalty,
|
47 |
+
"presence_penalty": presence_penalty,
|
48 |
+
"stream": True, # Enable streaming
|
49 |
+
"serving_runtime": "vllm",
|
50 |
+
}
|
51 |
+
|
52 |
+
try:
|
53 |
+
response = requests.post(
|
54 |
+
cloud_gateway_api + "chat/conversation",
|
55 |
+
headers=header,
|
56 |
+
json=payload,
|
57 |
+
verify=False,
|
58 |
+
)
|
59 |
+
print(response.text)
|
60 |
+
response.raise_for_status()
|
61 |
+
|
62 |
+
# Append the conversation ID with the key X-Conversation-ID to the header
|
63 |
+
header["X-Conversation-ID"] = response.json()["conversationId"]
|
64 |
+
|
65 |
+
with requests.get(
|
66 |
+
cloud_gateway_api + f"conversation/stream",
|
67 |
+
headers=header,
|
68 |
+
verify=False,
|
69 |
+
stream=True,
|
70 |
+
) as response:
|
71 |
+
for chunk in response.iter_lines():
|
72 |
+
if chunk:
|
73 |
+
# Convert the chunk from bytes to a string and then parse it as json
|
74 |
+
chunk_str = chunk.decode("utf-8")
|
75 |
+
|
76 |
+
# Remove the `data: ` prefix from the chunk if it exists
|
77 |
+
for _ in range(2):
|
78 |
+
if chunk_str.startswith("data: "):
|
79 |
+
chunk_str = chunk_str[len("data: ") :]
|
80 |
+
|
81 |
+
# Skip empty chunks
|
82 |
+
if chunk_str.strip() == "[DONE]":
|
83 |
+
break
|
84 |
+
|
85 |
+
# Parse the chunk into a JSON object
|
86 |
+
try:
|
87 |
+
chunk_json = json.loads(chunk_str)
|
88 |
+
|
89 |
+
# Extract the "content" field from the choices
|
90 |
+
if "choices" in chunk_json and chunk_json["choices"]:
|
91 |
+
content = chunk_json["choices"][0]["delta"].get(
|
92 |
+
"content", ""
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
content = ""
|
96 |
+
|
97 |
+
# Print the generated content as it's streamed
|
98 |
+
if content:
|
99 |
+
yield content
|
100 |
+
except json.JSONDecodeError:
|
101 |
+
# Handle any potential errors in decoding
|
102 |
+
continue
|
103 |
+
except requests.RequestException as e:
|
104 |
+
logging.error(f"Failed to generate response: {e}")
|
105 |
+
yield "Server not responding. Please try again later."
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
pillow
|
3 |
+
fastapi
|
4 |
+
websockets
|
style.css
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
display: block;
|
4 |
+
}
|
5 |
+
|
6 |
+
.contain {
|
7 |
+
max-width: 900px;
|
8 |
+
margin: auto;
|
9 |
+
padding-top: 1.5rem;
|
10 |
+
}
|