|
import os |
|
import logging |
|
import re |
|
import gradio as gr |
|
import base64 |
|
import io |
|
from PIL import Image |
|
from typing import Iterator |
|
|
|
from gateway import request_generation |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
MAX_NEW_TOKENS: int = int(os.getenv("MAX_NEW_TOKENS", 2048)) |
|
|
|
|
|
MAX_NUM_IMAGES: int = int(os.getenv("MAX_NUM_IMAGES")) |
|
if not MAX_NUM_IMAGES: |
|
raise EnvironmentError("MAX_NUM_IMAGES is not set. Please set it to 1 or more.") |
|
|
|
|
|
CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT") |
|
if not CLOUD_GATEWAY_API: |
|
raise EnvironmentError("API_ENDPOINT is not set.") |
|
|
|
MODEL_NAME: str = os.getenv("MODEL_NAME") |
|
if not MODEL_NAME: |
|
raise EnvironmentError("MODEL_NAME is not set.") |
|
|
|
|
|
API_KEY = os.getenv("API_KEY") |
|
if not API_KEY: |
|
raise Exception("API Key not valid.") |
|
|
|
|
|
HEADER = {"x-api-key": f"{API_KEY}"} |
|
|
|
|
|
def validate_media(message: str, chat_history: list = None) -> bool: |
|
"""Validate the number of image files in the new message. |
|
Args: |
|
message (str): input message from the user |
|
chat_history (list[tuple[str, str]]): entire chat history of the session |
|
Returns: |
|
bool: True if the number of image files is less than or equal to MAX_NUM_IMAGES, False otherwise |
|
""" |
|
image_count = sum(1 for path in message["files"]) |
|
|
|
image_count += message["text"].count("<image>") |
|
|
|
if image_count > MAX_NUM_IMAGES: |
|
gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images at a time.") |
|
return False |
|
|
|
|
|
if not all( |
|
file.lower().endswith((".png", ".jpg", ".jpeg")) for file in message["files"] |
|
): |
|
gr.Warning("Only images are allowed. Format available: PNG, JPG, JPEG") |
|
return False |
|
return True |
|
|
|
|
|
def encode_pil_to_base64(pil_image: Image.Image, format: str) -> str: |
|
"""Encode a PIL image to base64 string. |
|
Args: |
|
pil_image (Image.Image): PIL image object |
|
format (str): format to save the image, defaults to JPEG |
|
Returns: |
|
str: base64 encoded string of the image |
|
""" |
|
buffered = io.BytesIO() |
|
|
|
|
|
if format == "JPEG" and pil_image.mode in ("RGBA", "LA", "P"): |
|
|
|
pil_image = pil_image.convert("RGB") |
|
|
|
|
|
save_kwargs = {"format": format} |
|
if format == "JPEG": |
|
save_kwargs["quality"] = 85 |
|
|
|
try: |
|
pil_image.save(buffered, **save_kwargs) |
|
except Exception as e: |
|
print(f"Error saving image to buffer with format {format}: {e}") |
|
|
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
|
|
mime_format_part = format.lower() |
|
if mime_format_part == "jpeg": |
|
mime_type = "image/jpeg" |
|
elif mime_format_part == "png": |
|
mime_type = "image/png" |
|
else: |
|
gr.Error(f"Unsupported image format: {format}") |
|
return None |
|
|
|
return f"data:{mime_type};base64,{img_str}" |
|
|
|
|
|
def process_images(message: list) -> list[dict]: |
|
"""Process images in the message. |
|
Args: |
|
message (list): message list containing text and files |
|
Returns: |
|
list[dict]: list of dictionaries containing text and image content |
|
""" |
|
content = [] |
|
|
|
|
|
for path in message: |
|
pil_image = Image.open(path) |
|
|
|
image_format = pil_image.format.upper() |
|
if image_format == "JPG": |
|
image_format = "JPEG" |
|
|
|
if image_format in ["JPEG", "PNG"]: |
|
|
|
base64_image_data = encode_pil_to_base64(pil_image, format=image_format) |
|
content.append( |
|
{"type": "image_url", "image_url": {"url": base64_image_data}} |
|
) |
|
|
|
return content |
|
|
|
|
|
def extract_image_urls_from_tags(message): |
|
"""Extract image URLs from the <image> tags in the message text. |
|
Args: |
|
message (str): message text containing <image> tags |
|
Returns: |
|
list[str]: list of image URLs extracted from the <image> tags |
|
""" |
|
|
|
image_urls = re.findall(r"<image>(.*?)</image>", message, re.IGNORECASE | re.DOTALL) |
|
|
|
image_urls = [url.strip() for url in image_urls] |
|
return image_urls |
|
|
|
|
|
def process_new_user_message(message: dict) -> list[dict]: |
|
"""Process the new user message and return a list of dictionaries containing text and image content. |
|
Args: |
|
message (dict): message dictionary containing text and files |
|
Returns: |
|
list[dict]: list of dictionaries containing text and image content |
|
""" |
|
|
|
messages = [] |
|
|
|
if message["text"]: |
|
|
|
prompt = re.sub( |
|
r"<image>.*?</image>", "", message["text"], flags=re.DOTALL | re.IGNORECASE |
|
).strip() |
|
|
|
if not prompt: |
|
gr.Warning("Please insert a prompt.") |
|
return [] |
|
|
|
messages.append({"type": "text", "text": prompt}) |
|
|
|
|
|
image_urls = extract_image_urls_from_tags(message["text"]) |
|
for url in image_urls: |
|
if not url or not url.lower().startswith(("http://", "https://")): |
|
continue |
|
|
|
messages.append({"type": "image_url", "image_url": {"url": url}}) |
|
|
|
if message["files"]: |
|
|
|
image_content = process_images(message["files"]) |
|
|
|
messages.extend(image_content) |
|
|
|
return messages |
|
else: |
|
|
|
gr.Warning("Please insert a prompt.") |
|
return [] |
|
|
|
|
|
def run( |
|
message: str, |
|
chat_history: list, |
|
system_prompt: str, |
|
max_new_tokens: int = 1024, |
|
temperature: float = 0.6, |
|
frequency_penalty: float = 0.0, |
|
presence_penalty: float = 0.0, |
|
) -> Iterator[str]: |
|
"""Send a request to backend, fetch the streaming responses and emit to the UI. |
|
|
|
Args: |
|
message (str): input message from the user |
|
chat_history (list[tuple[str, str]]): entire chat history of the session |
|
system_prompt (str): system prompt |
|
max_new_tokens (int, optional): maximum number of tokens to generate, ignoring the number of tokens in the |
|
prompt. Defaults to 1024. |
|
temperature (float, optional): the value used to module the next token probabilities. Defaults to 0.6. |
|
top_p (float, optional): if set to float<1, only the smallest set of most probable tokens with probabilities |
|
that add up to top_p or higher are kept for generation. Defaults to 0.9. |
|
top_k (int, optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. |
|
Defaults to 50. |
|
repetition_penalty (float, optional): the parameter for repetition penalty. 1.0 means no penalty. |
|
Defaults to 1.2. |
|
|
|
Yields: |
|
Iterator[str]: Streaming responses to the UI |
|
""" |
|
if not validate_media(message): |
|
|
|
yield "" |
|
return |
|
|
|
messages = [] |
|
if system_prompt: |
|
messages.append( |
|
{"role": "system", "content": [{"type": "text", "text": system_prompt}]} |
|
) |
|
|
|
|
|
content = process_new_user_message(message) |
|
if content: |
|
|
|
messages.append({"role": "user", "content": content}) |
|
else: |
|
|
|
yield "" |
|
return |
|
|
|
|
|
outputs = [] |
|
for text in request_generation( |
|
header=HEADER, |
|
messages=messages, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
presence_penalty=presence_penalty, |
|
frequency_penalty=frequency_penalty, |
|
cloud_gateway_api=CLOUD_GATEWAY_API, |
|
model_name=MODEL_NAME, |
|
): |
|
outputs.append(text) |
|
yield "".join(outputs) |
|
|
|
|
|
examples = [ |
|
["Plan a three-day trip to Washington DC for Cherry Blossom Festival."], |
|
["How many hours does it take a man to eat a Helicopter?"], |
|
[ |
|
{ |
|
"text": "Write the matplotlib code to generate the same bar chart.", |
|
"files": ["assets/sample-images/barchart.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "Describe the atmosphere of the scene.", |
|
"files": ["assets/sample-images/06-1.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "Write a short story about what might have happened in this house.", |
|
"files": ["assets/sample-images/08.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "Describe the creatures that would live in this world.", |
|
"files": ["assets/sample-images/10.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "Read text in the image.", |
|
"files": ["assets/sample-images/1.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "When is this ticket dated and how much did it cost?", |
|
"files": ["assets/sample-images/2.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "Read the text in the image into markdown.", |
|
"files": ["assets/sample-images/3.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "Evaluate this integral.", |
|
"files": ["assets/sample-images/4.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "Caption this image", |
|
"files": ["assets/sample-images/01.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "What's the sign says?", |
|
"files": ["assets/sample-images/02.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "Compare and contrast the two images.", |
|
"files": ["assets/sample-images/03.png"], |
|
} |
|
], |
|
[ |
|
{ |
|
"text": "List all the objects in the image and their colors.", |
|
"files": ["assets/sample-images/04.png"], |
|
} |
|
], |
|
] |
|
|
|
description = f""" |
|
This Space is an Alpha release that demonstrates [Llama-4-Maverick](https://huggingface.co./meta-llama/Llama-4-Maverick-17B-128E-Instruct) model running on AMD MI300 infrastructure. The space is built with Meta Llama 4 [License](https://www.llama.com/llama4/license/). Feel free to play with it! |
|
""" |
|
|
|
|
|
demo = gr.ChatInterface( |
|
fn=run, |
|
type="messages", |
|
chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]), |
|
textbox=gr.MultimodalTextbox( |
|
file_types=["image"], |
|
file_count="single" if MAX_NUM_IMAGES == 1 else "multiple", |
|
autofocus=True, |
|
placeholder="Type message, drop PNG/JPEG or use <image>URL</image>...", |
|
), |
|
multimodal=True, |
|
additional_inputs=[ |
|
gr.Textbox( |
|
label="System prompt", |
|
|
|
value="", |
|
lines=3, |
|
), |
|
gr.Slider( |
|
label="Max New Tokens", |
|
minimum=1, |
|
maximum=MAX_NEW_TOKENS, |
|
step=1, |
|
value=2048, |
|
), |
|
gr.Slider( |
|
label="Temperature", |
|
minimum=0.1, |
|
maximum=4.0, |
|
step=0.1, |
|
value=0.3, |
|
), |
|
gr.Slider( |
|
label="Frequency penalty", |
|
minimum=-2.0, |
|
maximum=2.0, |
|
step=0.1, |
|
value=0.0, |
|
), |
|
gr.Slider( |
|
label="Presence penalty", |
|
minimum=-2.0, |
|
maximum=2.0, |
|
step=0.1, |
|
value=0.0, |
|
), |
|
], |
|
stop_btn=False, |
|
title="Llama-4 Maverick Instruct", |
|
description=description, |
|
fill_height=True, |
|
run_examples_on_click=False, |
|
examples=examples, |
|
css_paths="style.css", |
|
cache_examples=False, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.queue( |
|
max_size=int(os.getenv("QUEUE")), |
|
default_concurrency_limit=int(os.getenv("CONCURRENCY_LIMIT")), |
|
).launch() |
|
|