chenjoya commited on
Commit
d1a4ede
·
verified ·
1 Parent(s): 292389d

Update demo/infer.py

Browse files
Files changed (1) hide show
  1. demo/infer.py +29 -26
demo/infer.py CHANGED
@@ -1,4 +1,4 @@
1
- import functools, torch, os, tqdm
2
  from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
3
  apply_liger_kernel_to_qwen2_vl()
4
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, LogitsProcessor, logging
@@ -62,6 +62,7 @@ class LiveCCDemoInfer:
62
  repetition_penalty: float = 1.05,
63
  streaming_eos_base_threshold: float = None,
64
  streaming_eos_threshold_step: float = None,
 
65
  **kwargs,
66
  ):
67
  """
@@ -83,6 +84,7 @@ class LiveCCDemoInfer:
83
  state['video_pts'] = torch.from_numpy(video_reader._frame_pts[:, 1])
84
  state['last_video_pts_index'] = -1
85
  video_pts = state['video_pts']
 
86
  if last_timestamp + self.frame_time_interval > video_pts[-1]:
87
  state['video_end'] = True
88
  return
@@ -140,7 +142,7 @@ class LiveCCDemoInfer:
140
  return_tensors="pt",
141
  return_attention_mask=False
142
  )
143
- inputs.to('cuda')
144
  if past_ids is not None:
145
  inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
146
  if streaming_eos_base_threshold is not None:
@@ -153,9 +155,11 @@ class LiveCCDemoInfer:
153
  repetition_penalty=repetition_penalty,
154
  logits_processor=logits_processor,
155
  )
156
- state['past_key_values'] = outputs.past_key_values
157
- state['past_ids'] = outputs.sequences[:, :-1]
158
- yield (start_timestamp, stop_timestamp), self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True), state
 
 
159
 
160
  @torch.inference_mode()
161
  def video_qa(
@@ -165,7 +169,7 @@ class LiveCCDemoInfer:
165
  state: dict,
166
  do_sample: bool = False,
167
  repetition_penalty: float = 1.05,
168
- hf_space: bool = False,
169
  **kwargs,
170
  ):
171
  """
@@ -178,25 +182,24 @@ class LiveCCDemoInfer:
178
  last_history: list, last processed history
179
  """
180
  video_path = state.get('video_path', None)
181
- if video_path:
182
- message = {
183
- "role": "user",
184
- "content": [
185
- {"type": "video", "video": video_path},
186
- {"type": "text", "text": query},
187
- ],
188
- }
189
-
190
  else:
191
- message = {
192
- "role": "user",
193
- "content": [
194
- {"type": "text", "text": query},
195
- ],
196
- }
197
- image_inputs, video_inputs = process_vision_info([message])
198
- texts = self.processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True, return_tensors='pt')
199
  past_ids = state.get('past_ids', None)
 
 
 
 
 
 
 
200
  if past_ids is not None:
201
  texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
202
  inputs = self.processor(
@@ -204,6 +207,7 @@ class LiveCCDemoInfer:
204
  images=image_inputs,
205
  videos=video_inputs,
206
  return_tensors="pt",
 
207
  )
208
  inputs.to(self.model.device)
209
  if past_ids is not None:
@@ -214,9 +218,8 @@ class LiveCCDemoInfer:
214
  repetition_penalty=repetition_penalty,
215
  max_new_tokens=512,
216
  )
217
- state['past_key_values'] = outputs.past_key_values
218
- state['past_ids'] = outputs.sequences[:, :-1]
219
  response = self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True)
220
  print(response)
221
- state.pop('past_key_values')
222
  return response, state
 
1
+ import functools, torch
2
  from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
3
  apply_liger_kernel_to_qwen2_vl()
4
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, LogitsProcessor, logging
 
62
  repetition_penalty: float = 1.05,
63
  streaming_eos_base_threshold: float = None,
64
  streaming_eos_threshold_step: float = None,
65
+ hf_spaces: bool = False,
66
  **kwargs,
67
  ):
68
  """
 
84
  state['video_pts'] = torch.from_numpy(video_reader._frame_pts[:, 1])
85
  state['last_video_pts_index'] = -1
86
  video_pts = state['video_pts']
87
+ video_timestamp = min(video_timestamp, video_pts[-1])
88
  if last_timestamp + self.frame_time_interval > video_pts[-1]:
89
  state['video_end'] = True
90
  return
 
142
  return_tensors="pt",
143
  return_attention_mask=False
144
  )
145
+ inputs.to(self.model.device)
146
  if past_ids is not None:
147
  inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
148
  if streaming_eos_base_threshold is not None:
 
155
  repetition_penalty=repetition_penalty,
156
  logits_processor=logits_processor,
157
  )
158
+ state['past_key_values'] = outputs.past_key_values if not hf_spaces else None
159
+ state['past_ids'] = outputs.sequences[:, :-1] if not hf_spaces else None
160
+ response = self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True)
161
+ print(response)
162
+ yield (start_timestamp, stop_timestamp), response, state
163
 
164
  @torch.inference_mode()
165
  def video_qa(
 
169
  state: dict,
170
  do_sample: bool = False,
171
  repetition_penalty: float = 1.05,
172
+ hf_spaces: bool = False,
173
  **kwargs,
174
  ):
175
  """
 
182
  last_history: list, last processed history
183
  """
184
  video_path = state.get('video_path', None)
185
+ conversation = []
186
+ if hf_spaces:
187
+ for past_message in history:
188
+ content = [{"type": "text", "text": past_message['content']}]
189
+ if video_path: # only use once
190
+ content.insert(0, {"type": "video", "video": video_path})
191
+ video_path = None
192
+ conversation.append({"role": past_message["role"], "content": content})
 
193
  else:
194
+ pass # use past_key_values
 
 
 
 
 
 
 
195
  past_ids = state.get('past_ids', None)
196
+ content = [{"type": "text", "text": message}]
197
+ if past_ids is None and video_path: # only use once
198
+ content.insert(0, {"type": "video", "video": video_path})
199
+ conversation.append({"role": "user", "content": content})
200
+ print(conversation)
201
+ image_inputs, video_inputs = process_vision_info(conversation)
202
+ texts = self.processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True, return_tensors='pt')
203
  if past_ids is not None:
204
  texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
205
  inputs = self.processor(
 
207
  images=image_inputs,
208
  videos=video_inputs,
209
  return_tensors="pt",
210
+ return_attention_mask=False
211
  )
212
  inputs.to(self.model.device)
213
  if past_ids is not None:
 
218
  repetition_penalty=repetition_penalty,
219
  max_new_tokens=512,
220
  )
221
+ state['past_key_values'] = outputs.past_key_values if not hf_spaces else None
222
+ state['past_ids'] = outputs.sequences[:, :-1] if not hf_spaces else None
223
  response = self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True)
224
  print(response)
 
225
  return response, state