Didier commited on
Commit
c1f96a9
·
verified ·
1 Parent(s): 794ccef

Create vlm.py

Browse files
Files changed (1) hide show
  1. vlm.py +161 -0
vlm.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File: vlm.py
3
+ Description: Vision language model utility functions.
4
+
5
+ Heavily inspired (i.e. copied) from
6
+ https://huggingface.co/spaces/HuggingFaceTB/SmolVLM2/blob/main/app.py
7
+
8
+ Author: Didier Guillevic
9
+ Date: 2025-04-02
10
+ """
11
+
12
+ from transformers import AutoProcessor, AutoModelForImageTextToText
13
+ from transformers import TextIteratorStreamer
14
+ from threading import Thread
15
+ import re
16
+ import time
17
+ import torch
18
+ import spaces
19
+ import subprocess
20
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
21
+
22
+ from io import BytesIO
23
+
24
+ #
25
+ # Load the model: HuggingFaceTB/SmolVLM2-2.2B-Instruct
26
+ #
27
+
28
+ model_id = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
29
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+ processor = AutoProcessor.from_pretrained(model_id)
31
+ model = AutoModelForImageTextToText.from_pretrained(
32
+ model_id,
33
+ _attn_implementation="flash_attention_2",
34
+ torch_dtype=torch.bfloat16
35
+ ).to(device)
36
+
37
+ #
38
+ # Build messages
39
+ #
40
+ def build_messages(input_dict: dict, history: list[tuple]):
41
+ """Build messages given message & history from a **multimodal** chat interface.
42
+ Args:
43
+ input_dict: dictionary with keys: 'text', 'files'
44
+ history: list of tuples with (message, response)
45
+
46
+ Returns:
47
+ list of messages (to be sent to the model)
48
+ """
49
+ text = input_dict["text"]
50
+ images = []
51
+ user_content = []
52
+ media_queue = []
53
+ if history == []:
54
+ text = input_dict["text"].strip()
55
+
56
+ for file in input_dict.get("files", []):
57
+ if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
58
+ media_queue.append({"type": "image", "path": file})
59
+ elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
60
+ media_queue.append({"type": "video", "path": file})
61
+
62
+ if "<image>" in text or "<video>" in text:
63
+ parts = re.split(r'(<image>|<video>)', text)
64
+ for part in parts:
65
+ if part == "<image>" and media_queue:
66
+ user_content.append(media_queue.pop(0))
67
+ elif part == "<video>" and media_queue:
68
+ user_content.append(media_queue.pop(0))
69
+ elif part.strip():
70
+ user_content.append({"type": "text", "text": part.strip()})
71
+ else:
72
+ user_content.append({"type": "text", "text": text})
73
+
74
+ for media in media_queue:
75
+ user_content.append(media)
76
+
77
+ resulting_messages = [{"role": "user", "content": user_content}]
78
+
79
+ elif len(history) > 0:
80
+ resulting_messages = []
81
+ user_content = []
82
+ media_queue = []
83
+ for hist in history:
84
+ if hist["role"] == "user" and isinstance(hist["content"], tuple):
85
+ file_name = hist["content"][0]
86
+ if file_name.endswith((".png", ".jpg", ".jpeg")):
87
+ media_queue.append({"type": "image", "path": file_name})
88
+ elif file_name.endswith(".mp4"):
89
+ media_queue.append({"type": "video", "path": file_name})
90
+
91
+
92
+ for hist in history:
93
+ if hist["role"] == "user" and isinstance(hist["content"], str):
94
+ text = hist["content"]
95
+ parts = re.split(r'(<image>|<video>)', text)
96
+
97
+ for part in parts:
98
+ if part == "<image>" and media_queue:
99
+ user_content.append(media_queue.pop(0))
100
+ elif part == "<video>" and media_queue:
101
+ user_content.append(media_queue.pop(0))
102
+ elif part.strip():
103
+ user_content.append({"type": "text", "text": part.strip()})
104
+
105
+ elif hist["role"] == "assistant":
106
+ resulting_messages.append({
107
+ "role": "user",
108
+ "content": user_content
109
+ })
110
+ resulting_messages.append({
111
+ "role": "assistant",
112
+ "content": [{"type": "text", "text": hist["content"]}]
113
+ })
114
+ user_content = []
115
+
116
+
117
+ if text == "" and not images:
118
+ gr.Error("Please input a query and optionally image(s).")
119
+
120
+ if text == "" and images:
121
+ gr.Error("Please input a text query along the images(s).")
122
+
123
+ return resulting_messages
124
+
125
+ #
126
+ # Streaming response
127
+ #
128
+ @spaces.GPU
129
+ @torch.inference_mode()
130
+ def stream_response(messages: list[dict]):
131
+ """Stream the model's response to the chat interface.
132
+
133
+ Args:
134
+ messages: list of messages to send to the model
135
+ """
136
+ # Generate model's response
137
+ inputs = processor.apply_chat_template(
138
+ resulting_messages,
139
+ add_generation_prompt=True,
140
+ tokenize=True,
141
+ return_dict=True,
142
+ return_tensors="pt",
143
+ ).to(model.device, dtype=torch.bfloat16)
144
+
145
+ # Generate
146
+ streamer = TextIteratorStreamer(
147
+ processor, skip_prompt=True, skip_special_tokens=True)
148
+ generation_args = dict(
149
+ inputs,
150
+ streamer=streamer,
151
+ max_new_tokens=2_048,
152
+ do_sample=True
153
+ )
154
+
155
+ thread = Thread(target=model.generate, kwargs=generation_args)
156
+ thread.start()
157
+
158
+ partial_message = ""
159
+ for new_text in streamer:
160
+ partial_message += new_text
161
+ yield partial_message