khang119966 commited on
Commit
888d672
·
verified ·
1 Parent(s): 6968b05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -240
app.py CHANGED
@@ -35,13 +35,140 @@ from concurrent.futures import ProcessPoolExecutor
35
 
36
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
37
 
38
- env = {'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}
 
 
 
39
 
40
- subprocess.run('apt-get install -y fonts-noto-cjk', env=env, shell=True)
41
- subprocess.run('apt-get update -y', env=env, shell=True)
42
- subprocess.run('apt-get install -y wkhtmltopdf', env=env, shell=True)
43
- subprocess.run('apt-get install -y xvfb', env=env, shell=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  torch.set_default_device('cuda')
47
 
@@ -181,113 +308,6 @@ def visualize_attention_hiddenstate(attention_tensor, head=None, start_img_token
181
 
182
  return heat_maps, top_5_tokens
183
 
184
- def generate_next_token_table_image(model, tokenizer, response, index_focus):
185
- next_token_table = []
186
- for layer_index in range(len(response.hidden_states[index_focus])):
187
- h_out = model.language_model.lm_head(
188
- model.language_model.model.norm(response.hidden_states[index_focus][layer_index][0])
189
- )
190
- h_out = torch.softmax(h_out, -1)
191
- top_tokens = []
192
- for token_index in h_out.argsort(descending=True)[0, :3]: # Top 3
193
- token_str = tokenizer.decode(token_index)
194
- prob = float(h_out[0, int(token_index)])
195
- top_tokens.append((token_str, prob))
196
- next_token_table.append((layer_index, top_tokens))
197
- next_token_table = next_token_table[::-1]
198
-
199
- html_rows = ""
200
- last_layer_index = len(next_token_table) - 1
201
-
202
- for i, (layer_index, tokens) in enumerate(next_token_table):
203
- row = f"<tr><td style='font-weight: bold'>Layer {layer_index}</td>"
204
-
205
- # For the first column (Top 1)
206
- token_str, prob = tokens[0]
207
-
208
- # If this is the last layer in the table, make the text blue
209
- if layer_index == last_layer_index:
210
- row += f"<td><span style='color: red; font-weight: bold'>{token_str}</span> ({prob:.2%})</td>"
211
- else:
212
- row += f"<td><span style='color: blue; font-weight: bold'>{token_str}</span> ({prob:.2%})</td>"
213
-
214
- # For the other columns, keep normal formatting
215
- for token_str, prob in tokens[1:]:
216
- row += f"<td>{token_str} ({prob:.2%})</td>"
217
-
218
- row += "</tr>"
219
- html_rows += row
220
-
221
- html_code = f'''
222
- <html>
223
- <head>
224
- <meta charset="utf-8">
225
- <style>
226
- table {{
227
- font-family: 'Noto Sans';
228
- font-size: 12px;
229
- border-collapse: collapse;
230
- table-layout: fixed;
231
- width: 100%;
232
- }}
233
- th, td {{
234
- border: 1px solid black;
235
- padding: 8px;
236
- width: 150px;
237
- height: 30px;
238
- overflow: hidden;
239
- text-overflow: ellipsis;
240
- white-space: nowrap;
241
- text-align: center;
242
- }}
243
- th.layer {{
244
- width: 100px;
245
- }}
246
- th.title {{
247
- font-size: 14px;
248
- padding: 10px;
249
- height: auto;
250
- white-space: normal;
251
- overflow: visible;
252
- }}
253
- </style>
254
- </head>
255
- <body style="background-color: white;">
256
- <table>
257
- <tr>
258
- <th colspan="4" class="title">
259
- Top hidden tokens per layer for the Prediction
260
- </th>
261
- </tr>
262
- <tr>
263
- <th class="layer">Layer ⬆️</th>
264
- <th>Top 1</th>
265
- <th>Top 2</th>
266
- <th>Top 3</th>
267
- </tr>
268
- {html_rows}
269
- </table>
270
- </body>
271
- </html>
272
- '''
273
-
274
-
275
- with tempfile.TemporaryDirectory() as tmpdir:
276
- hti = Html2Image(output_path=tmpdir)
277
- hti.browser_flags = [
278
- "--headless=new", # ← Dùng chế độ headless mới
279
- "--disable-gpu", # ← Tắt GPU
280
- "--disable-software-rasterizer", # ← Tránh dùng fallback GPU software
281
- "--no-sandbox", # ← Tránh lỗi sandbox đa luồng
282
- ]
283
- filename = str(uuid.uuid4())+".png"
284
- # filename = 'next_token_table.png'
285
- hti.screenshot(html_str=html_code, save_as=filename, size=(500, 1000))
286
- img_path = os.path.join(tmpdir, filename)
287
- img_cv2 = cv2.imread(img_path)[:,:,::-1]
288
- os.remove(img_path)
289
- return img_cv2
290
-
291
  def adjust_overlay(overlay, text_img):
292
  h_o, w_o = overlay.shape[:2]
293
  h_t, w_t = text_img.shape[:2]
@@ -313,36 +333,6 @@ def adjust_overlay(overlay, text_img):
313
 
314
  return overlay_resized
315
 
316
- def generate_text_image_with_html2image(old_text, input_token, new_token, image_width=400, min_height=1000, font_size=16):
317
- full_text = old_text + f"<span style='color:blue; font-weight:bold'>[{input_token}]</span>"+ "→" + f"<span style='color:red; font-weight:bold'>[{new_token}]</span>"
318
-
319
- # Thay \n bằng thẻ HTML <br> để xuống dòng
320
- full_text = full_text.replace('\n', '<br>')
321
-
322
- html_code = f'''
323
- <html>
324
- <head>
325
- <meta charset="utf-8">
326
- </head>
327
- <body style="font-family: 'DejaVu Sans', sans-serif; font-size: {font_size}px; width: {image_width}px; min-height: {min_height}px; padding: 10px; background-color: white; line-height: 1.4;">
328
- {full_text}
329
- </body>
330
- </html>
331
- '''
332
- save_path = str(uuid.uuid4())+".png"
333
- hti = Html2Image(output_path='.')
334
- hti.browser_flags = [
335
- "--headless=new", # ← Dùng chế độ headless mới
336
- "--disable-gpu", # ← Tắt GPU
337
- "--disable-software-rasterizer", # ← Tránh dùng fallback GPU software
338
- "--no-sandbox", # ← Tránh lỗi sandbox đa luồng
339
- ]
340
- hti.screenshot(html_str=html_code, save_as=save_path, size=(image_width, min_height))
341
- text_img = cv2.imread(save_path)
342
- text_img = cv2.cvtColor(text_img, cv2.COLOR_BGR2RGB)
343
- os.remove(save_path)
344
- return text_img
345
-
346
  def extract_next_token_table_data(model, tokenizer, response, index_focus):
347
  next_token_table = []
348
  for layer_index in range(len(response.hidden_states[index_focus])):
@@ -359,98 +349,6 @@ def extract_next_token_table_data(model, tokenizer, response, index_focus):
359
  next_token_table = next_token_table[::-1]
360
  return next_token_table
361
 
362
- def render_next_token_table_image(table_data, predict_token):
363
- import tempfile, uuid, os
364
- from html2image import Html2Image
365
- import cv2
366
-
367
- html_rows = ""
368
- last_layer_index = len(table_data)
369
- for layer_index, tokens in table_data:
370
- row = f"<tr><td style='font-weight: bold'>Layer {layer_index+1}</td>"
371
-
372
- token_str, prob = tokens[0]
373
- if token_str == predict_token:
374
- style = "color: red; font-weight: bold"
375
- else:
376
- style = "color: blue; font-weight: bold"
377
- row += f"<td><span style='{style}'>{token_str}</span> ({prob:.2%})</td>"
378
-
379
- for token_str, prob in tokens[1:]:
380
- row += f"<td>{token_str} ({prob:.2%})</td>"
381
-
382
- row += "</tr>"
383
- html_rows += row
384
-
385
- html_code = f'''
386
- <html>
387
- <head>
388
- <meta charset="utf-8">
389
- <style>
390
- table {{
391
- font-family: 'Noto Sans';
392
- font-size: 12px;
393
- border-collapse: collapse;
394
- table-layout: fixed;
395
- width: 100%;
396
- }}
397
- th, td {{
398
- border: 1px solid black;
399
- padding: 8px;
400
- width: 150px;
401
- height: 30px;
402
- overflow: hidden;
403
- text-overflow: ellipsis;
404
- white-space: nowrap;
405
- text-align: center;
406
- }}
407
- th.layer {{
408
- width: 100px;
409
- }}
410
- th.title {{
411
- font-size: 14px;
412
- padding: 10px;
413
- height: auto;
414
- white-space: normal;
415
- overflow: visible;
416
- }}
417
- </style>
418
- </head>
419
- <body style="background-color: white;">
420
- <table>
421
- <tr>
422
- <th colspan="4" class="title">
423
- Hidden states per Transformer layer (LLM) for Prediction
424
- </th>
425
- </tr>
426
- <tr>
427
- <th class="layer">Layer ⬆️</th>
428
- <th>Top 1</th>
429
- <th>Top 2</th>
430
- <th>Top 3</th>
431
- </tr>
432
- {html_rows}
433
- </table>
434
- </body>
435
- </html>
436
- '''
437
-
438
- with tempfile.TemporaryDirectory() as tmpdir:
439
- hti = Html2Image(output_path=tmpdir)
440
- hti.browser_flags = [
441
- "--headless=new",
442
- "--disable-gpu",
443
- "--disable-software-rasterizer",
444
- "--no-sandbox",
445
- ]
446
- filename = str(uuid.uuid4()) + ".png"
447
- hti.screenshot(html_str=html_code, save_as=filename, size=(500, 1000))
448
- img_path = os.path.join(tmpdir, filename)
449
- img_cv2 = cv2.imread(img_path)[:, :, ::-1]
450
- os.remove(img_path)
451
- return img_cv2
452
-
453
-
454
  model = AutoModel.from_pretrained(
455
  "khang119966/Vintern-1B-v3_5-explainableAI",
456
  torch_dtype=torch.bfloat16,
@@ -460,9 +358,8 @@ model = AutoModel.from_pretrained(
460
  ).eval().cuda()
461
  tokenizer = AutoTokenizer.from_pretrained("khang119966/Vintern-1B-v3_5-explainableAI", trust_remote_code=True, use_fast=False)
462
 
463
- # Hàm bao để truyền vào multiprocessing
464
  def generate_text_img_wrapper(args):
465
- return generate_text_image_with_html2image(*args, image_width=500, min_height=1000)
466
 
467
  def generate_hidden_img_wrapper(args):
468
  return render_next_token_table_image(*args)
@@ -568,16 +465,21 @@ def generate_video(image, prompt, max_tokens):
568
  for frame in visualization_frames:
569
  frame = cv2.resize(frame,(visualization_frames[0].shape[1],visualization_frames[0].shape[0]))
570
  resized_visualization_frames.append(frame)
571
-
572
  # Lưu thành video MP4 bằng imageio
573
  imageio.mimsave(
574
- 'heatmap_animation.mp4',
575
  resized_visualization_frames, # dạng RGB
576
  fps=5
577
  )
578
-
579
 
580
- return "heatmap_animation.mp4"
 
 
 
 
 
 
581
 
582
  with gr.Blocks() as demo:
583
  gr.Markdown("""# 🎥 Visualizing How Multimodal Models Think
 
35
 
36
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
37
 
38
+ from PIL import Image, ImageDraw, ImageFont
39
+ import textwrap
40
+ import uuid
41
+ import os
42
 
43
+ def generate_text_image_with_pil(old_text, input_token, new_token, image_width=400, min_height=1000, font_size=16):
44
+ import textwrap
45
+ import numpy as np
46
+ from PIL import Image, ImageDraw, ImageFont
47
+
48
+ # Split text by newlines first to preserve manual line breaks
49
+ paragraphs = old_text.split('\n')
50
+
51
+ # Add the token information to the last paragraph
52
+ if paragraphs:
53
+ paragraphs[-1] += f"[{input_token}]→[{new_token}]"
54
+ else:
55
+ paragraphs = [f"[{input_token}]→[{new_token}]"]
56
+
57
+ # Create a list to store all wrapped lines
58
+ all_lines = []
59
+
60
+ # Process each paragraph separately
61
+ for paragraph in paragraphs:
62
+ # Only wrap if paragraph is not empty
63
+ if paragraph.strip():
64
+ wrapped_lines = textwrap.wrap(paragraph, width=60)
65
+ all_lines.extend(wrapped_lines)
66
+ else:
67
+ # Add an empty line for empty paragraphs (newlines)
68
+ all_lines.append("")
69
+
70
+ # Create image
71
+ img = Image.new('RGB', (image_width, min_height), color='white')
72
+ draw = ImageDraw.Draw(img)
73
+
74
+ # Load font
75
+ font_path = "/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc"
76
+ font = ImageFont.truetype(font_path, font_size)
77
+
78
+ # Draw text
79
+ y = 10
80
+ token_marker = f"[{input_token}]→[{new_token}]"
81
+
82
+ for line in all_lines:
83
+ if token_marker in line:
84
+ parts = line.split(token_marker)
85
+ # Draw text before token
86
+ draw.text((10, y), parts[0], fill="black", font=font)
87
+ x = 10 + draw.textlength(parts[0], font=font)
88
+
89
+ # Draw input token in blue
90
+ draw.text((x, y), f"[{input_token}]", fill="blue", font=font)
91
+ x += draw.textlength(f"[{input_token}]", font=font)
92
+
93
+ # Draw arrow
94
+ draw.text((x, y), "→", fill="black", font=font)
95
+ x += draw.textlength("→", font=font)
96
+
97
+ # Draw new token in red
98
+ draw.text((x, y), f"[{new_token}]", fill="red", font=font)
99
+
100
+ # Draw remainder text if any
101
+ if len(parts) > 1 and parts[1]:
102
+ x += draw.textlength(f"[{new_token}]", font=font)
103
+ draw.text((x, y), parts[1], fill="black", font=font)
104
+ else:
105
+ draw.text((10, y), line, fill="black", font=font)
106
+
107
+ # Move to next line, adding extra space between paragraphs
108
+ y += font_size + 8
109
+
110
+ return np.array(img)
111
+
112
+
113
+ from PIL import Image, ImageDraw, ImageFont
114
+
115
+
116
+ def render_next_token_table_image(table_data, predict_token, image_width=500, row_height=40, font_size=14):
117
+ # Cài đặt font hỗ trợ đa ngôn ngữ (sửa đường dẫn nếu cần)
118
+ # font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf"
119
+ font_path = "/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc"
120
+
121
+ font = ImageFont.truetype(font_path, font_size)
122
 
123
+ num_rows = len(table_data) + 2 # +2 cho phần tiêu đề
124
+ num_cols = 4 # Layer | Top1 | Top2 | Top3
125
+ table_width = image_width
126
+ col_width = table_width // num_cols
127
+ table_height = num_rows * row_height
128
+
129
+ # Tạo ảnh trắng
130
+ img = Image.new("RGB", (table_width, table_height), "white")
131
+ draw = ImageDraw.Draw(img)
132
+
133
+ def draw_cell(x, y, text, color="black", bold=False):
134
+ if bold:
135
+ draw.text((x + 5, y + 5), text, font=font, fill=color)
136
+ else:
137
+ draw.text((x + 5, y + 5), text, font=font, fill=color)
138
+
139
+ # Vẽ hàng tiêu đề chính
140
+ draw.rectangle([0, 0, table_width, row_height], outline="black")
141
+ draw_cell(5, 5, "Hidden states per Transformer layer (LLM) for Prediction", bold=True)
142
+
143
+ # Vẽ tiêu đề cột
144
+ headers = ["Layer ⬆️", "Top 1", "Top 2", "Top 3"]
145
+ for col, header in enumerate(headers):
146
+ x0 = col * col_width
147
+ y0 = row_height
148
+ draw.rectangle([x0, y0, x0 + col_width, y0 + row_height], outline="black")
149
+ draw_cell(x0, y0, header, bold=True)
150
+
151
+ # Vẽ từng hàng layer
152
+ for i, (layer_index, tokens) in enumerate(table_data):
153
+ y = (i + 2) * row_height
154
+ for col in range(num_cols):
155
+ x = col * col_width
156
+ draw.rectangle([x, y, x + col_width, y + row_height], outline="black")
157
+
158
+ if col == 0:
159
+ draw_cell(x, y, f"Layer {layer_index+1}", bold=True)
160
+ else:
161
+ if col - 1 < len(tokens):
162
+ token_str, prob = tokens[col - 1]
163
+ # Thay \n bằng chuỗi "\\n"
164
+ token_str = token_str
165
+ color = "red" if token_str == predict_token and col == 1 else "blue" if col == 1 else "black"
166
+ bold = token_str == predict_token and col == 1
167
+ token_str_ = token_str.replace("\n", "\\n").replace(" ", "\\s").replace("\t", "\\t")
168
+ draw_cell(x, y, f"{token_str_} ({prob:.1%})", color=color, bold=bold)
169
+
170
+ return np.array(img)
171
+
172
 
173
  torch.set_default_device('cuda')
174
 
 
308
 
309
  return heat_maps, top_5_tokens
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  def adjust_overlay(overlay, text_img):
312
  h_o, w_o = overlay.shape[:2]
313
  h_t, w_t = text_img.shape[:2]
 
333
 
334
  return overlay_resized
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  def extract_next_token_table_data(model, tokenizer, response, index_focus):
337
  next_token_table = []
338
  for layer_index in range(len(response.hidden_states[index_focus])):
 
349
  next_token_table = next_token_table[::-1]
350
  return next_token_table
351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  model = AutoModel.from_pretrained(
353
  "khang119966/Vintern-1B-v3_5-explainableAI",
354
  torch_dtype=torch.bfloat16,
 
358
  ).eval().cuda()
359
  tokenizer = AutoTokenizer.from_pretrained("khang119966/Vintern-1B-v3_5-explainableAI", trust_remote_code=True, use_fast=False)
360
 
 
361
  def generate_text_img_wrapper(args):
362
+ return generate_text_image_with_pil(*args, image_width=500, min_height=1000)
363
 
364
  def generate_hidden_img_wrapper(args):
365
  return render_next_token_table_image(*args)
 
465
  for frame in visualization_frames:
466
  frame = cv2.resize(frame,(visualization_frames[0].shape[1],visualization_frames[0].shape[0]))
467
  resized_visualization_frames.append(frame)
468
+
469
  # Lưu thành video MP4 bằng imageio
470
  imageio.mimsave(
471
+ 'heatmap_with_music.mp4',
472
  resized_visualization_frames, # dạng RGB
473
  fps=5
474
  )
 
475
 
476
+ # Nối video và nhạc
477
+ video = VideoFileClip("heatmap_animation.mp4")
478
+ audio = AudioFileClip("legacy-of-the-century-background-cinematic-music-for-video-46-second-319542.mp3").set_duration(video.duration)
479
+ final = video.set_audio(audio)
480
+ final.write_videofile("heatmap_with_music.mp4", codec="libx264", audio_codec="aac")
481
+
482
+ return "heatmap_with_music.mp4"
483
 
484
  with gr.Blocks() as demo:
485
  gr.Markdown("""# 🎥 Visualizing How Multimodal Models Think