Spaces:
Running
Running
import streamlit as st | |
import torch | |
import numpy as np | |
import pandas as pd | |
from PIL import Image, ImageDraw | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
# Device settings | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
# Load model with caching | |
def load_model(): | |
CHECKPOINT = "microsoft/Florence-2-base-ft" | |
model = AutoModelForCausalLM.from_pretrained(CHECKPOINT, trust_remote_code=True).to(device, dtype=torch_dtype) | |
processor = AutoProcessor.from_pretrained(CHECKPOINT, trust_remote_code=True) | |
return model, processor | |
# Load the model and processor | |
try: | |
model, processor = load_model() | |
except Exception as e: | |
st.error(f"Model loading failed: {e}") | |
st.stop() | |
# UI title | |
st.title("Florence-2 Multi-Modal Model Playground") | |
# Task selector | |
task = st.selectbox("Select Task", ["Object Detection (OD)", "Phrase Grounding (PG)", "Image Captioning (IC)"]) | |
# Phrase input for PG | |
phrase = "" | |
if task == "Phrase Grounding (PG)": | |
phrase = st.text_input("Enter phrase for grounding (e.g., 'A red car')", "") | |
# Image uploader | |
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
# If file uploaded | |
if uploaded_file: | |
try: | |
image = Image.open(uploaded_file).convert("RGB") | |
except Exception as e: | |
st.error(f"Error loading image: {e}") | |
st.stop() | |
st.image(image, caption="Uploaded Image", use_container_width=True) | |
# Task-specific prompt | |
if task == "Object Detection (OD)": | |
task_prompt = "<OD>" | |
elif task == "Phrase Grounding (PG)": | |
task_prompt = "<CAPTION_TO_PHRASE_GROUNDING>" | |
else: | |
task_prompt = "<CAPTION>" | |
# Preprocess inputs | |
try: | |
inputs = processor(text=task_prompt + phrase, images=image, return_tensors="pt").to(device, torch_dtype) | |
except Exception as e: | |
st.error(f"Error during preprocessing: {e}") | |
st.stop() | |
# Generate output | |
with torch.no_grad(): | |
try: | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=512, | |
num_beams=3, | |
do_sample=False | |
) | |
except Exception as e: | |
st.error(f"Error during generation: {e}") | |
st.stop() | |
# Decode and post-process | |
try: | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = processor.post_process_generation( | |
generated_text, | |
task=task_prompt, | |
image_size=(image.width, image.height) | |
) | |
except Exception as e: | |
st.error(f"Post-processing failed: {e}") | |
st.stop() | |
# Display results | |
if task in ["Object Detection (OD)", "Phrase Grounding (PG)"]: | |
key = "<OD>" if task == "Object Detection (OD)" else "<CAPTION_TO_PHRASE_GROUNDING>" | |
detections = parsed_answer.get(key, {"bboxes": [], "labels": []}) | |
bboxes = detections.get("bboxes", []) | |
labels = detections.get("labels", []) | |
draw = ImageDraw.Draw(image) | |
data = [] | |
for bbox, label in zip(bboxes, labels): | |
x_min, y_min, x_max, y_max = map(int, bbox) | |
draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3) | |
draw.text((x_min, max(0, y_min - 10)), label, fill="red") | |
data.append([x_min, y_min, x_max - x_min, y_max - y_min, label]) | |
st.image(image, caption="Detected Objects", use_container_width=True) | |
df = pd.DataFrame(data, columns=["x", "y", "w", "h", "object"]) | |
st.dataframe(df) | |
else: | |
caption = parsed_answer.get("<CAPTION>", "No caption generated.") | |
st.subheader("Generated Caption:") | |
st.success(caption) |