chenjoya commited on
Commit
82c2aee
·
1 Parent(s): c1fab3e
Files changed (2) hide show
  1. app.py +143 -4
  2. demo/infer.py +212 -0
app.py CHANGED
@@ -1,7 +1,146 @@
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
 
4
+ from demo.infer import LiveCCDemoInfer
 
5
 
6
+ class GradioBackend:
7
+ waiting_video_response = 'Waiting for video input...'
8
+ not_found_video_response = 'Video does not exist...'
9
+ mode2api = {
10
+ 'Real-Time Commentary': 'live_cc',
11
+ 'Conversation': 'video_qa'
12
+ }
13
+ def __init__(self, model_path: str = 'chenjoya/LiveCC-7B-Instruct'):
14
+ self.infer = LiveCCDemoInfer(model_path)
15
+ from kokoro import KPipeline
16
+ self.audio_pipeline = KPipeline(lang_code='a')
17
+
18
+ def __call__(self, query: str = None, state: dict = {}, mode: str = 'Real-Time Commentary', **kwargs):
19
+ return getattr(self.infer, self.mode2api[mode])(query=query, state=state, **kwargs)
20
+
21
+ gradio_backend = GradioBackend()
22
+
23
+ with gr.Blocks() as demo:
24
+ gr.Markdown("## LiveCC Real-Time Commentary and Conversation - Gradio Demo")
25
+ gr.Markdown("#### [LiveCC: Learning Video LLM with Streaming Speech Transcription at Scale](https://showlab.github.io/livecc/)")
26
+ gr_state = gr.State({}, render=False) # control all useful state, including kv cache
27
+ gr_video_state = gr.JSON({}, visible=False) # only record video state, belong to gr_state but lightweight
28
+ gr_static_trigger = gr.Number(value=0, visible=False) # control start streaming or stop
29
+ gr_dynamic_trigger = gr.Number(value=0, visible=False) # for continuous refresh
30
+ with gr.Row():
31
+ with gr.Column():
32
+ gr_video = gr.Video(
33
+ label="video",
34
+ elem_id="gr_video",
35
+ visible=True,
36
+ sources=['upload'],
37
+ autoplay=True,
38
+ include_audio=False,
39
+ width=720,
40
+ height=480
41
+ )
42
+ gr_examples = gr.Examples(
43
+ examples=[
44
+ 'demo/sources/howto_fix_laptop_mute_1080p.mp4',
45
+ ],
46
+ inputs=[gr_video],
47
+ )
48
+ gr_clean_button = gr.Button("Clean (Press me before changing video)", elem_id="gr_button")
49
+
50
+ with gr.Column():
51
+ with gr.Row():
52
+ gr_radio_mode = gr.Radio(label="Select Mode", choices=["Real-Time Commentary", "Conversation"], elem_id="gr_radio_mode", value='Real-Time Commentary', interactive=True)
53
+
54
+ def gr_chatinterface_fn(message, history, state, mode):
55
+ response, state = gradio_backend(query=message, state=state, mode=mode)
56
+ return response, state
57
+ def gr_chatinterface_chatbot_clear_fn():
58
+ return {}, {}, 0, 0
59
+ gr_chatinterface = gr.ChatInterface(
60
+ fn=gr_chatinterface_fn,
61
+ type="messages",
62
+ additional_inputs=[gr_state, gr_radio_mode],
63
+ additional_outputs=[gr_state],
64
+ )
65
+ gr_chatinterface.chatbot.clear(fn=gr_chatinterface_chatbot_clear_fn, outputs=[gr_video_state, gr_state, gr_static_trigger, gr_dynamic_trigger])
66
+ gr_clean_button.click(fn=lambda :[[], *gr_chatinterface_chatbot_clear_fn()], outputs=[gr_video_state, gr_state, gr_static_trigger, gr_dynamic_trigger])
67
+
68
+ def gr_for_streaming(history: list[gr.ChatMessage], video_state: dict, state: dict, mode: str, static_trigger: int, dynamic_trigger: int):
69
+ # if static_trigger == 0:
70
+ # return gr_chatinterface_chatbot_clear_fn()
71
+ # if video_state['video_path'] != state.get('video_path', None):
72
+ # return gr_chatinterface_chatbot_clear_fn()
73
+ state.update(video_state)
74
+ query, assistant_waiting_message = None, None
75
+ for message in history[::-1]:
76
+ if message['role'] == 'user':
77
+ if message['metadata'] is None or message['metadata'].get('status', '') == '':
78
+ query = message['content']
79
+ if message['metadata'] is None:
80
+ message['metadata'] = {}
81
+ message['metadata']['status'] = 'pending'
82
+ continue
83
+ if query is not None: # put others as done
84
+ message['metadata']['status'] = 'done'
85
+ elif message['content'] == GradioBackend.waiting_video_response:
86
+ assistant_waiting_message = message
87
+
88
+ for (start_timestamp, stop_timestamp), response, state in gradio_backend(query=query, state=state, mode=mode):
89
+ if start_timestamp >= 0:
90
+ response_with_timestamp = f'{start_timestamp:.1f}s-{stop_timestamp:.1f}s: {response}'
91
+ if assistant_waiting_message is None:
92
+ history.append(gr.ChatMessage(role="assistant", content=response_with_timestamp))
93
+ else:
94
+ assistant_waiting_message['content'] = response_with_timestamp
95
+ assistant_waiting_message = None
96
+ yield history, state, dynamic_trigger
97
+ yield history, state, 1 - dynamic_trigger
98
+
99
+ js_video_timestamp_fetcher = """
100
+ (state, video_state) => {
101
+ const videoEl = document.querySelector("#gr_video video");
102
+ return { video_path: videoEl.currentSrc, video_timestamp: videoEl.currentTime };
103
+ }
104
+ """
105
+ gr_video.change(fn=lambda :[1,1], outputs=[gr_static_trigger, gr_dynamic_trigger])
106
+
107
+ def gr_get_video_state(video_state):
108
+ print(video_state)
109
+ if 'file=' in video_state['video_path']:
110
+ video_state['video_path'] = video_state['video_path'].split('file=')[1]
111
+ return video_state
112
+ gr_dynamic_trigger.change(
113
+ fn=gr_get_video_state,
114
+ inputs=[gr_video_state],
115
+ outputs=[gr_video_state],
116
+ js=js_video_timestamp_fetcher
117
+ ).then(
118
+ fn=gr_for_streaming,
119
+ inputs=[gr_chatinterface.chatbot, gr_video_state, gr_state, gr_radio_mode, gr_static_trigger, gr_dynamic_trigger],
120
+ outputs=[gr_chatinterface.chatbot, gr_state, gr_dynamic_trigger],
121
+ )
122
+
123
+ demo.queue(max_size=5, default_concurrency_limit=5)
124
+ demo.launch(share=True)
125
+
126
+
127
+ # --- for streaming ---
128
+
129
+ # gr_tts = gr.Audio(visible=False, elem_id="gr_tts", streaming=True, autoplay=True)
130
+ # def tts():
131
+ # while True:
132
+ # contents = ''
133
+ # while not gradio_backend.contents.empty():
134
+ # content = gradio_backend.contents.get()
135
+ # contents += ' ' + content.rstrip(' ...')
136
+ # contents = contents.strip()
137
+ # if contents:
138
+ # generator = gradio_backend.audio_pipeline(contents, voice='af_heart', speed=1.2)
139
+ # for _, _, audio_torch in generator:
140
+ # audio_np = audio_torch.cpu().numpy()
141
+ # max_val = np.max(np.abs(audio_np))
142
+ # if max_val > 0:
143
+ # audio_np = audio_np / max_val
144
+ # audio_int16 = (audio_np * 32767).astype(np.int16)
145
+ # yield (24000, audio_int16)
146
+ # gr_video.change(fn=tts, outputs=[gr_tts])
demo/infer.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools, torch, os, tqdm
2
+ from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
3
+ apply_liger_kernel_to_qwen2_vl()
4
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, LogitsProcessor, logging
5
+ from livecc_utils import prepare_multiturn_multimodal_inputs_for_generation, get_smart_resized_clip, get_smart_resized_video_reader
6
+ from qwen_vl_utils import process_vision_info
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+ class ThresholdLogitsProcessor(LogitsProcessor):
11
+ def __init__(self, token_id: int, base_threshold: float, step: float):
12
+ self.token_id = token_id
13
+ self.base_threshold = base_threshold
14
+ self.step = step
15
+ self.count = 0
16
+
17
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
18
+ threshold = self.base_threshold + self.step * self.count
19
+ low_confidence = torch.softmax(scores, dim=-1)[:, self.token_id] <= threshold
20
+ if low_confidence.any():
21
+ scores[low_confidence, self.token_id] = -float("inf")
22
+ self.count += 1
23
+ return scores
24
+
25
+ class LiveCCDemoInfer:
26
+ VIDEO_PLAY_END = object()
27
+ VIDEO_PLAY_CONTINUE = object()
28
+ fps = 2
29
+ initial_fps_frames = 6
30
+ streaming_fps_frames = 2
31
+ initial_time_interval = initial_fps_frames / fps
32
+ streaming_time_interval = streaming_fps_frames / fps
33
+ frame_time_interval = 1 / fps
34
+ def __init__(self, model_path: str = None, device_id: int = 0):
35
+ self.model = Qwen2VLForConditionalGeneration.from_pretrained(
36
+ model_path, torch_dtype="auto",
37
+ device_map=f'cuda:{device_id}',
38
+ attn_implementation='flash_attention_2'
39
+ )
40
+ self.processor = AutoProcessor.from_pretrained(model_path, use_fast=False)
41
+ self.streaming_eos_token_id = self.processor.tokenizer(' ...').input_ids[-1]
42
+ self.model.prepare_inputs_for_generation = functools.partial(prepare_multiturn_multimodal_inputs_for_generation, self.model)
43
+ message = {
44
+ "role": "user",
45
+ "content": [
46
+ {"type": "text", "text": 'livecc'},
47
+ ]
48
+ }
49
+ texts = self.processor.apply_chat_template([message], tokenize=False)
50
+ self.system_prompt_offset = texts.index('<|im_start|>user')
51
+ self._cached_video_readers_with_hw = {}
52
+
53
+ @torch.inference_mode()
54
+ def live_cc(
55
+ self,
56
+ query: str,
57
+ state: dict,
58
+ max_pixels: int = 384 * 28 * 28,
59
+ default_query: str = 'Please describe the video.',
60
+ do_sample: bool = False,
61
+ repetition_penalty: float = 1.05,
62
+ streaming_eos_base_threshold: float = None,
63
+ streaming_eos_threshold_step: float = None,
64
+ **kwargs,
65
+ ):
66
+ """
67
+ state: dict, (maybe) with keys:
68
+ video_path: str, video path
69
+ video_timestamp: float, current video timestamp
70
+ last_timestamp: float, last processed video timestamp
71
+ last_video_pts_index: int, last processed video frame index
72
+ video_pts: np.ndarray, video pts
73
+ last_history: list, last processed history
74
+ """
75
+ # 1. preparation: video_reader, and last processing info
76
+ video_timestamp, last_timestamp = state.get('video_timestamp', 0), state.get('last_timestamp', -1 / self.fps)
77
+ video_path = state['video_path']
78
+ if video_path not in self._cached_video_readers_with_hw:
79
+ self._cached_video_readers_with_hw[video_path] = get_smart_resized_video_reader(video_path, max_pixels)
80
+ video_reader = self._cached_video_readers_with_hw[video_path][0]
81
+ video_reader.get_frame_timestamp(0)
82
+ state['video_pts'] = torch.from_numpy(video_reader._frame_pts[:, 1])
83
+ state['last_video_pts_index'] = -1
84
+ video_pts = state['video_pts']
85
+ if last_timestamp + self.frame_time_interval > video_pts[-1]:
86
+ state['video_end'] = True
87
+ return
88
+ video_reader, resized_height, resized_width = self._cached_video_readers_with_hw[video_path]
89
+ last_video_pts_index = state['last_video_pts_index']
90
+
91
+ # 2. which frames will be processed
92
+ initialized = last_timestamp >= 0
93
+ if not initialized:
94
+ video_timestamp = max(video_timestamp, self.initial_time_interval)
95
+ if video_timestamp <= last_timestamp + self.frame_time_interval:
96
+ return
97
+ timestamps = torch.arange(last_timestamp + self.frame_time_interval, video_timestamp, self.frame_time_interval) # add compensation
98
+
99
+ # 3. fetch frames in required timestamps
100
+ clip, clip_timestamps, clip_idxs = get_smart_resized_clip(video_reader, resized_height, resized_width, timestamps, video_pts, video_pts_index_from=last_video_pts_index+1)
101
+ state['last_video_pts_index'] = clip_idxs[-1]
102
+ state['last_timestamp'] = clip_timestamps[-1]
103
+
104
+ # 4. organize to interleave frames
105
+ interleave_clips, interleave_timestamps = [], []
106
+ if not initialized:
107
+ interleave_clips.append(clip[:self.initial_fps_frames])
108
+ interleave_timestamps.append(clip_timestamps[:self.initial_fps_frames])
109
+ clip = clip[self.initial_fps_frames:]
110
+ clip_timestamps = clip_timestamps[self.initial_fps_frames:]
111
+ if len(clip) > 0:
112
+ interleave_clips.extend(list(clip.split(self.streaming_fps_frames)))
113
+ interleave_timestamps.extend(list(clip_timestamps.split(self.streaming_fps_frames)))
114
+
115
+ # 5. make conversation and send to model
116
+ for clip, timestamps in zip(interleave_clips, interleave_timestamps):
117
+ start_timestamp, stop_timestamp = timestamps[0].item(), timestamps[-1].item() + self.frame_time_interval
118
+ message = {
119
+ "role": "user",
120
+ "content": [
121
+ {"type": "text", "text": f'Time={start_timestamp:.1f}-{stop_timestamp:.1f}s'},
122
+ {"type": "video", "video": clip}
123
+ ]
124
+ }
125
+ if not query and not state.get('query', None):
126
+ query = default_query
127
+ logger.warning(f'No query provided, use default_query={default_query}')
128
+ if query and state.get('query', None) != query:
129
+ message['content'].append({"type": "text", "text": query})
130
+ state['query'] = query
131
+ texts = self.processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True, return_tensors='pt')
132
+ past_ids = state.get('past_ids', None)
133
+ if past_ids is not None:
134
+ texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
135
+ inputs = self.processor(
136
+ text=texts,
137
+ images=None,
138
+ videos=[clip],
139
+ return_tensors="pt",
140
+ return_attention_mask=False
141
+ )
142
+ inputs.to('cuda')
143
+ if past_ids is not None:
144
+ inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
145
+ if streaming_eos_base_threshold is not None:
146
+ logits_processor = [ThresholdLogitsProcessor(self.streaming_eos_token_id, streaming_eos_base_threshold, streaming_eos_threshold_step)]
147
+ else:
148
+ logits_processor = None
149
+ outputs = self.model.generate(
150
+ **inputs, past_key_values=state.get('past_key_values', None),
151
+ return_dict_in_generate=True, do_sample=do_sample,
152
+ repetition_penalty=repetition_penalty,
153
+ logits_processor=logits_processor,
154
+ )
155
+ state['past_key_values'] = outputs.past_key_values
156
+ state['past_ids'] = outputs.sequences[:, :-1]
157
+ yield (start_timestamp, stop_timestamp), self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True), state
158
+
159
+ def video_qa(
160
+ model,
161
+ processor,
162
+ video_path: str,
163
+ query: str,
164
+ answer_prefix: str = '',
165
+ video_start: float = None,
166
+ video_end: float = None,
167
+ strict_fps: bool = False,
168
+ strict_abcd_ids: list[int] = None,
169
+ do_sample: bool = False,
170
+ max_new_tokens: int = 128
171
+ ):
172
+ if strict_fps:
173
+ video_inputs, _ = _read_video_decord_plus({'video': video_path, 'video_start': video_start, 'video_end': video_end}, strict_fps=True, drop_last=False)
174
+ video_inputs = _spatial_resize_video(video_inputs)
175
+ conversation = [
176
+ {
177
+ "role": "user",
178
+ "content": [
179
+ {"type": "video", "video": video_inputs},
180
+ {"type": "text", "text": query},
181
+ ],
182
+ }
183
+ ]
184
+ image_inputs = None
185
+ else:
186
+ conversation = [
187
+ {
188
+ "role": "user",
189
+ "content": [
190
+ {"type": "video", "video": video_path, "video_start": video_start, "video_end": video_end},
191
+ {"type": "text", "text": query},
192
+ ],
193
+ }
194
+ ]
195
+ image_inputs, video_inputs = process_vision_info(conversation)
196
+ text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) + answer_prefix
197
+ inputs = processor(
198
+ text=[text],
199
+ images=image_inputs,
200
+ videos=video_inputs,
201
+ return_tensors="pt",
202
+ )
203
+ print(text)
204
+ inputs = inputs.to("cuda")
205
+ if not strict_abcd_ids:
206
+ generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=do_sample)
207
+ output_text = processor.decode(generated_ids[0, inputs.input_ids.size(1):], clean_up_tokenization_spaces=False)
208
+ else:
209
+ outputs = model.generate(**inputs, do_sample=do_sample, top_p=None, temperature=None, top_k=None, max_new_tokens=1, return_dict_in_generate=True, output_scores=True, repetition_penalty=1)
210
+ print(outputs.scores[0][0, strict_abcd_ids])
211
+ output_text = ['A', 'B', 'C', 'D'][outputs.scores[0][0, strict_abcd_ids].argmax()]
212
+ return output_text