Lohia, Aditya
styling changes
eb9b354
import os
import logging
import re
import gradio as gr
import base64
import io
from PIL import Image
from typing import Iterator
from gateway import request_generation
# Setup logging
logging.basicConfig(level=logging.INFO)
# CONSTANTS
# Get max new tokens from environment variable, if it is not set, default to 2048
MAX_NEW_TOKENS: int = int(os.getenv("MAX_NEW_TOKENS", 2048))
# Get max number of images to be passed in the prompt
MAX_NUM_IMAGES: int = int(os.getenv("MAX_NUM_IMAGES"))
if not MAX_NUM_IMAGES:
raise EnvironmentError("MAX_NUM_IMAGES is not set. Please set it to 1 or more.")
# Validate environment variables
CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT")
if not CLOUD_GATEWAY_API:
raise EnvironmentError("API_ENDPOINT is not set.")
MODEL_NAME: str = os.getenv("MODEL_NAME")
if not MODEL_NAME:
raise EnvironmentError("MODEL_NAME is not set.")
# Get API Key
API_KEY = os.getenv("API_KEY")
if not API_KEY: # simple check to validate API Key
raise Exception("API Key not valid.")
# Create a header, avoid declaring multiple times
HEADER = {"x-api-key": f"{API_KEY}"}
def validate_media(message: str, chat_history: list = None) -> bool:
"""Validate the number of image files in the new message.
Args:
message (str): input message from the user
chat_history (list[tuple[str, str]]): entire chat history of the session
Returns:
bool: True if the number of image files is less than or equal to MAX_NUM_IMAGES, False otherwise
"""
image_count = sum(1 for path in message["files"])
# Check if there are <image> tags in the prompt and add count
image_count += message["text"].count("<image>")
if image_count > MAX_NUM_IMAGES:
gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images at a time.")
return False
# If there are files, check if they are images
if not all(
file.lower().endswith((".png", ".jpg", ".jpeg")) for file in message["files"]
):
gr.Warning("Only images are allowed. Format available: PNG, JPG, JPEG")
return False
return True
def encode_pil_to_base64(pil_image: Image.Image, format: str) -> str:
"""Encode a PIL image to base64 string.
Args:
pil_image (Image.Image): PIL image object
format (str): format to save the image, defaults to JPEG
Returns:
str: base64 encoded string of the image
"""
buffered = io.BytesIO()
# Handle potential transparency issues for JPEG or JPG
if format == "JPEG" and pil_image.mode in ("RGBA", "LA", "P"):
# Convert to RGB
pil_image = pil_image.convert("RGB")
# Define save arguments, including quality for JPEG
save_kwargs = {"format": format}
if format == "JPEG":
save_kwargs["quality"] = 85 # Adjust quality as needed (0-100)
try:
pil_image.save(buffered, **save_kwargs)
except Exception as e:
print(f"Error saving image to buffer with format {format}: {e}")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Determine the MIME type based on the format
mime_format_part = format.lower()
if mime_format_part == "jpeg":
mime_type = "image/jpeg"
elif mime_format_part == "png":
mime_type = "image/png"
else:
gr.Error(f"Unsupported image format: {format}")
return None
return f"data:{mime_type};base64,{img_str}"
def process_images(message: list) -> list[dict]:
"""Process images in the message.
Args:
message (list): message list containing text and files
Returns:
list[dict]: list of dictionaries containing text and image content
"""
content = []
# Iterate through the files in the message
for path in message:
pil_image = Image.open(path)
# Get the image format
image_format = pil_image.format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format in ["JPEG", "PNG"]:
# Converting image to base64
base64_image_data = encode_pil_to_base64(pil_image, format=image_format)
content.append(
{"type": "image_url", "image_url": {"url": base64_image_data}}
)
return content
def extract_image_urls_from_tags(message):
"""Extract image URLs from the <image> tags in the message text.
Args:
message (str): message text containing <image> tags
Returns:
list[str]: list of image URLs extracted from the <image> tags
"""
# Extract all <image> tags from the message text using regex
image_urls = re.findall(r"<image>(.*?)</image>", message, re.IGNORECASE | re.DOTALL)
# Basic cleanup: strip whitespace from found URLs
image_urls = [url.strip() for url in image_urls]
return image_urls
def process_new_user_message(message: dict) -> list[dict]:
"""Process the new user message and return a list of dictionaries containing text and image content.
Args:
message (dict): message dictionary containing text and files
Returns:
list[dict]: list of dictionaries containing text and image content
"""
# Create the content list messages
messages = []
if message["text"]:
# Remove the <image> tags from the message text
prompt = re.sub(
r"<image>.*?</image>", "", message["text"], flags=re.DOTALL | re.IGNORECASE
).strip()
# If the message text is empty after removing <image> tags, return an empty list
if not prompt:
gr.Warning("Please insert a prompt.")
return []
# If the message text is not empty, append it to the content list
messages.append({"type": "text", "text": prompt})
# processing image urls within tags
image_urls = extract_image_urls_from_tags(message["text"])
for url in image_urls:
if not url or not url.lower().startswith(("http://", "https://")):
continue
# Append the image URL to the content list
messages.append({"type": "image_url", "image_url": {"url": url}})
if message["files"]:
# If there are files, process the images
image_content = process_images(message["files"])
# Append the image content to the messages list
messages.extend(image_content)
return messages
else:
# If there are no text parts, throw a gr.Warning to insert prompt and return nothing
gr.Warning("Please insert a prompt.")
return []
def run(
message: str,
chat_history: list,
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
) -> Iterator[str]:
"""Send a request to backend, fetch the streaming responses and emit to the UI.
Args:
message (str): input message from the user
chat_history (list[tuple[str, str]]): entire chat history of the session
system_prompt (str): system prompt
max_new_tokens (int, optional): maximum number of tokens to generate, ignoring the number of tokens in the
prompt. Defaults to 1024.
temperature (float, optional): the value used to module the next token probabilities. Defaults to 0.6.
top_p (float, optional): if set to float<1, only the smallest set of most probable tokens with probabilities
that add up to top_p or higher are kept for generation. Defaults to 0.9.
top_k (int, optional): the number of highest probability vocabulary tokens to keep for top-k-filtering.
Defaults to 50.
repetition_penalty (float, optional): the parameter for repetition penalty. 1.0 means no penalty.
Defaults to 1.2.
Yields:
Iterator[str]: Streaming responses to the UI
"""
if not validate_media(message):
# If the number of image files is not valid, return an empty string
yield ""
return
messages = []
if system_prompt:
messages.append(
{"role": "system", "content": [{"type": "text", "text": system_prompt}]}
)
# Append the new user message if it returns anything other than empty string
content = process_new_user_message(message)
if content:
# Append the new user message to the messages list
messages.append({"role": "user", "content": content})
else:
# If the content is empty, return an empty string
yield ""
return
# sample method to yield responses from the llm model
outputs = []
for text in request_generation(
header=HEADER,
messages=messages,
max_new_tokens=max_new_tokens,
temperature=temperature,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
cloud_gateway_api=CLOUD_GATEWAY_API,
model_name=MODEL_NAME,
):
outputs.append(text)
yield "".join(outputs)
examples = [
["Plan a three-day trip to Washington DC for Cherry Blossom Festival."],
["How many hours does it take a man to eat a Helicopter?"],
[
{
"text": "Write the matplotlib code to generate the same bar chart.",
"files": ["assets/sample-images/barchart.png"],
}
],
[
{
"text": "Describe the atmosphere of the scene.",
"files": ["assets/sample-images/06-1.png"],
}
],
[
{
"text": "Write a short story about what might have happened in this house.",
"files": ["assets/sample-images/08.png"],
}
],
[
{
"text": "Describe the creatures that would live in this world.",
"files": ["assets/sample-images/10.png"],
}
],
[
{
"text": "Read text in the image.",
"files": ["assets/sample-images/1.png"],
}
],
[
{
"text": "When is this ticket dated and how much did it cost?",
"files": ["assets/sample-images/2.png"],
}
],
[
{
"text": "Read the text in the image into markdown.",
"files": ["assets/sample-images/3.png"],
}
],
[
{
"text": "Evaluate this integral.",
"files": ["assets/sample-images/4.png"],
}
],
[
{
"text": "Caption this image",
"files": ["assets/sample-images/01.png"],
}
],
[
{
"text": "What's the sign says?",
"files": ["assets/sample-images/02.png"],
}
],
[
{
"text": "Compare and contrast the two images.",
"files": ["assets/sample-images/03.png"],
}
],
[
{
"text": "List all the objects in the image and their colors.",
"files": ["assets/sample-images/04.png"],
}
],
]
description = f"""
This Space is an Alpha release that demonstrates [Llama-4-Maverick](https://huggingface.co./meta-llama/Llama-4-Maverick-17B-128E-Instruct) model running on AMD MI300 infrastructure. The space is built with Meta Llama 4 [License](https://www.llama.com/llama4/license/). Feel free to play with it!
"""
demo = gr.ChatInterface(
fn=run,
type="messages",
chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
textbox=gr.MultimodalTextbox(
file_types=["image"],
file_count="single" if MAX_NUM_IMAGES == 1 else "multiple",
autofocus=True,
placeholder="Type message, drop PNG/JPEG or use <image>URL</image>...",
),
multimodal=True,
additional_inputs=[
gr.Textbox(
label="System prompt",
# value="You are a highly capable AI assistant. Provide accurate, concise, and fact-based responses that are directly relevant to the user's query. Avoid speculation, ensure logical consistency, and maintain clarity in longer outputs.",
value="",
lines=3,
),
gr.Slider(
label="Max New Tokens",
minimum=1,
maximum=MAX_NEW_TOKENS,
step=1,
value=2048,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.3,
),
gr.Slider(
label="Frequency penalty",
minimum=-2.0,
maximum=2.0,
step=0.1,
value=0.0,
),
gr.Slider(
label="Presence penalty",
minimum=-2.0,
maximum=2.0,
step=0.1,
value=0.0,
),
],
stop_btn=False,
title="Llama-4 Maverick Instruct",
description=description,
fill_height=True,
run_examples_on_click=False,
examples=examples,
css_paths="style.css",
cache_examples=False,
)
if __name__ == "__main__":
demo.queue(
max_size=int(os.getenv("QUEUE")),
default_concurrency_limit=int(os.getenv("CONCURRENCY_LIMIT")),
).launch()