aimeri commited on
Commit
8ddcd3c
·
1 Parent(s): e4a9a7a

Refactor process_input and create_demo functions in app.py to enhance chat history management and improve text response handling, including the addition of user and assistant avatars. Update input clearing logic for better user experience with multimodal inputs.

Browse files

Add binary files using Git LFS

Update .gitattributes to include PNG files in Git LFS and add new binary images for user and assistant avatars.

Update .gitattributes to include PNG files in Git LFS and add new binary images for user and assistant avatars.

Refactor model initialization and process_input function in app.py to improve GPU memory management and enhance error handling during multimodal input processing. Introduce a dedicated get_model function to manage model loading and memory clearing, ensuring efficient resource usage.

Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +181 -132
  3. assistant.png +3 -0
  4. user.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -5,19 +5,28 @@ from qwen_omni_utils import process_mm_info
5
  import soundfile as sf
6
  import tempfile
7
  import spaces
 
8
 
9
  # Initialize the model and processor
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
12
 
13
- model = Qwen2_5OmniModel.from_pretrained(
14
- "Qwen/Qwen2.5-Omni-7B",
15
- torch_dtype=torch_dtype,
16
- device_map="auto",
17
- enable_audio_output=True,
18
- # attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
19
- )
 
 
 
 
 
 
 
20
 
 
21
  processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")
22
 
23
  # System prompt
@@ -34,131 +43,166 @@ VOICE_OPTIONS = {
34
 
35
  @spaces.GPU
36
  def process_input(image, audio, video, text, chat_history, voice_type, enable_audio_output):
37
- # Combine multimodal inputs
38
- user_input = {
39
- "text": text,
40
- "image": image if image is not None else None,
41
- "audio": audio if audio is not None else None,
42
- "video": video if video is not None else None
43
- }
44
-
45
- # Prepare conversation history for model processing
46
- conversation = [SYSTEM_PROMPT]
47
-
48
- # Add previous chat history
49
- if isinstance(chat_history, list):
50
- for item in chat_history:
51
- if isinstance(item, list) and len(item) == 2:
52
- user_msg, bot_msg = item
53
- if bot_msg is not None: # Only add complete message pairs
54
- # Convert display format back to processable format
55
- processed_msg = user_msg
56
- if "[Image]" in user_msg:
57
- processed_msg = {"type": "text", "text": user_msg.replace("[Image]", "").strip()}
58
- if "[Audio]" in user_msg:
59
- processed_msg = {"type": "text", "text": user_msg.replace("[Audio]", "").strip()}
60
- if "[Video]" in user_msg:
61
- processed_msg = {"type": "text", "text": user_msg.replace("[Video]", "").strip()}
62
-
63
- conversation.append({"role": "user", "content": processed_msg})
64
- conversation.append({"role": "assistant", "content": bot_msg})
65
- else:
66
- # Initialize chat history if it's not a list
67
- chat_history = []
68
-
69
- # Add current user input
70
- conversation.append({"role": "user", "content": user_input_to_content(user_input)})
71
-
72
- # Prepare for inference
73
- text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
74
  try:
75
- audios, images, videos = process_mm_info(conversation, use_audio_in_video=True)
76
- except AssertionError:
77
- # If video doesn't have audio, try without audio processing
78
- audios, images, videos = process_mm_info(conversation, use_audio_in_video=False)
79
-
80
- inputs = processor(
81
- text=text,
82
- audios=audios,
83
- images=images,
84
- videos=videos,
85
- return_tensors="pt",
86
- padding=True
87
- )
88
- inputs = inputs.to(model.device).to(model.dtype)
89
-
90
- # Generate response with streaming
91
- if enable_audio_output:
92
- voice_type_value = VOICE_OPTIONS.get(voice_type, "Chelsie")
93
- text_ids, audio = model.generate(
94
- **inputs,
95
- use_audio_in_video=False, # Set to False to avoid audio processing issues
96
- return_audio=True,
97
- spk=voice_type_value,
98
- max_new_tokens=512,
99
- do_sample=True,
100
- temperature=0.7,
101
- top_p=0.9,
102
- streamer=TextStreamer(processor, skip_prompt=True)
103
- )
104
 
105
- # Save audio to temporary file
106
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
107
- sf.write(
108
- tmp_file.name,
109
- audio.reshape(-1).detach().cpu().numpy(),
110
- samplerate=24000,
111
- )
112
- audio_path = tmp_file.name
113
- else:
114
- text_ids = model.generate(
115
- **inputs,
116
- use_audio_in_video=False, # Set to False to avoid audio processing issues
117
- return_audio=False,
118
- max_new_tokens=512,
119
- do_sample=True,
120
- temperature=0.7,
121
- top_p=0.9,
122
- streamer=TextStreamer(processor, skip_prompt=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  )
124
- audio_path = None
125
-
126
- # Decode text response
127
- text_response = processor.batch_decode(
128
- text_ids,
129
- skip_special_tokens=True,
130
- clean_up_tokenization_spaces=False
131
- )[0]
132
-
133
- # Clean up text response by removing system/user messages
134
- text_response = text_response.strip()
135
- text_response = text_response.split("assistant")[-1].strip()
136
- if text_response.startswith(":"):
137
- text_response = text_response[1:].strip()
138
-
139
- # Format user message for chat history display
140
- user_message_for_display = str(text) if text is not None else ""
141
- if image is not None:
142
- user_message_for_display = (user_message_for_display + " " if user_message_for_display.strip() else "") + "[Image]"
143
- if audio is not None:
144
- user_message_for_display = (user_message_for_display + " " if user_message_for_display.strip() else "") + "[Audio]"
145
- if video is not None:
146
- user_message_for_display = (user_message_for_display + " " if user_message_for_display.strip() else "") + "[Video]"
147
-
148
- # If empty, provide a default message
149
- if not user_message_for_display.strip():
150
- user_message_for_display = "Multimodal input"
151
-
152
- # Update chat history with properly formatted entries
153
- if not isinstance(chat_history, list):
154
- chat_history = []
155
- chat_history.append([user_message_for_display, text_response])
156
-
157
- # Prepare output
158
- if enable_audio_output and audio_path:
159
- return chat_history, text_response, audio_path
160
- else:
161
- return chat_history, text_response, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  def user_input_to_content(user_input):
164
  if isinstance(user_input, str):
@@ -193,8 +237,7 @@ def create_demo():
193
  chatbot = gr.Chatbot(
194
  height=600,
195
  show_label=False,
196
- avatar_images=["👤", "🤖"],
197
- bubble_full_width=False,
198
  )
199
  with gr.Accordion("Advanced Options", open=False):
200
  voice_type = gr.Dropdown(
@@ -277,7 +320,7 @@ def create_demo():
277
 
278
  # Text input handling
279
  text_submit.click(
280
- fn=lambda text: [[str(text) if text is not None else "", None]],
281
  inputs=text_input,
282
  outputs=[chatbot],
283
  queue=False
@@ -285,6 +328,9 @@ def create_demo():
285
  fn=process_input,
286
  inputs=[placeholder_image, placeholder_audio, placeholder_video, text_input, chatbot, voice_type, enable_audio_output],
287
  outputs=[chatbot, text_output, audio_output]
 
 
 
288
  )
289
 
290
  # Multimodal input handling
@@ -313,6 +359,9 @@ def create_demo():
313
  inputs=[image_input, audio_input, video_input, additional_text,
314
  chatbot, voice_type, enable_audio_output],
315
  outputs=[chatbot, text_output, audio_output]
 
 
 
316
  )
317
 
318
  # Clear chat
 
5
  import soundfile as sf
6
  import tempfile
7
  import spaces
8
+ import gc
9
 
10
  # Initialize the model and processor
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
13
 
14
+ def get_model():
15
+ if torch.cuda.is_available():
16
+ torch.cuda.empty_cache()
17
+ gc.collect()
18
+
19
+ model = Qwen2_5OmniModel.from_pretrained(
20
+ "Qwen/Qwen2.5-Omni-7B",
21
+ torch_dtype=torch_dtype,
22
+ device_map="auto",
23
+ enable_audio_output=True,
24
+ low_cpu_mem_usage=True,
25
+ # attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
26
+ )
27
+ return model
28
 
29
+ model = get_model()
30
  processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")
31
 
32
  # System prompt
 
43
 
44
  @spaces.GPU
45
  def process_input(image, audio, video, text, chat_history, voice_type, enable_audio_output):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  try:
47
+ # Clear GPU memory before processing
48
+ if torch.cuda.is_available():
49
+ torch.cuda.empty_cache()
50
+ gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # Combine multimodal inputs
53
+ user_input = {
54
+ "text": text,
55
+ "image": image if image is not None else None,
56
+ "audio": audio if audio is not None else None,
57
+ "video": video if video is not None else None
58
+ }
59
+
60
+ # Prepare conversation history for model processing
61
+ conversation = [SYSTEM_PROMPT]
62
+
63
+ # Add previous chat history
64
+ if isinstance(chat_history, list):
65
+ for item in chat_history:
66
+ if isinstance(item, list) and len(item) == 2:
67
+ user_msg, bot_msg = item
68
+ if bot_msg is not None: # Only add complete message pairs
69
+ # Convert display format back to processable format
70
+ processed_msg = user_msg
71
+ if "[Image]" in user_msg:
72
+ processed_msg = {"type": "text", "text": user_msg.replace("[Image]", "").strip()}
73
+ if "[Audio]" in user_msg:
74
+ processed_msg = {"type": "text", "text": user_msg.replace("[Audio]", "").strip()}
75
+ if "[Video]" in user_msg:
76
+ processed_msg = {"type": "text", "text": user_msg.replace("[Video]", "").strip()}
77
+
78
+ conversation.append({"role": "user", "content": processed_msg})
79
+ conversation.append({"role": "assistant", "content": bot_msg})
80
+
81
+ # Add current user input
82
+ conversation.append({"role": "user", "content": user_input_to_content(user_input)})
83
+
84
+ # Prepare for inference
85
+ model_input = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
86
+ try:
87
+ audios, images, videos = process_mm_info(conversation, use_audio_in_video=False) # Default to no audio in video
88
+ except Exception as e:
89
+ print(f"Error processing multimedia: {str(e)}")
90
+ audios, images, videos = [], [], [] # Fallback to empty lists
91
+
92
+ inputs = processor(
93
+ text=model_input,
94
+ audios=audios,
95
+ images=images,
96
+ videos=videos,
97
+ return_tensors="pt",
98
+ padding=True
99
  )
100
+
101
+ # Move inputs to device and convert dtype
102
+ inputs = {k: v.to(device=model.device, dtype=model.dtype) if isinstance(v, torch.Tensor) else v
103
+ for k, v in inputs.items()}
104
+
105
+ # Generate response with streaming
106
+ try:
107
+ if enable_audio_output:
108
+ voice_type_value = VOICE_OPTIONS.get(voice_type, "Chelsie")
109
+ text_ids, audio = model.generate(
110
+ **inputs,
111
+ use_audio_in_video=False, # Set to False to avoid audio processing issues
112
+ return_audio=True,
113
+ spk=voice_type_value,
114
+ max_new_tokens=512,
115
+ do_sample=True,
116
+ temperature=0.7,
117
+ top_p=0.9,
118
+ streamer=TextStreamer(processor, skip_prompt=True)
119
+ )
120
+
121
+ # Save audio to temporary file
122
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
123
+ sf.write(
124
+ tmp_file.name,
125
+ audio.reshape(-1).detach().cpu().numpy(),
126
+ samplerate=24000,
127
+ )
128
+ audio_path = tmp_file.name
129
+ else:
130
+ text_ids = model.generate(
131
+ **inputs,
132
+ use_audio_in_video=False, # Set to False to avoid audio processing issues
133
+ return_audio=False,
134
+ max_new_tokens=512,
135
+ do_sample=True,
136
+ temperature=0.7,
137
+ top_p=0.9,
138
+ streamer=TextStreamer(processor, skip_prompt=True)
139
+ )
140
+ audio_path = None
141
+
142
+ # Decode text response
143
+ text_response = processor.batch_decode(
144
+ text_ids,
145
+ skip_special_tokens=True,
146
+ clean_up_tokenization_spaces=False
147
+ )[0]
148
+
149
+ # Clean up text response by removing system/user messages and special tokens
150
+ text_response = text_response.strip()
151
+ # Remove everything before the last assistant's message
152
+ if "<|im_start|>assistant" in text_response:
153
+ text_response = text_response.split("<|im_start|>assistant")[-1]
154
+ # Remove any remaining special tokens
155
+ text_response = text_response.replace("<|im_end|>", "").replace("<|im_start|>", "")
156
+ if text_response.startswith(":"):
157
+ text_response = text_response[1:].strip()
158
+
159
+ # Format user message for chat history display
160
+ user_message_for_display = str(text) if text is not None else ""
161
+ if image is not None:
162
+ user_message_for_display = (user_message_for_display + " " if user_message_for_display.strip() else "") + "[Image]"
163
+ if audio is not None:
164
+ user_message_for_display = (user_message_for_display + " " if user_message_for_display.strip() else "") + "[Audio]"
165
+ if video is not None:
166
+ user_message_for_display = (user_message_for_display + " " if user_message_for_display.strip() else "") + "[Video]"
167
+
168
+ # If empty, provide a default message
169
+ if not user_message_for_display.strip():
170
+ user_message_for_display = "Multimodal input"
171
+
172
+ # Update chat history with properly formatted entries
173
+ if not isinstance(chat_history, list):
174
+ chat_history = []
175
+
176
+ # Find the last incomplete message pair if it exists
177
+ if chat_history and isinstance(chat_history[-1], list) and len(chat_history[-1]) == 2 and chat_history[-1][1] is None:
178
+ chat_history[-1][1] = text_response
179
+ else:
180
+ chat_history.append([user_message_for_display, text_response])
181
+
182
+ # Clear GPU memory after processing
183
+ if torch.cuda.is_available():
184
+ torch.cuda.empty_cache()
185
+ gc.collect()
186
+
187
+ # Prepare output
188
+ if enable_audio_output and audio_path:
189
+ return chat_history, text_response, audio_path
190
+ else:
191
+ return chat_history, text_response, None
192
+
193
+ except Exception as e:
194
+ print(f"Error during generation: {str(e)}")
195
+ error_msg = "I apologize, but I encountered an error processing your request. Please try again."
196
+ chat_history.append([user_message_for_display, error_msg])
197
+ return chat_history, error_msg, None
198
+
199
+ except Exception as e:
200
+ print(f"Error in process_input: {str(e)}")
201
+ if not isinstance(chat_history, list):
202
+ chat_history = []
203
+ error_msg = "I apologize, but I encountered an error processing your request. Please try again."
204
+ chat_history.append([str(text) if text is not None else "Error", error_msg])
205
+ return chat_history, error_msg, None
206
 
207
  def user_input_to_content(user_input):
208
  if isinstance(user_input, str):
 
237
  chatbot = gr.Chatbot(
238
  height=600,
239
  show_label=False,
240
+ avatar_images=["user.png", "assistant.png"]
 
241
  )
242
  with gr.Accordion("Advanced Options", open=False):
243
  voice_type = gr.Dropdown(
 
320
 
321
  # Text input handling
322
  text_submit.click(
323
+ fn=lambda text: [[text if text is not None else "", None]],
324
  inputs=text_input,
325
  outputs=[chatbot],
326
  queue=False
 
328
  fn=process_input,
329
  inputs=[placeholder_image, placeholder_audio, placeholder_video, text_input, chatbot, voice_type, enable_audio_output],
330
  outputs=[chatbot, text_output, audio_output]
331
+ ).then(
332
+ fn=lambda: "", # Clear input after submission
333
+ outputs=text_input
334
  )
335
 
336
  # Multimodal input handling
 
359
  inputs=[image_input, audio_input, video_input, additional_text,
360
  chatbot, voice_type, enable_audio_output],
361
  outputs=[chatbot, text_output, audio_output]
362
+ ).then(
363
+ fn=lambda: (None, None, None, ""), # Clear inputs after submission
364
+ outputs=[image_input, audio_input, video_input, additional_text]
365
  )
366
 
367
  # Clear chat
assistant.png ADDED

Git LFS Details

  • SHA256: b1ae5b5af8ab5b2ade90759a4f22383c796b7dbc7cc60ef9569a5e793edcf280
  • Pointer size: 131 Bytes
  • Size of remote file: 687 kB
user.png ADDED

Git LFS Details

  • SHA256: 88c7c82d7c8682dc1a140c3173034b62c9e2fc7c1cbe48d31d6fa427346dc89f
  • Pointer size: 131 Bytes
  • Size of remote file: 577 kB