Spaces:
Running
on
Zero
Running
on
Zero
fix
Browse files- app.py +143 -4
- demo/infer.py +212 -0
app.py
CHANGED
@@ -1,7 +1,146 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
import gradio as gr
|
3 |
|
4 |
+
from demo.infer import LiveCCDemoInfer
|
|
|
5 |
|
6 |
+
class GradioBackend:
|
7 |
+
waiting_video_response = 'Waiting for video input...'
|
8 |
+
not_found_video_response = 'Video does not exist...'
|
9 |
+
mode2api = {
|
10 |
+
'Real-Time Commentary': 'live_cc',
|
11 |
+
'Conversation': 'video_qa'
|
12 |
+
}
|
13 |
+
def __init__(self, model_path: str = 'chenjoya/LiveCC-7B-Instruct'):
|
14 |
+
self.infer = LiveCCDemoInfer(model_path)
|
15 |
+
from kokoro import KPipeline
|
16 |
+
self.audio_pipeline = KPipeline(lang_code='a')
|
17 |
+
|
18 |
+
def __call__(self, query: str = None, state: dict = {}, mode: str = 'Real-Time Commentary', **kwargs):
|
19 |
+
return getattr(self.infer, self.mode2api[mode])(query=query, state=state, **kwargs)
|
20 |
+
|
21 |
+
gradio_backend = GradioBackend()
|
22 |
+
|
23 |
+
with gr.Blocks() as demo:
|
24 |
+
gr.Markdown("## LiveCC Real-Time Commentary and Conversation - Gradio Demo")
|
25 |
+
gr.Markdown("#### [LiveCC: Learning Video LLM with Streaming Speech Transcription at Scale](https://showlab.github.io/livecc/)")
|
26 |
+
gr_state = gr.State({}, render=False) # control all useful state, including kv cache
|
27 |
+
gr_video_state = gr.JSON({}, visible=False) # only record video state, belong to gr_state but lightweight
|
28 |
+
gr_static_trigger = gr.Number(value=0, visible=False) # control start streaming or stop
|
29 |
+
gr_dynamic_trigger = gr.Number(value=0, visible=False) # for continuous refresh
|
30 |
+
with gr.Row():
|
31 |
+
with gr.Column():
|
32 |
+
gr_video = gr.Video(
|
33 |
+
label="video",
|
34 |
+
elem_id="gr_video",
|
35 |
+
visible=True,
|
36 |
+
sources=['upload'],
|
37 |
+
autoplay=True,
|
38 |
+
include_audio=False,
|
39 |
+
width=720,
|
40 |
+
height=480
|
41 |
+
)
|
42 |
+
gr_examples = gr.Examples(
|
43 |
+
examples=[
|
44 |
+
'demo/sources/howto_fix_laptop_mute_1080p.mp4',
|
45 |
+
],
|
46 |
+
inputs=[gr_video],
|
47 |
+
)
|
48 |
+
gr_clean_button = gr.Button("Clean (Press me before changing video)", elem_id="gr_button")
|
49 |
+
|
50 |
+
with gr.Column():
|
51 |
+
with gr.Row():
|
52 |
+
gr_radio_mode = gr.Radio(label="Select Mode", choices=["Real-Time Commentary", "Conversation"], elem_id="gr_radio_mode", value='Real-Time Commentary', interactive=True)
|
53 |
+
|
54 |
+
def gr_chatinterface_fn(message, history, state, mode):
|
55 |
+
response, state = gradio_backend(query=message, state=state, mode=mode)
|
56 |
+
return response, state
|
57 |
+
def gr_chatinterface_chatbot_clear_fn():
|
58 |
+
return {}, {}, 0, 0
|
59 |
+
gr_chatinterface = gr.ChatInterface(
|
60 |
+
fn=gr_chatinterface_fn,
|
61 |
+
type="messages",
|
62 |
+
additional_inputs=[gr_state, gr_radio_mode],
|
63 |
+
additional_outputs=[gr_state],
|
64 |
+
)
|
65 |
+
gr_chatinterface.chatbot.clear(fn=gr_chatinterface_chatbot_clear_fn, outputs=[gr_video_state, gr_state, gr_static_trigger, gr_dynamic_trigger])
|
66 |
+
gr_clean_button.click(fn=lambda :[[], *gr_chatinterface_chatbot_clear_fn()], outputs=[gr_video_state, gr_state, gr_static_trigger, gr_dynamic_trigger])
|
67 |
+
|
68 |
+
def gr_for_streaming(history: list[gr.ChatMessage], video_state: dict, state: dict, mode: str, static_trigger: int, dynamic_trigger: int):
|
69 |
+
# if static_trigger == 0:
|
70 |
+
# return gr_chatinterface_chatbot_clear_fn()
|
71 |
+
# if video_state['video_path'] != state.get('video_path', None):
|
72 |
+
# return gr_chatinterface_chatbot_clear_fn()
|
73 |
+
state.update(video_state)
|
74 |
+
query, assistant_waiting_message = None, None
|
75 |
+
for message in history[::-1]:
|
76 |
+
if message['role'] == 'user':
|
77 |
+
if message['metadata'] is None or message['metadata'].get('status', '') == '':
|
78 |
+
query = message['content']
|
79 |
+
if message['metadata'] is None:
|
80 |
+
message['metadata'] = {}
|
81 |
+
message['metadata']['status'] = 'pending'
|
82 |
+
continue
|
83 |
+
if query is not None: # put others as done
|
84 |
+
message['metadata']['status'] = 'done'
|
85 |
+
elif message['content'] == GradioBackend.waiting_video_response:
|
86 |
+
assistant_waiting_message = message
|
87 |
+
|
88 |
+
for (start_timestamp, stop_timestamp), response, state in gradio_backend(query=query, state=state, mode=mode):
|
89 |
+
if start_timestamp >= 0:
|
90 |
+
response_with_timestamp = f'{start_timestamp:.1f}s-{stop_timestamp:.1f}s: {response}'
|
91 |
+
if assistant_waiting_message is None:
|
92 |
+
history.append(gr.ChatMessage(role="assistant", content=response_with_timestamp))
|
93 |
+
else:
|
94 |
+
assistant_waiting_message['content'] = response_with_timestamp
|
95 |
+
assistant_waiting_message = None
|
96 |
+
yield history, state, dynamic_trigger
|
97 |
+
yield history, state, 1 - dynamic_trigger
|
98 |
+
|
99 |
+
js_video_timestamp_fetcher = """
|
100 |
+
(state, video_state) => {
|
101 |
+
const videoEl = document.querySelector("#gr_video video");
|
102 |
+
return { video_path: videoEl.currentSrc, video_timestamp: videoEl.currentTime };
|
103 |
+
}
|
104 |
+
"""
|
105 |
+
gr_video.change(fn=lambda :[1,1], outputs=[gr_static_trigger, gr_dynamic_trigger])
|
106 |
+
|
107 |
+
def gr_get_video_state(video_state):
|
108 |
+
print(video_state)
|
109 |
+
if 'file=' in video_state['video_path']:
|
110 |
+
video_state['video_path'] = video_state['video_path'].split('file=')[1]
|
111 |
+
return video_state
|
112 |
+
gr_dynamic_trigger.change(
|
113 |
+
fn=gr_get_video_state,
|
114 |
+
inputs=[gr_video_state],
|
115 |
+
outputs=[gr_video_state],
|
116 |
+
js=js_video_timestamp_fetcher
|
117 |
+
).then(
|
118 |
+
fn=gr_for_streaming,
|
119 |
+
inputs=[gr_chatinterface.chatbot, gr_video_state, gr_state, gr_radio_mode, gr_static_trigger, gr_dynamic_trigger],
|
120 |
+
outputs=[gr_chatinterface.chatbot, gr_state, gr_dynamic_trigger],
|
121 |
+
)
|
122 |
+
|
123 |
+
demo.queue(max_size=5, default_concurrency_limit=5)
|
124 |
+
demo.launch(share=True)
|
125 |
+
|
126 |
+
|
127 |
+
# --- for streaming ---
|
128 |
+
|
129 |
+
# gr_tts = gr.Audio(visible=False, elem_id="gr_tts", streaming=True, autoplay=True)
|
130 |
+
# def tts():
|
131 |
+
# while True:
|
132 |
+
# contents = ''
|
133 |
+
# while not gradio_backend.contents.empty():
|
134 |
+
# content = gradio_backend.contents.get()
|
135 |
+
# contents += ' ' + content.rstrip(' ...')
|
136 |
+
# contents = contents.strip()
|
137 |
+
# if contents:
|
138 |
+
# generator = gradio_backend.audio_pipeline(contents, voice='af_heart', speed=1.2)
|
139 |
+
# for _, _, audio_torch in generator:
|
140 |
+
# audio_np = audio_torch.cpu().numpy()
|
141 |
+
# max_val = np.max(np.abs(audio_np))
|
142 |
+
# if max_val > 0:
|
143 |
+
# audio_np = audio_np / max_val
|
144 |
+
# audio_int16 = (audio_np * 32767).astype(np.int16)
|
145 |
+
# yield (24000, audio_int16)
|
146 |
+
# gr_video.change(fn=tts, outputs=[gr_tts])
|
demo/infer.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools, torch, os, tqdm
|
2 |
+
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
|
3 |
+
apply_liger_kernel_to_qwen2_vl()
|
4 |
+
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, LogitsProcessor, logging
|
5 |
+
from livecc_utils import prepare_multiturn_multimodal_inputs_for_generation, get_smart_resized_clip, get_smart_resized_video_reader
|
6 |
+
from qwen_vl_utils import process_vision_info
|
7 |
+
|
8 |
+
logger = logging.get_logger(__name__)
|
9 |
+
|
10 |
+
class ThresholdLogitsProcessor(LogitsProcessor):
|
11 |
+
def __init__(self, token_id: int, base_threshold: float, step: float):
|
12 |
+
self.token_id = token_id
|
13 |
+
self.base_threshold = base_threshold
|
14 |
+
self.step = step
|
15 |
+
self.count = 0
|
16 |
+
|
17 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
18 |
+
threshold = self.base_threshold + self.step * self.count
|
19 |
+
low_confidence = torch.softmax(scores, dim=-1)[:, self.token_id] <= threshold
|
20 |
+
if low_confidence.any():
|
21 |
+
scores[low_confidence, self.token_id] = -float("inf")
|
22 |
+
self.count += 1
|
23 |
+
return scores
|
24 |
+
|
25 |
+
class LiveCCDemoInfer:
|
26 |
+
VIDEO_PLAY_END = object()
|
27 |
+
VIDEO_PLAY_CONTINUE = object()
|
28 |
+
fps = 2
|
29 |
+
initial_fps_frames = 6
|
30 |
+
streaming_fps_frames = 2
|
31 |
+
initial_time_interval = initial_fps_frames / fps
|
32 |
+
streaming_time_interval = streaming_fps_frames / fps
|
33 |
+
frame_time_interval = 1 / fps
|
34 |
+
def __init__(self, model_path: str = None, device_id: int = 0):
|
35 |
+
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
36 |
+
model_path, torch_dtype="auto",
|
37 |
+
device_map=f'cuda:{device_id}',
|
38 |
+
attn_implementation='flash_attention_2'
|
39 |
+
)
|
40 |
+
self.processor = AutoProcessor.from_pretrained(model_path, use_fast=False)
|
41 |
+
self.streaming_eos_token_id = self.processor.tokenizer(' ...').input_ids[-1]
|
42 |
+
self.model.prepare_inputs_for_generation = functools.partial(prepare_multiturn_multimodal_inputs_for_generation, self.model)
|
43 |
+
message = {
|
44 |
+
"role": "user",
|
45 |
+
"content": [
|
46 |
+
{"type": "text", "text": 'livecc'},
|
47 |
+
]
|
48 |
+
}
|
49 |
+
texts = self.processor.apply_chat_template([message], tokenize=False)
|
50 |
+
self.system_prompt_offset = texts.index('<|im_start|>user')
|
51 |
+
self._cached_video_readers_with_hw = {}
|
52 |
+
|
53 |
+
@torch.inference_mode()
|
54 |
+
def live_cc(
|
55 |
+
self,
|
56 |
+
query: str,
|
57 |
+
state: dict,
|
58 |
+
max_pixels: int = 384 * 28 * 28,
|
59 |
+
default_query: str = 'Please describe the video.',
|
60 |
+
do_sample: bool = False,
|
61 |
+
repetition_penalty: float = 1.05,
|
62 |
+
streaming_eos_base_threshold: float = None,
|
63 |
+
streaming_eos_threshold_step: float = None,
|
64 |
+
**kwargs,
|
65 |
+
):
|
66 |
+
"""
|
67 |
+
state: dict, (maybe) with keys:
|
68 |
+
video_path: str, video path
|
69 |
+
video_timestamp: float, current video timestamp
|
70 |
+
last_timestamp: float, last processed video timestamp
|
71 |
+
last_video_pts_index: int, last processed video frame index
|
72 |
+
video_pts: np.ndarray, video pts
|
73 |
+
last_history: list, last processed history
|
74 |
+
"""
|
75 |
+
# 1. preparation: video_reader, and last processing info
|
76 |
+
video_timestamp, last_timestamp = state.get('video_timestamp', 0), state.get('last_timestamp', -1 / self.fps)
|
77 |
+
video_path = state['video_path']
|
78 |
+
if video_path not in self._cached_video_readers_with_hw:
|
79 |
+
self._cached_video_readers_with_hw[video_path] = get_smart_resized_video_reader(video_path, max_pixels)
|
80 |
+
video_reader = self._cached_video_readers_with_hw[video_path][0]
|
81 |
+
video_reader.get_frame_timestamp(0)
|
82 |
+
state['video_pts'] = torch.from_numpy(video_reader._frame_pts[:, 1])
|
83 |
+
state['last_video_pts_index'] = -1
|
84 |
+
video_pts = state['video_pts']
|
85 |
+
if last_timestamp + self.frame_time_interval > video_pts[-1]:
|
86 |
+
state['video_end'] = True
|
87 |
+
return
|
88 |
+
video_reader, resized_height, resized_width = self._cached_video_readers_with_hw[video_path]
|
89 |
+
last_video_pts_index = state['last_video_pts_index']
|
90 |
+
|
91 |
+
# 2. which frames will be processed
|
92 |
+
initialized = last_timestamp >= 0
|
93 |
+
if not initialized:
|
94 |
+
video_timestamp = max(video_timestamp, self.initial_time_interval)
|
95 |
+
if video_timestamp <= last_timestamp + self.frame_time_interval:
|
96 |
+
return
|
97 |
+
timestamps = torch.arange(last_timestamp + self.frame_time_interval, video_timestamp, self.frame_time_interval) # add compensation
|
98 |
+
|
99 |
+
# 3. fetch frames in required timestamps
|
100 |
+
clip, clip_timestamps, clip_idxs = get_smart_resized_clip(video_reader, resized_height, resized_width, timestamps, video_pts, video_pts_index_from=last_video_pts_index+1)
|
101 |
+
state['last_video_pts_index'] = clip_idxs[-1]
|
102 |
+
state['last_timestamp'] = clip_timestamps[-1]
|
103 |
+
|
104 |
+
# 4. organize to interleave frames
|
105 |
+
interleave_clips, interleave_timestamps = [], []
|
106 |
+
if not initialized:
|
107 |
+
interleave_clips.append(clip[:self.initial_fps_frames])
|
108 |
+
interleave_timestamps.append(clip_timestamps[:self.initial_fps_frames])
|
109 |
+
clip = clip[self.initial_fps_frames:]
|
110 |
+
clip_timestamps = clip_timestamps[self.initial_fps_frames:]
|
111 |
+
if len(clip) > 0:
|
112 |
+
interleave_clips.extend(list(clip.split(self.streaming_fps_frames)))
|
113 |
+
interleave_timestamps.extend(list(clip_timestamps.split(self.streaming_fps_frames)))
|
114 |
+
|
115 |
+
# 5. make conversation and send to model
|
116 |
+
for clip, timestamps in zip(interleave_clips, interleave_timestamps):
|
117 |
+
start_timestamp, stop_timestamp = timestamps[0].item(), timestamps[-1].item() + self.frame_time_interval
|
118 |
+
message = {
|
119 |
+
"role": "user",
|
120 |
+
"content": [
|
121 |
+
{"type": "text", "text": f'Time={start_timestamp:.1f}-{stop_timestamp:.1f}s'},
|
122 |
+
{"type": "video", "video": clip}
|
123 |
+
]
|
124 |
+
}
|
125 |
+
if not query and not state.get('query', None):
|
126 |
+
query = default_query
|
127 |
+
logger.warning(f'No query provided, use default_query={default_query}')
|
128 |
+
if query and state.get('query', None) != query:
|
129 |
+
message['content'].append({"type": "text", "text": query})
|
130 |
+
state['query'] = query
|
131 |
+
texts = self.processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True, return_tensors='pt')
|
132 |
+
past_ids = state.get('past_ids', None)
|
133 |
+
if past_ids is not None:
|
134 |
+
texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
|
135 |
+
inputs = self.processor(
|
136 |
+
text=texts,
|
137 |
+
images=None,
|
138 |
+
videos=[clip],
|
139 |
+
return_tensors="pt",
|
140 |
+
return_attention_mask=False
|
141 |
+
)
|
142 |
+
inputs.to('cuda')
|
143 |
+
if past_ids is not None:
|
144 |
+
inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
|
145 |
+
if streaming_eos_base_threshold is not None:
|
146 |
+
logits_processor = [ThresholdLogitsProcessor(self.streaming_eos_token_id, streaming_eos_base_threshold, streaming_eos_threshold_step)]
|
147 |
+
else:
|
148 |
+
logits_processor = None
|
149 |
+
outputs = self.model.generate(
|
150 |
+
**inputs, past_key_values=state.get('past_key_values', None),
|
151 |
+
return_dict_in_generate=True, do_sample=do_sample,
|
152 |
+
repetition_penalty=repetition_penalty,
|
153 |
+
logits_processor=logits_processor,
|
154 |
+
)
|
155 |
+
state['past_key_values'] = outputs.past_key_values
|
156 |
+
state['past_ids'] = outputs.sequences[:, :-1]
|
157 |
+
yield (start_timestamp, stop_timestamp), self.processor.decode(outputs.sequences[0, inputs.input_ids.size(1):], skip_special_tokens=True), state
|
158 |
+
|
159 |
+
def video_qa(
|
160 |
+
model,
|
161 |
+
processor,
|
162 |
+
video_path: str,
|
163 |
+
query: str,
|
164 |
+
answer_prefix: str = '',
|
165 |
+
video_start: float = None,
|
166 |
+
video_end: float = None,
|
167 |
+
strict_fps: bool = False,
|
168 |
+
strict_abcd_ids: list[int] = None,
|
169 |
+
do_sample: bool = False,
|
170 |
+
max_new_tokens: int = 128
|
171 |
+
):
|
172 |
+
if strict_fps:
|
173 |
+
video_inputs, _ = _read_video_decord_plus({'video': video_path, 'video_start': video_start, 'video_end': video_end}, strict_fps=True, drop_last=False)
|
174 |
+
video_inputs = _spatial_resize_video(video_inputs)
|
175 |
+
conversation = [
|
176 |
+
{
|
177 |
+
"role": "user",
|
178 |
+
"content": [
|
179 |
+
{"type": "video", "video": video_inputs},
|
180 |
+
{"type": "text", "text": query},
|
181 |
+
],
|
182 |
+
}
|
183 |
+
]
|
184 |
+
image_inputs = None
|
185 |
+
else:
|
186 |
+
conversation = [
|
187 |
+
{
|
188 |
+
"role": "user",
|
189 |
+
"content": [
|
190 |
+
{"type": "video", "video": video_path, "video_start": video_start, "video_end": video_end},
|
191 |
+
{"type": "text", "text": query},
|
192 |
+
],
|
193 |
+
}
|
194 |
+
]
|
195 |
+
image_inputs, video_inputs = process_vision_info(conversation)
|
196 |
+
text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) + answer_prefix
|
197 |
+
inputs = processor(
|
198 |
+
text=[text],
|
199 |
+
images=image_inputs,
|
200 |
+
videos=video_inputs,
|
201 |
+
return_tensors="pt",
|
202 |
+
)
|
203 |
+
print(text)
|
204 |
+
inputs = inputs.to("cuda")
|
205 |
+
if not strict_abcd_ids:
|
206 |
+
generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=do_sample)
|
207 |
+
output_text = processor.decode(generated_ids[0, inputs.input_ids.size(1):], clean_up_tokenization_spaces=False)
|
208 |
+
else:
|
209 |
+
outputs = model.generate(**inputs, do_sample=do_sample, top_p=None, temperature=None, top_k=None, max_new_tokens=1, return_dict_in_generate=True, output_scores=True, repetition_penalty=1)
|
210 |
+
print(outputs.scores[0][0, strict_abcd_ids])
|
211 |
+
output_text = ['A', 'B', 'C', 'D'][outputs.scores[0][0, strict_abcd_ids].argmax()]
|
212 |
+
return output_text
|