import os import logging import re import gradio as gr import base64 import io from PIL import Image from typing import Iterator from gateway import request_generation # Setup logging logging.basicConfig(level=logging.INFO) # CONSTANTS # Get max new tokens from environment variable, if it is not set, default to 2048 MAX_NEW_TOKENS: int = int(os.getenv("MAX_NEW_TOKENS", 2048)) # Get max number of images to be passed in the prompt MAX_NUM_IMAGES: int = int(os.getenv("MAX_NUM_IMAGES")) if not MAX_NUM_IMAGES: raise EnvironmentError("MAX_NUM_IMAGES is not set. Please set it to 1 or more.") # Validate environment variables CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT") if not CLOUD_GATEWAY_API: raise EnvironmentError("API_ENDPOINT is not set.") MODEL_NAME: str = os.getenv("MODEL_NAME") if not MODEL_NAME: raise EnvironmentError("MODEL_NAME is not set.") # Get API Key API_KEY = os.getenv("API_KEY") if not API_KEY: # simple check to validate API Key raise Exception("API Key not valid.") # Create a header, avoid declaring multiple times HEADER = {"x-api-key": f"{API_KEY}"} def validate_media(message: str, chat_history: list = None) -> bool: """Validate the number of image files in the new message. Args: message (str): input message from the user chat_history (list[tuple[str, str]]): entire chat history of the session Returns: bool: True if the number of image files is less than or equal to MAX_NUM_IMAGES, False otherwise """ image_count = sum(1 for path in message["files"]) # Check if there are tags in the prompt and add count image_count += message["text"].count("") if image_count > MAX_NUM_IMAGES: gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images at a time.") return False # If there are files, check if they are images if not all( file.lower().endswith((".png", ".jpg", ".jpeg")) for file in message["files"] ): gr.Warning("Only images are allowed. Format available: PNG, JPG, JPEG") return False return True def encode_pil_to_base64(pil_image: Image.Image, format: str) -> str: """Encode a PIL image to base64 string. Args: pil_image (Image.Image): PIL image object format (str): format to save the image, defaults to JPEG Returns: str: base64 encoded string of the image """ buffered = io.BytesIO() # Handle potential transparency issues for JPEG or JPG if format == "JPEG" and pil_image.mode in ("RGBA", "LA", "P"): # Convert to RGB pil_image = pil_image.convert("RGB") # Define save arguments, including quality for JPEG save_kwargs = {"format": format} if format == "JPEG": save_kwargs["quality"] = 85 # Adjust quality as needed (0-100) try: pil_image.save(buffered, **save_kwargs) except Exception as e: print(f"Error saving image to buffer with format {format}: {e}") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") # Determine the MIME type based on the format mime_format_part = format.lower() if mime_format_part == "jpeg": mime_type = "image/jpeg" elif mime_format_part == "png": mime_type = "image/png" else: gr.Error(f"Unsupported image format: {format}") return None return f"data:{mime_type};base64,{img_str}" def process_images(message: list) -> list[dict]: """Process images in the message. Args: message (list): message list containing text and files Returns: list[dict]: list of dictionaries containing text and image content """ content = [] # Iterate through the files in the message for path in message: pil_image = Image.open(path) # Get the image format image_format = pil_image.format.upper() if image_format == "JPG": image_format = "JPEG" if image_format in ["JPEG", "PNG"]: # Converting image to base64 base64_image_data = encode_pil_to_base64(pil_image, format=image_format) content.append( {"type": "image_url", "image_url": {"url": base64_image_data}} ) return content def extract_image_urls_from_tags(message): """Extract image URLs from the tags in the message text. Args: message (str): message text containing tags Returns: list[str]: list of image URLs extracted from the tags """ # Extract all tags from the message text using regex image_urls = re.findall(r"(.*?)", message, re.IGNORECASE | re.DOTALL) # Basic cleanup: strip whitespace from found URLs image_urls = [url.strip() for url in image_urls] return image_urls def process_new_user_message(message: dict) -> list[dict]: """Process the new user message and return a list of dictionaries containing text and image content. Args: message (dict): message dictionary containing text and files Returns: list[dict]: list of dictionaries containing text and image content """ # Create the content list messages messages = [] if message["text"]: # Remove the tags from the message text prompt = re.sub( r".*?", "", message["text"], flags=re.DOTALL | re.IGNORECASE ).strip() # If the message text is empty after removing tags, return an empty list if not prompt: gr.Warning("Please insert a prompt.") return [] # If the message text is not empty, append it to the content list messages.append({"type": "text", "text": prompt}) # processing image urls within tags image_urls = extract_image_urls_from_tags(message["text"]) for url in image_urls: if not url or not url.lower().startswith(("http://", "https://")): continue # Append the image URL to the content list messages.append({"type": "image_url", "image_url": {"url": url}}) if message["files"]: # If there are files, process the images image_content = process_images(message["files"]) # Append the image content to the messages list messages.extend(image_content) return messages else: # If there are no text parts, throw a gr.Warning to insert prompt and return nothing gr.Warning("Please insert a prompt.") return [] def run( message: str, chat_history: list, system_prompt: str, max_new_tokens: int = 1024, temperature: float = 0.6, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, ) -> Iterator[str]: """Send a request to backend, fetch the streaming responses and emit to the UI. Args: message (str): input message from the user chat_history (list[tuple[str, str]]): entire chat history of the session system_prompt (str): system prompt max_new_tokens (int, optional): maximum number of tokens to generate, ignoring the number of tokens in the prompt. Defaults to 1024. temperature (float, optional): the value used to module the next token probabilities. Defaults to 0.6. top_p (float, optional): if set to float<1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to 0.9. top_k (int, optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to 50. repetition_penalty (float, optional): the parameter for repetition penalty. 1.0 means no penalty. Defaults to 1.2. Yields: Iterator[str]: Streaming responses to the UI """ if not validate_media(message): # If the number of image files is not valid, return an empty string yield "" return messages = [] if system_prompt: messages.append( {"role": "system", "content": [{"type": "text", "text": system_prompt}]} ) # Append the new user message if it returns anything other than empty string content = process_new_user_message(message) if content: # Append the new user message to the messages list messages.append({"role": "user", "content": content}) else: # If the content is empty, return an empty string yield "" return # sample method to yield responses from the llm model outputs = [] for text in request_generation( header=HEADER, messages=messages, max_new_tokens=max_new_tokens, temperature=temperature, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, cloud_gateway_api=CLOUD_GATEWAY_API, model_name=MODEL_NAME, ): outputs.append(text) yield "".join(outputs) examples = [ ["Plan a three-day trip to Washington DC for Cherry Blossom Festival."], ["How many hours does it take a man to eat a Helicopter?"], [ { "text": "Write the matplotlib code to generate the same bar chart.", "files": ["assets/sample-images/barchart.png"], } ], [ { "text": "Describe the atmosphere of the scene.", "files": ["assets/sample-images/06-1.png"], } ], [ { "text": "Write a short story about what might have happened in this house.", "files": ["assets/sample-images/08.png"], } ], [ { "text": "Describe the creatures that would live in this world.", "files": ["assets/sample-images/10.png"], } ], [ { "text": "Read text in the image.", "files": ["assets/sample-images/1.png"], } ], [ { "text": "When is this ticket dated and how much did it cost?", "files": ["assets/sample-images/2.png"], } ], [ { "text": "Read the text in the image into markdown.", "files": ["assets/sample-images/3.png"], } ], [ { "text": "Evaluate this integral.", "files": ["assets/sample-images/4.png"], } ], [ { "text": "Caption this image", "files": ["assets/sample-images/01.png"], } ], [ { "text": "What's the sign says?", "files": ["assets/sample-images/02.png"], } ], [ { "text": "Compare and contrast the two images.", "files": ["assets/sample-images/03.png"], } ], [ { "text": "List all the objects in the image and their colors.", "files": ["assets/sample-images/04.png"], } ], ] description = f""" This Space is an Alpha release that demonstrates [Llama-4-Maverick](https://huggingface.co./meta-llama/Llama-4-Maverick-17B-128E-Instruct) model running on AMD MI300 infrastructure. The space is built with Meta Llama 4 [License](https://www.llama.com/llama4/license/). Feel free to play with it! """ demo = gr.ChatInterface( fn=run, type="messages", chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]), textbox=gr.MultimodalTextbox( file_types=["image"], file_count="single" if MAX_NUM_IMAGES == 1 else "multiple", autofocus=True, placeholder="Type message, drop PNG/JPEG or use URL...", ), multimodal=True, additional_inputs=[ gr.Textbox( label="System prompt", # value="You are a highly capable AI assistant. Provide accurate, concise, and fact-based responses that are directly relevant to the user's query. Avoid speculation, ensure logical consistency, and maintain clarity in longer outputs.", value="", lines=3, ), gr.Slider( label="Max New Tokens", minimum=1, maximum=MAX_NEW_TOKENS, step=1, value=2048, ), gr.Slider( label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.3, ), gr.Slider( label="Frequency penalty", minimum=-2.0, maximum=2.0, step=0.1, value=0.0, ), gr.Slider( label="Presence penalty", minimum=-2.0, maximum=2.0, step=0.1, value=0.0, ), ], stop_btn=False, title="Llama-4 Maverick Instruct", description=description, fill_height=True, run_examples_on_click=False, examples=examples, css_paths="style.css", cache_examples=False, ) if __name__ == "__main__": demo.queue( max_size=int(os.getenv("QUEUE")), default_concurrency_limit=int(os.getenv("CONCURRENCY_LIMIT")), ).launch()