File size: 3,904 Bytes
981c23a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import streamlit as st
import torch
import numpy as np
import pandas as pd
from PIL import Image, ImageDraw
from transformers import AutoProcessor, AutoModelForCausalLM

# Device settings
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Load model with caching
@st.cache_resource
def load_model():
    CHECKPOINT = "microsoft/Florence-2-base-ft"
    model = AutoModelForCausalLM.from_pretrained(CHECKPOINT, trust_remote_code=True).to(device, dtype=torch_dtype)
    processor = AutoProcessor.from_pretrained(CHECKPOINT, trust_remote_code=True)
    return model, processor

# Load the model and processor
try:
    model, processor = load_model()
except Exception as e:
    st.error(f"Model loading failed: {e}")
    st.stop()

# UI title
st.title("Florence-2 Multi-Modal Model Playground")

# Task selector
task = st.selectbox("Select Task", ["Object Detection (OD)", "Phrase Grounding (PG)", "Image Captioning (IC)"])

# Phrase input for PG
phrase = ""
if task == "Phrase Grounding (PG)":
    phrase = st.text_input("Enter phrase for grounding (e.g., 'A red car')", "")

# Image uploader
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])

# If file uploaded
if uploaded_file:
    try:
        image = Image.open(uploaded_file).convert("RGB")
    except Exception as e:
        st.error(f"Error loading image: {e}")
        st.stop()

    st.image(image, caption="Uploaded Image", use_container_width=True)

    # Task-specific prompt
    if task == "Object Detection (OD)":
        task_prompt = "<OD>"
    elif task == "Phrase Grounding (PG)":
        task_prompt = "<CAPTION_TO_PHRASE_GROUNDING>"
    else:
        task_prompt = "<CAPTION>"

    # Preprocess inputs
    try:
        inputs = processor(text=task_prompt + phrase, images=image, return_tensors="pt").to(device, torch_dtype)
    except Exception as e:
        st.error(f"Error during preprocessing: {e}")
        st.stop()

    # Generate output
    with torch.no_grad():
        try:
            generated_ids = model.generate(
                input_ids=inputs["input_ids"],
                pixel_values=inputs["pixel_values"],
                max_new_tokens=512,
                num_beams=3,
                do_sample=False
            )
        except Exception as e:
            st.error(f"Error during generation: {e}")
            st.stop()

    # Decode and post-process
    try:
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
        parsed_answer = processor.post_process_generation(
            generated_text,
            task=task_prompt,
            image_size=(image.width, image.height)
        )
    except Exception as e:
        st.error(f"Post-processing failed: {e}")
        st.stop()

    # Display results
    if task in ["Object Detection (OD)", "Phrase Grounding (PG)"]:
        key = "<OD>" if task == "Object Detection (OD)" else "<CAPTION_TO_PHRASE_GROUNDING>"
        detections = parsed_answer.get(key, {"bboxes": [], "labels": []})
        bboxes = detections.get("bboxes", [])
        labels = detections.get("labels", [])

        draw = ImageDraw.Draw(image)
        data = []

        for bbox, label in zip(bboxes, labels):
            x_min, y_min, x_max, y_max = map(int, bbox)
            draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)
            draw.text((x_min, max(0, y_min - 10)), label, fill="red")
            data.append([x_min, y_min, x_max - x_min, y_max - y_min, label])

        st.image(image, caption="Detected Objects", use_container_width=True)
        df = pd.DataFrame(data, columns=["x", "y", "w", "h", "object"])
        st.dataframe(df)

    else:
        caption = parsed_answer.get("<CAPTION>", "No caption generated.")
        st.subheader("Generated Caption:")
        st.success(caption)