File size: 5,467 Bytes
c1f96a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e4b55b
 
 
 
 
c1f96a9
 
 
 
 
 
 
5d69ad6
c1f96a9
 
 
 
 
 
 
 
 
 
 
 
8e4b55b
 
 
c1f96a9
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
File: vlm.py
Description: Vision language model utility functions.

Heavily inspired (i.e. copied) from
    https://huggingface.co./spaces/HuggingFaceTB/SmolVLM2/blob/main/app.py

Author: Didier Guillevic
Date: 2025-04-02
"""

from transformers import AutoProcessor, AutoModelForImageTextToText
from transformers import TextIteratorStreamer
from threading import Thread
import re
import time
import torch
import spaces
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

from io import BytesIO

#
# Load the model: HuggingFaceTB/SmolVLM2-2.2B-Instruct
#

model_id = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(
    model_id, 
    _attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16
).to(device)

#
# Build messages
#
def build_messages(input_dict: dict, history: list[tuple]):
    """Build messages given message & history from a **multimodal** chat interface.
    Args:
        input_dict: dictionary with keys: 'text', 'files'
        history: list of tuples with (message, response)
    
    Returns:
        list of messages (to be sent to the model)
    """
    text = input_dict["text"]
    images = []
    user_content = []
    media_queue = []
    if history == []:
        text = input_dict["text"].strip() 
        
        for file in input_dict.get("files", []):
            if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
                media_queue.append({"type": "image", "path": file})
            elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
                media_queue.append({"type": "video", "path": file})

        if "<image>" in text or "<video>" in text:
            parts = re.split(r'(<image>|<video>)', text)  
            for part in parts:
                if part == "<image>" and media_queue:
                    user_content.append(media_queue.pop(0)) 
                elif part == "<video>" and media_queue:
                    user_content.append(media_queue.pop(0))  
                elif part.strip():  
                    user_content.append({"type": "text", "text": part.strip()})
        else:
            user_content.append({"type": "text", "text": text})
            
            for media in media_queue:
                user_content.append(media)

        resulting_messages = [{"role": "user", "content": user_content}]

    elif len(history) > 0:
        resulting_messages = []
        user_content = []
        media_queue = []
        for hist in history:
            if hist["role"] == "user" and isinstance(hist["content"], tuple): 
                file_name = hist["content"][0]
            if file_name.endswith((".png", ".jpg", ".jpeg")):
                media_queue.append({"type": "image", "path": file_name})
            elif file_name.endswith(".mp4"):
                media_queue.append({"type": "video", "path": file_name})


        for hist in history:
            if hist["role"] == "user" and isinstance(hist["content"], str): 
                text = hist["content"]
                parts = re.split(r'(<image>|<video>)', text)  
                
                for part in parts:
                    if part == "<image>" and media_queue:
                        user_content.append(media_queue.pop(0)) 
                    elif part == "<video>" and media_queue:
                        user_content.append(media_queue.pop(0))  
                    elif part.strip(): 
                        user_content.append({"type": "text", "text": part.strip()})
            
            elif hist["role"] == "assistant":
                resulting_messages.append({
                    "role": "user",
                    "content": user_content
                })
                resulting_messages.append({
                    "role": "assistant",
                    "content": [{"type": "text", "text": hist["content"]}]
                })
                user_content = [] 


    if text == "" and not images:
        gr.Error("Please input a query and optionally image(s).")

    if text == "" and images:
        gr.Error("Please input a text query along the images(s).")
    
    return resulting_messages

#
# Streaming response
#
@spaces.GPU
@torch.inference_mode()
def stream_response(
        messages: list[dict],
        max_new_tokens: int=1_024,
        temperature: float=0.15
    ):
    """Stream the model's response to the chat interface.
    
    Args:
        messages: list of messages to send to the model
    """
    # Generate model's response
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device, dtype=torch.bfloat16)
    
    # Generate
    streamer = TextIteratorStreamer(
        processor, skip_prompt=True, skip_special_tokens=True)
    generation_args = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=0.9,
        do_sample=True
    )

    thread = Thread(target=model.generate, kwargs=generation_args)
    thread.start()

    partial_message = ""
    for new_text in streamer:
        partial_message += new_text
        yield partial_message