liuguilin commited on
Commit
37df18e
·
1 Parent(s): d62b227
Files changed (3) hide show
  1. app.py +41 -26
  2. eagle_vl/serve/chat_utils.py +42 -20
  3. eagle_vl/serve/inference.py +32 -13
app.py CHANGED
@@ -39,7 +39,7 @@ logger = configure_logger()
39
 
40
  def parse_args():
41
  parser = argparse.ArgumentParser()
42
- parser.add_argument("--model", type=str, default="Eagle-2.5-8B")
43
  parser.add_argument(
44
  "--local-path",
45
  type=str,
@@ -57,7 +57,7 @@ def fetch_model(model_name: str):
57
  if args.local_path:
58
  model_path = args.local_path
59
  else:
60
- model_path = f"nvidia/{args.model}"
61
 
62
  if model_name in DEPLOY_MODELS:
63
  model_info = DEPLOY_MODELS[model_name]
@@ -100,6 +100,7 @@ def predict(
100
  temperature,
101
  max_length_tokens,
102
  max_context_length_tokens,
 
103
  chunk_size: int = 512,
104
  ):
105
  """
@@ -116,18 +117,7 @@ def predict(
116
  max_context_length_tokens (int): The max context length tokens.
117
  chunk_size (int): The chunk size.
118
  """
119
- print("running the prediction function")
120
- try:
121
- logger.info("fetching model")
122
- model, processor = fetch_model(args.model)
123
- logger.info("model fetched")
124
- if text == "":
125
- yield chatbot, history, "Empty context."
126
- return
127
- except KeyError:
128
- logger.info("no model found")
129
- yield [[text, "No Model Found"]], [], "No Model Found"
130
- return
131
 
132
  if images is None:
133
  images = []
@@ -136,15 +126,33 @@ def predict(
136
  pil_images = []
137
  for img_or_file in images:
138
  try:
 
139
  # load as pil image
140
  if isinstance(images, Image.Image):
141
  pil_images.append(img_or_file)
142
- else:
143
- image = Image.open(img_or_file.name).convert("RGB")
144
- pil_images.append(image)
 
 
 
145
  except Exception as e:
146
  print(f"Error loading image: {e}")
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  # generate prompt
149
  conversation = generate_prompt_with_history(
150
  text,
@@ -166,6 +174,7 @@ def predict(
166
  max_length=max_length_tokens,
167
  temperature=temperature,
168
  top_p=top_p,
 
169
  ):
170
  full_response += x
171
  response = strip_stop_words(full_response, stop_words)
@@ -174,12 +183,12 @@ def predict(
174
 
175
  yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..."
176
 
177
- if last_image is not None:
178
- vg_image = parse_ref_bbox(response, last_image)
179
- if vg_image is not None:
180
- vg_base64 = pil_to_base64(vg_image, "vg", max_size=800, min_size=400)
181
- gradio_chatbot_output[-1][1] += vg_base64
182
- yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..."
183
 
184
  logger.info("flushed result to gradio")
185
 
@@ -202,6 +211,7 @@ def retry(
202
  temperature,
203
  max_length_tokens,
204
  max_context_length_tokens,
 
205
  chunk_size: int = 512,
206
  ):
207
  """
@@ -226,6 +236,7 @@ def retry(
226
  temperature,
227
  max_length_tokens,
228
  max_context_length_tokens,
 
229
  chunk_size,
230
  )
231
 
@@ -265,9 +276,10 @@ def build_demo(args: argparse.Namespace) -> gr.Blocks:
265
  with gr.Column():
266
  # add note no more than 2 images once
267
  # gr.Markdown("Note: you can upload no more than 2 images once")
268
- upload_images = gr.Files(file_types=["image"], show_label=True)
269
  gallery = gr.Gallery(columns=[3], height="200px", show_label=True)
270
  upload_images.change(preview_images, inputs=upload_images, outputs=gallery)
 
271
  # Parameter Setting Tab for control the generation parameters
272
  with gr.Tab(label="Parameter Setting"):
273
  top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p")
@@ -280,7 +292,9 @@ def build_demo(args: argparse.Namespace) -> gr.Blocks:
280
  max_context_length_tokens = gr.Slider(
281
  minimum=512, maximum=16384, value=4096, step=64, interactive=True, label="Max Context Length Tokens"
282
  )
283
-
 
 
284
  show_images = gr.HTML(visible=False)
285
 
286
  gr.Examples(
@@ -298,6 +312,7 @@ def build_demo(args: argparse.Namespace) -> gr.Blocks:
298
  temperature,
299
  max_length_tokens,
300
  max_context_length_tokens,
 
301
  ]
302
  output_widgets = [chatbot, history, status_display]
303
 
@@ -336,7 +351,7 @@ def main(args: argparse.Namespace):
336
  demo.queue().launch(
337
  favicon_path=favicon_path,
338
  server_name=args.ip,
339
- server_port=args.port
340
  )
341
 
342
 
 
39
 
40
  def parse_args():
41
  parser = argparse.ArgumentParser()
42
+ parser.add_argument("--model", type=str, default="Eagle2.5-VL-8B-Preview")
43
  parser.add_argument(
44
  "--local-path",
45
  type=str,
 
57
  if args.local_path:
58
  model_path = args.local_path
59
  else:
60
+ model_path = f"NVEagle/{args.model}"
61
 
62
  if model_name in DEPLOY_MODELS:
63
  model_info = DEPLOY_MODELS[model_name]
 
100
  temperature,
101
  max_length_tokens,
102
  max_context_length_tokens,
103
+ video_nframes,
104
  chunk_size: int = 512,
105
  ):
106
  """
 
117
  max_context_length_tokens (int): The max context length tokens.
118
  chunk_size (int): The chunk size.
119
  """
120
+
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  if images is None:
123
  images = []
 
126
  pil_images = []
127
  for img_or_file in images:
128
  try:
129
+ logger.info(f"img_or_file: {img_or_file}")
130
  # load as pil image
131
  if isinstance(images, Image.Image):
132
  pil_images.append(img_or_file)
133
+ elif isinstance(img_or_file, str):
134
+ if img_or_file.endswith((".mp4", ".mov", ".avi", ".webm")):
135
+ pil_images.append(img_or_file)
136
+ else:
137
+ image = Image.open(img_or_file.name).convert("RGB")
138
+ pil_images.append(image)
139
  except Exception as e:
140
  print(f"Error loading image: {e}")
141
 
142
+
143
+ print("running the prediction function")
144
+ try:
145
+ logger.info("fetching model")
146
+ model, processor = fetch_model(args.model)
147
+ logger.info("model fetched")
148
+ if text == "":
149
+ yield chatbot, history, "Empty context."
150
+ return
151
+ except KeyError:
152
+ logger.info("no model found")
153
+ yield [[text, "No Model Found"]], [], "No Model Found"
154
+ return
155
+
156
  # generate prompt
157
  conversation = generate_prompt_with_history(
158
  text,
 
174
  max_length=max_length_tokens,
175
  temperature=temperature,
176
  top_p=top_p,
177
+ video_nframes=video_nframes,
178
  ):
179
  full_response += x
180
  response = strip_stop_words(full_response, stop_words)
 
183
 
184
  yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..."
185
 
186
+ # if last_image is not None:
187
+ # vg_image = parse_ref_bbox(response, last_image)
188
+ # if vg_image is not None:
189
+ # vg_base64 = pil_to_base64(vg_image, "vg", max_size=800, min_size=400)
190
+ # gradio_chatbot_output[-1][1] += vg_base64
191
+ # yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..."
192
 
193
  logger.info("flushed result to gradio")
194
 
 
211
  temperature,
212
  max_length_tokens,
213
  max_context_length_tokens,
214
+ video_nframes,
215
  chunk_size: int = 512,
216
  ):
217
  """
 
236
  temperature,
237
  max_length_tokens,
238
  max_context_length_tokens,
239
+ video_nframes,
240
  chunk_size,
241
  )
242
 
 
276
  with gr.Column():
277
  # add note no more than 2 images once
278
  # gr.Markdown("Note: you can upload no more than 2 images once")
279
+ upload_images = gr.Files(file_types=["image", "video"], show_label=True)
280
  gallery = gr.Gallery(columns=[3], height="200px", show_label=True)
281
  upload_images.change(preview_images, inputs=upload_images, outputs=gallery)
282
+
283
  # Parameter Setting Tab for control the generation parameters
284
  with gr.Tab(label="Parameter Setting"):
285
  top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p")
 
292
  max_context_length_tokens = gr.Slider(
293
  minimum=512, maximum=16384, value=4096, step=64, interactive=True, label="Max Context Length Tokens"
294
  )
295
+ video_nframes = gr.Slider(
296
+ minimum=1, maximum=128, value=16, step=1, interactive=True, label="Video Nframes"
297
+ )
298
  show_images = gr.HTML(visible=False)
299
 
300
  gr.Examples(
 
312
  temperature,
313
  max_length_tokens,
314
  max_context_length_tokens,
315
+ video_nframes
316
  ]
317
  output_widgets = [chatbot, history, status_display]
318
 
 
351
  demo.queue().launch(
352
  favicon_path=favicon_path,
353
  server_name=args.ip,
354
+ server_port=args.port,
355
  )
356
 
357
 
eagle_vl/serve/chat_utils.py CHANGED
@@ -13,7 +13,7 @@ import gradio as gr
13
  import torch
14
  import os
15
  from .utils import pil_to_base64
16
-
17
  IMAGE_TOKEN = "<image>"
18
  logger = logging.getLogger("gradio_logger")
19
 
@@ -324,6 +324,7 @@ def convert_conversation_to_prompts(conversation: Conversation):
324
  Convert the conversation to prompts.
325
  """
326
  conv_prompts = []
 
327
  last_image = None
328
 
329
  messages = conversation.messages
@@ -342,34 +343,55 @@ def convert_conversation_to_prompts(conversation: Conversation):
342
 
343
 
344
  def to_gradio_chatbot(conversation: Conversation) -> list:
345
- """Convert the conversation to gradio chatbot format."""
346
  ret = []
347
  for i, (_, msg) in enumerate(conversation.messages[conversation.offset :]):
 
348
  if i % 2 == 0:
349
- if type(msg) is tuple:
350
- msg, images = copy.deepcopy(msg)
351
-
352
- if isinstance(images, list):
353
- img_str = ""
354
- for j, image in enumerate(images):
355
- if isinstance(image, str):
356
- with open(image, "rb") as f:
357
- data = f.read()
358
- img_b64_str = base64.b64encode(data).decode()
359
- image_str = (
360
- f'<img src="data:image/png;base64,{img_b64_str}" '
361
- f'alt="user upload image" style="max-width: 300px; height: auto;" />'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  )
363
  else:
364
- image_str = pil_to_base64(image, f"user upload image_{j}", max_size=800, min_size=400)
 
365
 
366
- img_str += image_str
367
- msg = img_str + msg
368
- else:
369
- pass
 
370
 
 
371
  ret.append([msg, None])
372
  else:
 
373
  ret[-1][-1] = msg
374
  return ret
375
 
 
13
  import torch
14
  import os
15
  from .utils import pil_to_base64
16
+ import mimetypes
17
  IMAGE_TOKEN = "<image>"
18
  logger = logging.getLogger("gradio_logger")
19
 
 
324
  Convert the conversation to prompts.
325
  """
326
  conv_prompts = []
327
+
328
  last_image = None
329
 
330
  messages = conversation.messages
 
343
 
344
 
345
  def to_gradio_chatbot(conversation: Conversation) -> list:
346
+ """Convert the conversation to gradio chatbot format, supporting images and video."""
347
  ret = []
348
  for i, (_, msg) in enumerate(conversation.messages[conversation.offset :]):
349
+ # User message
350
  if i % 2 == 0:
351
+ if isinstance(msg, tuple):
352
+ msg_text, media = copy.deepcopy(msg)
353
+ media_str = ""
354
+
355
+ # Handle list of media items
356
+ if isinstance(media, list):
357
+ items = media
358
+ else:
359
+ items = [media]
360
+
361
+ for j, item in enumerate(items):
362
+ # If string path, determine type
363
+ if isinstance(item, str):
364
+ mime, _ = mimetypes.guess_type(item)
365
+ with open(item, "rb") as f:
366
+ data = f.read()
367
+ b64 = base64.b64encode(data).decode()
368
+
369
+ if mime and mime.startswith("image/"):
370
+ media_str += (
371
+ f'<img src="data:{mime};base64,{b64}" '
372
+ f'alt="user upload image_{j}" '
373
+ f'style="max-width:300px;height:auto;" />'
374
+ )
375
+ elif mime and mime.startswith("video/"):
376
+ media_str += (
377
+ f'<video controls '
378
+ f'style="max-width:300px;height:auto;" '
379
+ f'src="data:{mime};base64,{b64}"></video>'
380
  )
381
  else:
382
+ # Fallback to link
383
+ media_str += f'<a href="{item}" target="_blank">{item}</a>'
384
 
385
+ # If PIL image
386
+ else:
387
+ media_str += pil_to_base64(item, f"user upload image_{j}", max_size=800, min_size=400)
388
+
389
+ msg = media_str + msg_text
390
 
391
+ # Append user side
392
  ret.append([msg, None])
393
  else:
394
+ # Assistant side, fill previous tuple
395
  ret[-1][-1] = msg
396
  return ret
397
 
eagle_vl/serve/inference.py CHANGED
@@ -12,7 +12,7 @@ from transformers import (
12
  StoppingCriteriaList,
13
  TextIteratorStreamer,
14
  )
15
-
16
  from .chat_utils import Conversation, get_conv_template
17
 
18
  logger = logging.getLogger(__name__)
@@ -91,6 +91,7 @@ class StoppingCriteriaSub(StoppingCriteria):
91
  def preprocess(
92
  messages: list[dict],
93
  processor,
 
94
  ):
95
  """
96
  Build messages from the conversations and images.
@@ -110,12 +111,28 @@ def preprocess(
110
  if "images" in message:
111
  per_round_images = message["images"]
112
  for image in per_round_images:
113
- record["content"].append(
114
- {
115
- "type": "image",
116
- "image": image,
117
- }
118
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  if 'content' in message:
120
  record["content"].append(
121
  {
@@ -148,12 +165,12 @@ def preprocess(
148
  formatted_answer.count(processor.image_token) == 0
149
  ), f"there should be no {processor.image_token} in the assistant's reply, but got {messages}"
150
 
151
- print(f"messages = {results}")
152
  text = processor.apply_chat_template(results, add_generation_prompt=False)
153
- print(f"raw text = {text}")
 
 
154
 
155
- image_inputs, video_inputs = processor.process_vision_info(results)
156
-
157
  inputs = processor(
158
  images=image_inputs,
159
  videos=video_inputs,
@@ -161,6 +178,7 @@ def preprocess(
161
  return_tensors="pt",
162
  padding=True,
163
  truncation=True,
 
164
  )
165
  return inputs
166
 
@@ -176,10 +194,11 @@ def eagle_vl_generate(
176
  temperature: float = 1.0,
177
  top_p: float = 1.0,
178
  chunk_size: int = -1,
 
179
  ):
180
  # convert conversation to inputs
181
  print(f"conversations = {conversations}")
182
- inputs = preprocess(conversations, processor=processor)
183
  inputs = inputs.to(model.device)
184
 
185
  return generate(
@@ -202,7 +221,7 @@ def generate(
202
  temperature: float = 0,
203
  top_p: float = 0.95,
204
  stop_words: List[str] = [],
205
- chunk_size: int = -1,
206
  ):
207
  """Stream the text output from the multimodality model with prompt and image inputs."""
208
  tokenizer = processor.tokenizer
 
12
  StoppingCriteriaList,
13
  TextIteratorStreamer,
14
  )
15
+ from PIL import Image
16
  from .chat_utils import Conversation, get_conv_template
17
 
18
  logger = logging.getLogger(__name__)
 
91
  def preprocess(
92
  messages: list[dict],
93
  processor,
94
+ video_nframes: int = 16,
95
  ):
96
  """
97
  Build messages from the conversations and images.
 
111
  if "images" in message:
112
  per_round_images = message["images"]
113
  for image in per_round_images:
114
+ if isinstance(image, Image.Image):
115
+ record["content"].append(
116
+ {
117
+ "type": "image",
118
+ "image": image,
119
+ }
120
+ )
121
+ elif isinstance(image, str) and image.endswith((".jpeg", ".jpg", ".png", ".gif")):
122
+ record["content"].append(
123
+ {
124
+ "type": "image",
125
+ "image": image,
126
+ }
127
+ )
128
+ elif isinstance(image, str) and image.endswith((".mp4", ".mov", ".avi", ".webm")):
129
+ record["content"].append(
130
+ {
131
+ "type": "video",
132
+ "video": image,
133
+ "nframes": video_nframes,
134
+ }
135
+ )
136
  if 'content' in message:
137
  record["content"].append(
138
  {
 
165
  formatted_answer.count(processor.image_token) == 0
166
  ), f"there should be no {processor.image_token} in the assistant's reply, but got {messages}"
167
 
168
+ # print(f"messages = {results}")
169
  text = processor.apply_chat_template(results, add_generation_prompt=False)
170
+ # print(f"raw text = {text}")
171
+
172
+ image_inputs, video_inputs, video_kwargs = processor.process_vision_info(results, return_video_kwargs=True)
173
 
 
 
174
  inputs = processor(
175
  images=image_inputs,
176
  videos=video_inputs,
 
178
  return_tensors="pt",
179
  padding=True,
180
  truncation=True,
181
+ videos_kwargs=video_kwargs,
182
  )
183
  return inputs
184
 
 
194
  temperature: float = 1.0,
195
  top_p: float = 1.0,
196
  chunk_size: int = -1,
197
+ video_nframes: int = 16,
198
  ):
199
  # convert conversation to inputs
200
  print(f"conversations = {conversations}")
201
+ inputs = preprocess(conversations, processor=processor, video_nframes=video_nframes)
202
  inputs = inputs.to(model.device)
203
 
204
  return generate(
 
221
  temperature: float = 0,
222
  top_p: float = 0.95,
223
  stop_words: List[str] = [],
224
+ chunk_size: int = -1
225
  ):
226
  """Stream the text output from the multimodality model with prompt and image inputs."""
227
  tokenizer = processor.tokenizer