Lohia, Aditya commited on
Commit
4b91514
·
1 Parent(s): 1ccd3bb
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import re
4
+ import gradio as gr
5
+ import base64
6
+ import io
7
+ import json
8
+ from PIL import Image
9
+ from typing import Iterator
10
+
11
+ from gateway import request_generation
12
+
13
+ # Setup logging
14
+ logging.basicConfig(level=logging.INFO)
15
+
16
+ # CONSTANTS
17
+ # Get max new tokens from environment variable, if it is not set, default to 2048
18
+ MAX_NEW_TOKENS: int = os.getenv("MAX_NEW_TOKENS", 2048)
19
+
20
+ # Get max number of images to be passed in the prompt
21
+ MAX_NUM_IMAGES: int = os.getenv("MAX_NUM_IMAGES")
22
+ if not MAX_NUM_IMAGES:
23
+ raise EnvironmentError("MAX_NUM_IMAGES is not set. Please set it to 1 or more.")
24
+
25
+ # Validate environment variables
26
+ CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT")
27
+ if not CLOUD_GATEWAY_API:
28
+ raise EnvironmentError("API_ENDPOINT is not set.")
29
+
30
+ MODEL_NAME: str = os.getenv("MODEL_NAME")
31
+ if not MODEL_NAME:
32
+ raise EnvironmentError("MODEL_NAME is not set.")
33
+
34
+ # Get API Key
35
+ API_KEY = os.getenv("API_KEY")
36
+ if not API_KEY: # simple check to validate API Key
37
+ raise Exception("API Key not valid.")
38
+
39
+ # Create a header, avoid declaring multiple times
40
+ HEADER = {"x-api-key": f"{API_KEY}"}
41
+
42
+
43
+ def validate_media(message: str, chat_history: list = None) -> bool:
44
+ """Validate the number of image files in the new message.
45
+ Args:
46
+ message (str): input message from the user
47
+ chat_history (list[tuple[str, str]]): entire chat history of the session
48
+ Returns:
49
+ bool: True if the number of image files is less than or equal to MAX_NUM_IMAGES, False otherwise
50
+ """
51
+ image_count = sum(1 for path in message["files"])
52
+
53
+ if image_count > MAX_NUM_IMAGES:
54
+ gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images at a time.")
55
+ return False
56
+
57
+ # If there are files, check if they are images
58
+ if not all(
59
+ file.lower().endswith((".png", ".jpg", ".jpeg")) for file in message["files"]
60
+ ):
61
+ gr.Warning("Only images are allowed. Format available: PNG, JPG, JPEG")
62
+ return True
63
+
64
+
65
+ def encode_pil_to_base64(pil_image: Image.Image, format: str) -> str:
66
+ """Encode a PIL image to base64 string.
67
+ Args:
68
+ pil_image (Image.Image): PIL image object
69
+ format (str): format to save the image, defaults to JPEG
70
+ Returns:
71
+ str: base64 encoded string of the image
72
+ """
73
+ buffered = io.BytesIO()
74
+
75
+ # Handle potential transparency issues for JPEG or JPG
76
+ if format == "JPEG" and pil_image.mode in ("RGBA", "LA", "P"):
77
+ # Convert to RGB
78
+ pil_image = pil_image.convert("RGB")
79
+
80
+ # Define save arguments, including quality for JPEG
81
+ save_kwargs = {"format": format}
82
+ if format == "JPEG":
83
+ save_kwargs["quality"] = 85 # Adjust quality as needed (0-100)
84
+
85
+ try:
86
+ pil_image.save(buffered, **save_kwargs)
87
+ except Exception as e:
88
+ print(f"Error saving image to buffer with format {format}: {e}")
89
+
90
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
91
+
92
+ # Determine the MIME type based on the format
93
+ mime_format_part = format.lower()
94
+ if mime_format_part == "jpeg":
95
+ mime_type = "image/jpeg"
96
+ elif mime_format_part == "png":
97
+ mime_type = "image/png"
98
+ else:
99
+ gr.Error(f"Unsupported image format: {format}")
100
+ return None
101
+
102
+ return f"data:{mime_type};base64,{img_str}"
103
+
104
+
105
+ def process_images(message: list) -> list[dict]:
106
+ """Process images in the message.
107
+ Args:
108
+ message (list): message list containing text and files
109
+ Returns:
110
+ list[dict]: list of dictionaries containing text and image content
111
+ """
112
+ content = []
113
+
114
+ # Iterate through the files in the message
115
+ for path in message:
116
+ pil_image = Image.open(path)
117
+ # Get the image format
118
+ image_format = pil_image.format.upper()
119
+ if image_format == "JPG":
120
+ image_format = "JPEG"
121
+
122
+ if image_format in ["JPEG", "PNG"]:
123
+ # Converting image to base64
124
+ base64_image_data = encode_pil_to_base64(pil_image, format=image_format)
125
+ content.append(
126
+ {"type": "image_url", "image_url": {"url": base64_image_data}}
127
+ )
128
+
129
+ return content
130
+
131
+
132
+ def process_new_user_message(message: dict) -> list[dict]:
133
+ """Process the new user message and return a list of dictionaries containing text and image content.
134
+ Args:
135
+ message (dict): message dictionary containing text and files
136
+ Returns:
137
+ list[dict]: list of dictionaries containing text and image content
138
+ """
139
+ # Create the content list messages
140
+ messages = []
141
+
142
+ if message["text"]:
143
+ # Append the text part to the content list
144
+ messages.append({"type": "text", "text": message["text"]})
145
+
146
+ if not message["files"]:
147
+ # If there are no files, return the text part only
148
+ return messages
149
+ else:
150
+ # If there are files, process the images
151
+ image_content = process_images(message["files"])
152
+ # Append the image content to the messages list
153
+ messages.extend(image_content)
154
+ return messages
155
+ else:
156
+ # If there are no text parts, throw a gr.Warning to insert prompt and return nothing
157
+ gr.Warning("Please insert a prompt.")
158
+ return []
159
+
160
+
161
+ def run(
162
+ message: str,
163
+ chat_history: list,
164
+ system_prompt: str,
165
+ max_new_tokens: int = 1024,
166
+ temperature: float = 0.6,
167
+ frequency_penalty: float = 0.0,
168
+ presence_penalty: float = 0.0,
169
+ ) -> Iterator[str]:
170
+ """Send a request to backend, fetch the streaming responses and emit to the UI.
171
+
172
+ Args:
173
+ message (str): input message from the user
174
+ chat_history (list[tuple[str, str]]): entire chat history of the session
175
+ system_prompt (str): system prompt
176
+ max_new_tokens (int, optional): maximum number of tokens to generate, ignoring the number of tokens in the
177
+ prompt. Defaults to 1024.
178
+ temperature (float, optional): the value used to module the next token probabilities. Defaults to 0.6.
179
+ top_p (float, optional): if set to float<1, only the smallest set of most probable tokens with probabilities
180
+ that add up to top_p or higher are kept for generation. Defaults to 0.9.
181
+ top_k (int, optional): the number of highest probability vocabulary tokens to keep for top-k-filtering.
182
+ Defaults to 50.
183
+ repetition_penalty (float, optional): the parameter for repetition penalty. 1.0 means no penalty.
184
+ Defaults to 1.2.
185
+
186
+ Yields:
187
+ Iterator[str]: Streaming responses to the UI
188
+ """
189
+ if not validate_media(message):
190
+ # If the number of image files is not valid, return an empty string
191
+ yield ""
192
+ return
193
+
194
+ messages = []
195
+ if system_prompt:
196
+ messages.append(
197
+ {"role": "system", "content": [{"type": "text", "text": system_prompt}]}
198
+ )
199
+
200
+ # Append the new user message if it returns anything other than empty string
201
+ content = process_new_user_message(message)
202
+ if content:
203
+ # Append the new user message to the messages list
204
+ messages.append({"role": "user", "content": content})
205
+ else:
206
+ # If the content is empty, return an empty string
207
+ yield ""
208
+ return
209
+
210
+ # sample method to yield responses from the llm model
211
+ outputs = []
212
+ for text in request_generation(
213
+ header=HEADER,
214
+ messages=messages,
215
+ max_new_tokens=max_new_tokens,
216
+ temperature=temperature,
217
+ presence_penalty=presence_penalty,
218
+ frequency_penalty=frequency_penalty,
219
+ cloud_gateway_api=CLOUD_GATEWAY_API,
220
+ model_name=MODEL_NAME,
221
+ ):
222
+ outputs.append(text)
223
+ yield "".join(outputs)
224
+
225
+
226
+ examples = [
227
+ ["Plan a three-day trip to Washington DC for Cherry Blossom Festival."],
228
+ ["How many hours does it take a man to eat a Helicopter?"],
229
+ [
230
+ {
231
+ "text": "Write the matplotlib code to generate the same bar chart.",
232
+ "files": ["assets/sample-images/barchart.png"],
233
+ }
234
+ ],
235
+ [
236
+ {
237
+ "text": "Describe the atmosphere of the scene.",
238
+ "files": ["assets/sample-images/06-1.png"],
239
+ }
240
+ ],
241
+ [
242
+ {
243
+ "text": "Write a short story about what might have happened in this house.",
244
+ "files": ["assets/sample-images/08.png"],
245
+ }
246
+ ],
247
+ [
248
+ {
249
+ "text": "Describe the creatures that would live in this world.",
250
+ "files": ["assets/sample-images/10.png"],
251
+ }
252
+ ],
253
+ [
254
+ {
255
+ "text": "Read text in the image.",
256
+ "files": ["assets/sample-images/1.png"],
257
+ }
258
+ ],
259
+ [
260
+ {
261
+ "text": "When is this ticket dated and how much did it cost?",
262
+ "files": ["assets/sample-images/2.png"],
263
+ }
264
+ ],
265
+ [
266
+ {
267
+ "text": "Read the text in the image into markdown.",
268
+ "files": ["assets/sample-images/3.png"],
269
+ }
270
+ ],
271
+ [
272
+ {
273
+ "text": "Evaluate this integral.",
274
+ "files": ["assets/sample-images/4.png"],
275
+ }
276
+ ],
277
+ [
278
+ {
279
+ "text": "Caption this image",
280
+ "files": ["assets/sample-images/01.png"],
281
+ }
282
+ ],
283
+ [
284
+ {
285
+ "text": "What's the sign says?",
286
+ "files": ["assets/sample-images/02.png"],
287
+ }
288
+ ],
289
+ [
290
+ {
291
+ "text": "Compare and contrast the two images.",
292
+ "files": ["assets/sample-images/03.png"],
293
+ }
294
+ ],
295
+ [
296
+ {
297
+ "text": "List all the objects in the image and their colors.",
298
+ "files": ["assets/sample-images/04.png"],
299
+ }
300
+ ],
301
+ ]
302
+
303
+
304
+ demo = gr.ChatInterface(
305
+ fn=run,
306
+ type="messages",
307
+ chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
308
+ textbox=gr.MultimodalTextbox(
309
+ file_types=["image"],
310
+ file_count="single" if MAX_NUM_IMAGES == 1 else "multiple",
311
+ autofocus=True,
312
+ ),
313
+ multimodal=True,
314
+ additional_inputs=[
315
+ gr.Textbox(
316
+ label="System prompt",
317
+ # value="You are a highly capable AI assistant. Provide accurate, concise, and fact-based responses that are directly relevant to the user's query. Avoid speculation, ensure logical consistency, and maintain clarity in longer outputs.",
318
+ value="",
319
+ lines=3,
320
+ ),
321
+ gr.Slider(
322
+ label="Max New Tokens",
323
+ minimum=1,
324
+ maximum=MAX_NEW_TOKENS,
325
+ step=1,
326
+ value=2048,
327
+ ),
328
+ gr.Slider(
329
+ label="Temperature",
330
+ minimum=0.1,
331
+ maximum=4.0,
332
+ step=0.1,
333
+ value=0.3,
334
+ ),
335
+ gr.Slider(
336
+ label="Frequency penalty",
337
+ minimum=-2.0,
338
+ maximum=2.0,
339
+ step=0.1,
340
+ value=0.0,
341
+ ),
342
+ gr.Slider(
343
+ label="Presence penalty",
344
+ minimum=-2.0,
345
+ maximum=2.0,
346
+ step=0.1,
347
+ value=0.0,
348
+ ),
349
+ ],
350
+ stop_btn=False,
351
+ title="Llama-4 Scout Instruct",
352
+ description="This Space is an Alpha release that demonstrates [Llama-4-Scout](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E) model running on AMD MI300 infrastructure. The space is built with Meta Llama 4 [License](https://www.llama.com/llama4/license/). Feel free to play with it!",
353
+ fill_height=True,
354
+ run_examples_on_click=False,
355
+ examples=examples,
356
+ cache_examples=False,
357
+ )
358
+
359
+
360
+ if __name__ == "__main__":
361
+ demo.queue(
362
+ max_size=int(os.getenv("QUEUE")),
363
+ default_concurrency_limit=int(os.getenv("CONCURRENCY_LIMIT")),
364
+ ).launch()
assets/sample-images/01.png ADDED

Git LFS Details

  • SHA256: 0b849f4cc108d58d0de9ec4707426ed1fe8fe276d90f72d56feed624e830c2b5
  • Pointer size: 132 Bytes
  • Size of remote file: 2.1 MB
assets/sample-images/02.png ADDED

Git LFS Details

  • SHA256: 983d5cd47f4a88924c8619c5a9ecbaa374e766847627d30ba1d4dc9e3c556255
  • Pointer size: 131 Bytes
  • Size of remote file: 621 kB
assets/sample-images/03.png ADDED

Git LFS Details

  • SHA256: ba2380dc16996f688760dd9f62ecfbc8b6abcb785cf667ede031c6c843cf8cfd
  • Pointer size: 132 Bytes
  • Size of remote file: 2.52 MB
assets/sample-images/04.png ADDED

Git LFS Details

  • SHA256: 8c096c6ac54439e68a5e77ed2991989970f5e522bca36136f438391d313b02d0
  • Pointer size: 132 Bytes
  • Size of remote file: 2.5 MB
assets/sample-images/06-1.png ADDED

Git LFS Details

  • SHA256: 4756c1c49975e926390b87c08b9672a51b09bc7c41c33f20a9fd82f999c26ac4
  • Pointer size: 132 Bytes
  • Size of remote file: 2.4 MB
assets/sample-images/08.png ADDED

Git LFS Details

  • SHA256: 637f90b5941cb8eb3cc1eff20092bdb07d0fcfbd6e5bf881b3951d421089594e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.63 MB
assets/sample-images/1.png ADDED

Git LFS Details

  • SHA256: b498566d553d7f24d2a08fe65a7c52acca801fde4f155f8a3ba0b91e924044e9
  • Pointer size: 133 Bytes
  • Size of remote file: 12.4 MB
assets/sample-images/10.png ADDED

Git LFS Details

  • SHA256: e9ea637b8a7d50696ea85e608d6d9fe485958cf7aad52504ac27798c7e8b3d8f
  • Pointer size: 132 Bytes
  • Size of remote file: 2.7 MB
assets/sample-images/2.png ADDED

Git LFS Details

  • SHA256: df8a2ae6d36b0bda173b290123dc92385c7b60dbe63157f0a94dc865f2f79dd8
  • Pointer size: 132 Bytes
  • Size of remote file: 5.72 MB
assets/sample-images/3.png ADDED

Git LFS Details

  • SHA256: 03b341cb6773365b852f9614e4493aefe14e208066f507499598ce498e02c0b2
  • Pointer size: 133 Bytes
  • Size of remote file: 20.5 MB
assets/sample-images/4.png ADDED

Git LFS Details

  • SHA256: f40af3f85ab1524d5604f186ebbf102905c1b14ae631b82a092ec397f54eae7a
  • Pointer size: 129 Bytes
  • Size of remote file: 6.05 kB
assets/sample-images/barchart.png ADDED

Git LFS Details

  • SHA256: 4c83100dc7880913cecc96efcf7557c0f748c3b9d49e6f219e06f90e0a448847
  • Pointer size: 130 Bytes
  • Size of remote file: 74.2 kB
gateway.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import requests
4
+ import urllib3
5
+
6
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
7
+
8
+ # Setup logging
9
+ logging.basicConfig(level=logging.INFO)
10
+
11
+
12
+ def request_generation(
13
+ header: dict,
14
+ messages: dict,
15
+ cloud_gateway_api: str,
16
+ model_name: str,
17
+ max_new_tokens: int = 1024,
18
+ temperature: float = 0.3,
19
+ frequency_penalty: float = 0.0,
20
+ presence_penalty: float = 0.0,
21
+ ):
22
+ """
23
+ Request streaming generation from the cloud gateway API. Uses the simple requests module with stream=True to utilize
24
+ token-by-token generation from LLM.
25
+
26
+ Args:
27
+ header: authorization header for the API.
28
+ message: prompt from the user.
29
+ system_prompt: system prompt to append.
30
+ cloud_gateway_api (str): API endpoint to send the request.
31
+ max_new_tokens: maximum number of tokens to generate, ignoring the number of tokens in the prompt.
32
+ temperature: the value used to module the next token probabilities.
33
+ top_p: if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p
34
+ or higher are kept for generation.
35
+ repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.
36
+
37
+ Returns:
38
+
39
+ """
40
+
41
+ payload = {
42
+ "model": model_name,
43
+ "messages": messages,
44
+ "max_tokens": max_new_tokens,
45
+ "temperature": temperature,
46
+ "frequency_penalty": frequency_penalty,
47
+ "presence_penalty": presence_penalty,
48
+ "stream": True, # Enable streaming
49
+ "serving_runtime": "vllm",
50
+ }
51
+
52
+ try:
53
+ response = requests.post(
54
+ cloud_gateway_api + "chat/conversation",
55
+ headers=header,
56
+ json=payload,
57
+ verify=False,
58
+ )
59
+ print(response.text)
60
+ response.raise_for_status()
61
+
62
+ # Append the conversation ID with the key X-Conversation-ID to the header
63
+ header["X-Conversation-ID"] = response.json()["conversationId"]
64
+
65
+ with requests.get(
66
+ cloud_gateway_api + f"conversation/stream",
67
+ headers=header,
68
+ verify=False,
69
+ stream=True,
70
+ ) as response:
71
+ for chunk in response.iter_lines():
72
+ if chunk:
73
+ # Convert the chunk from bytes to a string and then parse it as json
74
+ chunk_str = chunk.decode("utf-8")
75
+
76
+ # Remove the `data: ` prefix from the chunk if it exists
77
+ for _ in range(2):
78
+ if chunk_str.startswith("data: "):
79
+ chunk_str = chunk_str[len("data: ") :]
80
+
81
+ # Skip empty chunks
82
+ if chunk_str.strip() == "[DONE]":
83
+ break
84
+
85
+ # Parse the chunk into a JSON object
86
+ try:
87
+ chunk_json = json.loads(chunk_str)
88
+
89
+ # Extract the "content" field from the choices
90
+ if "choices" in chunk_json and chunk_json["choices"]:
91
+ content = chunk_json["choices"][0]["delta"].get(
92
+ "content", ""
93
+ )
94
+ else:
95
+ content = ""
96
+
97
+ # Print the generated content as it's streamed
98
+ if content:
99
+ yield content
100
+ except json.JSONDecodeError:
101
+ # Handle any potential errors in decoding
102
+ continue
103
+ except requests.RequestException as e:
104
+ logging.error(f"Failed to generate response: {e}")
105
+ yield "Server not responding. Please try again later."
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy
2
+ pillow
3
+ fastapi
4
+ websockets
style.css ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ .contain {
7
+ max-width: 900px;
8
+ margin: auto;
9
+ padding-top: 1.5rem;
10
+ }