Spaces:
Running
on
Zero
Running
on
Zero
Update demo/infer.py
Browse files- demo/infer.py +29 -26
demo/infer.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import functools, torch
|
2 |
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
|
3 |
apply_liger_kernel_to_qwen2_vl()
|
4 |
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, LogitsProcessor, logging
|
@@ -62,6 +62,7 @@ class LiveCCDemoInfer:
|
|
62 |
repetition_penalty: float = 1.05,
|
63 |
streaming_eos_base_threshold: float = None,
|
64 |
streaming_eos_threshold_step: float = None,
|
|
|
65 |
**kwargs,
|
66 |
):
|
67 |
"""
|
@@ -83,6 +84,7 @@ class LiveCCDemoInfer:
|
|
83 |
state['video_pts'] = torch.from_numpy(video_reader._frame_pts[:, 1])
|
84 |
state['last_video_pts_index'] = -1
|
85 |
video_pts = state['video_pts']
|
|
|
86 |
if last_timestamp + self.frame_time_interval > video_pts[-1]:
|
87 |
state['video_end'] = True
|
88 |
return
|
@@ -140,7 +142,7 @@ class LiveCCDemoInfer:
|
|
140 |
return_tensors="pt",
|
141 |
return_attention_mask=False
|
142 |
)
|
143 |
-
inputs.to(
|
144 |
if past_ids is not None:
|
145 |
inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
|
146 |
if streaming_eos_base_threshold is not None:
|
@@ -153,9 +155,11 @@ class LiveCCDemoInfer:
|
|
153 |
repetition_penalty=repetition_penalty,
|
154 |
logits_processor=logits_processor,
|
155 |
)
|
156 |
-
state['past_key_values'] = outputs.past_key_values
|
157 |
-
state['past_ids'] = outputs.sequences[:, :-1]
|
158 |
-
|
|
|
|
|
159 |
|
160 |
@torch.inference_mode()
|
161 |
def video_qa(
|
@@ -165,7 +169,7 @@ class LiveCCDemoInfer:
|
|
165 |
state: dict,
|
166 |
do_sample: bool = False,
|
167 |
repetition_penalty: float = 1.05,
|
168 |
-
|
169 |
**kwargs,
|
170 |
):
|
171 |
"""
|
@@ -178,25 +182,24 @@ class LiveCCDemoInfer:
|
|
178 |
last_history: list, last processed history
|
179 |
"""
|
180 |
video_path = state.get('video_path', None)
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
{"type": "
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
else:
|
191 |
-
|
192 |
-
"role": "user",
|
193 |
-
"content": [
|
194 |
-
{"type": "text", "text": query},
|
195 |
-
],
|
196 |
-
}
|
197 |
-
image_inputs, video_inputs = process_vision_info([message])
|
198 |
-
texts = self.processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True, return_tensors='pt')
|
199 |
past_ids = state.get('past_ids', None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
if past_ids is not None:
|
201 |
texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
|
202 |
inputs = self.processor(
|
@@ -204,6 +207,7 @@ class LiveCCDemoInfer:
|
|
204 |
images=image_inputs,
|
205 |
videos=video_inputs,
|
206 |
return_tensors="pt",
|
|
|
207 |
)
|
208 |
inputs.to(self.model.device)
|
209 |
if past_ids is not None:
|
@@ -214,9 +218,8 @@ class LiveCCDemoInfer:
|
|
214 |
repetition_penalty=repetition_penalty,
|
215 |
max_new_tokens=512,
|
216 |
)
|
217 |
-
state['past_key_values'] = outputs.past_key_values
|
218 |
-
state['past_ids'] = outputs.sequences[:, :-1]
|
219 |
response = self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True)
|
220 |
print(response)
|
221 |
-
state.pop('past_key_values')
|
222 |
return response, state
|
|
|
1 |
+
import functools, torch
|
2 |
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
|
3 |
apply_liger_kernel_to_qwen2_vl()
|
4 |
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, LogitsProcessor, logging
|
|
|
62 |
repetition_penalty: float = 1.05,
|
63 |
streaming_eos_base_threshold: float = None,
|
64 |
streaming_eos_threshold_step: float = None,
|
65 |
+
hf_spaces: bool = False,
|
66 |
**kwargs,
|
67 |
):
|
68 |
"""
|
|
|
84 |
state['video_pts'] = torch.from_numpy(video_reader._frame_pts[:, 1])
|
85 |
state['last_video_pts_index'] = -1
|
86 |
video_pts = state['video_pts']
|
87 |
+
video_timestamp = min(video_timestamp, video_pts[-1])
|
88 |
if last_timestamp + self.frame_time_interval > video_pts[-1]:
|
89 |
state['video_end'] = True
|
90 |
return
|
|
|
142 |
return_tensors="pt",
|
143 |
return_attention_mask=False
|
144 |
)
|
145 |
+
inputs.to(self.model.device)
|
146 |
if past_ids is not None:
|
147 |
inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
|
148 |
if streaming_eos_base_threshold is not None:
|
|
|
155 |
repetition_penalty=repetition_penalty,
|
156 |
logits_processor=logits_processor,
|
157 |
)
|
158 |
+
state['past_key_values'] = outputs.past_key_values if not hf_spaces else None
|
159 |
+
state['past_ids'] = outputs.sequences[:, :-1] if not hf_spaces else None
|
160 |
+
response = self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True)
|
161 |
+
print(response)
|
162 |
+
yield (start_timestamp, stop_timestamp), response, state
|
163 |
|
164 |
@torch.inference_mode()
|
165 |
def video_qa(
|
|
|
169 |
state: dict,
|
170 |
do_sample: bool = False,
|
171 |
repetition_penalty: float = 1.05,
|
172 |
+
hf_spaces: bool = False,
|
173 |
**kwargs,
|
174 |
):
|
175 |
"""
|
|
|
182 |
last_history: list, last processed history
|
183 |
"""
|
184 |
video_path = state.get('video_path', None)
|
185 |
+
conversation = []
|
186 |
+
if hf_spaces:
|
187 |
+
for past_message in history:
|
188 |
+
content = [{"type": "text", "text": past_message['content']}]
|
189 |
+
if video_path: # only use once
|
190 |
+
content.insert(0, {"type": "video", "video": video_path})
|
191 |
+
video_path = None
|
192 |
+
conversation.append({"role": past_message["role"], "content": content})
|
|
|
193 |
else:
|
194 |
+
pass # use past_key_values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
past_ids = state.get('past_ids', None)
|
196 |
+
content = [{"type": "text", "text": message}]
|
197 |
+
if past_ids is None and video_path: # only use once
|
198 |
+
content.insert(0, {"type": "video", "video": video_path})
|
199 |
+
conversation.append({"role": "user", "content": content})
|
200 |
+
print(conversation)
|
201 |
+
image_inputs, video_inputs = process_vision_info(conversation)
|
202 |
+
texts = self.processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True, return_tensors='pt')
|
203 |
if past_ids is not None:
|
204 |
texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
|
205 |
inputs = self.processor(
|
|
|
207 |
images=image_inputs,
|
208 |
videos=video_inputs,
|
209 |
return_tensors="pt",
|
210 |
+
return_attention_mask=False
|
211 |
)
|
212 |
inputs.to(self.model.device)
|
213 |
if past_ids is not None:
|
|
|
218 |
repetition_penalty=repetition_penalty,
|
219 |
max_new_tokens=512,
|
220 |
)
|
221 |
+
state['past_key_values'] = outputs.past_key_values if not hf_spaces else None
|
222 |
+
state['past_ids'] = outputs.sequences[:, :-1] if not hf_spaces else None
|
223 |
response = self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True)
|
224 |
print(response)
|
|
|
225 |
return response, state
|