aisafe commited on
Commit
0876906
·
verified ·
1 Parent(s): fada0da

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +803 -0
app.py ADDED
@@ -0,0 +1,803 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import plotly.graph_objects as go
3
+ from transformers import pipeline
4
+ import re
5
+ import time
6
+ import requests
7
+ from PIL import Image
8
+ import itertools
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+ from matplotlib.colors import rgb2hex
12
+ import matplotlib
13
+ from matplotlib.colors import ListedColormap, rgb2hex
14
+ import ipywidgets as widgets
15
+ from IPython.display import display, HTML
16
+ import re
17
+ import pandas as pd
18
+ from pprint import pprint
19
+ from tenacity import retry
20
+ from tqdm import tqdm
21
+ import tiktoken
22
+ import scipy.stats
23
+ import torch
24
+ from transformers import GPT2LMHeadModel
25
+ import tiktoken
26
+ import seaborn as sns
27
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
28
+ # from colorama import Fore, Style
29
+ import openai # for OpenAI API calls
30
+
31
+ ######################################
32
+ import streamlit as st
33
+ def colorize_tokens(token_data, sentence):
34
+ colored_sentence = ""
35
+ start = 0
36
+
37
+ for token in token_data:
38
+ entity_group = token["entity_group"]
39
+ word = token["word"]
40
+ tag = f"[{entity_group}]"
41
+ tag_color = tag_colors.get(entity_group, "white") # Default to white if color not found
42
+ colored_chunk = f'<span style="color:black;background-color:{tag_color}">{word} {tag}</span>'
43
+ colored_sentence += sentence[start:token["start"]] + colored_chunk
44
+ start = token["end"]
45
+
46
+ # Add the remaining part of the sentence
47
+ colored_sentence += sentence[start:]
48
+
49
+ return colored_sentence
50
+
51
+ # Define colors for the tags
52
+ tag_colors = {
53
+ "ADJP": "#0000FF", # Blue
54
+ "ADVP": "#008000", # Green
55
+ "CONJP": "#FF0000", # Red
56
+ "INTJ": "#00FFFF", # Cyan
57
+ "LST": "#FF00FF", # Magenta
58
+ "NP": "#FFFF00", # Yellow
59
+ "PP": "#800080", # Purple
60
+ "PRT": "#00008B", # Dark Blue
61
+ "SBAR": "#006400", # Dark Green
62
+ "VP": "#008B8B", # Dark Cyan
63
+ }
64
+ ##################
65
+
66
+ ###################
67
+ def generate_tagged_sentence(sentence, entity_tags):
68
+ # Create a list to hold the tagged tokens
69
+ tagged_tokens = []
70
+
71
+ # Process the entity tags to annotate the sentence
72
+ for tag in entity_tags:
73
+ start = tag['start']
74
+ end = tag['end']
75
+ token = sentence[start - 1:end] # Adjust for 0-based indexing
76
+ tag_name = f"[{tag['entity_group']}]"
77
+
78
+ tagged_tokens.append(f"{token} {tag_name}")
79
+
80
+ # Return the tagged sentence
81
+ return " ".join(tagged_tokens)
82
+
83
+
84
+ def replace_pp_with_pause(sentence, entity_tags):
85
+ # Create a list to hold the tagged tokens
86
+ tagged_tokens = []
87
+
88
+ # Process the entity tags to replace [PP] with [PAUSE]
89
+ for tag in entity_tags:
90
+ start = tag['start']
91
+ end = tag['end']
92
+ token = sentence[start - 1:end] # Adjust for 0-based indexing
93
+ tag_name = f"[{tag['entity_group']}]"
94
+
95
+ if tag['entity_group'] == 'PP':
96
+ # Replace [PP] with [PAUSE]
97
+ tag_name = '[PAUSE]'
98
+ else:
99
+ tag_name = ''
100
+
101
+ tagged_tokens.append(f"{token}{tag_name}")
102
+
103
+ # Return the sentence with [PAUSE] replacement
104
+ return " ".join(tagged_tokens)
105
+
106
+
107
+ def get_split_sentences(sentence, entity_tags):
108
+ split_sentences = []
109
+
110
+ # Initialize a variable to hold the current sentence
111
+ current_sentence = []
112
+
113
+ # Process the entity tags to split the sentence
114
+ for tag in entity_tags:
115
+ if tag['entity_group'] == 'PP':
116
+ start = tag['start']
117
+ end = tag['end']
118
+ token = sentence[start - 1:end] # Adjust for 0-based indexing
119
+ current_sentence.append(token)
120
+ split_sentences.append(" ".join(current_sentence))
121
+ current_sentence = [] # Reset the current sentence
122
+ else:
123
+ start = tag['start']
124
+ end = tag['end']
125
+ token = sentence[start - 1:end] # Adjust for 0-based indexing
126
+ current_sentence.append(token)
127
+
128
+ # If the sentence ends without a [PAUSE] token, add the final sentence
129
+ if current_sentence:
130
+ split_sentences.append(" ".join(current_sentence))
131
+
132
+ return split_sentences
133
+ # def get_split_sentences(sentence, entity_tags):
134
+ # split_sentences = []
135
+
136
+ # # Initialize a variable to hold the current sentence
137
+ # current_sentence = []
138
+
139
+ # # Process the entity tags to split the sentence
140
+ # for tag in entity_tags:
141
+ # if tag['entity_group'] == 'PP':
142
+ # if current_sentence:
143
+ # print(current_sentence)
144
+ # split_sentences.append(" ".join(current_sentence))
145
+ # current_sentence = [] # Reset the current sentence
146
+ # else:
147
+ # start = tag['start']
148
+ # end = tag['end']
149
+ # token = sentence[start - 1:end] # Adjust for 0-based indexing
150
+ # current_sentence.append(token)
151
+
152
+ # # If the sentence ends without a [PAUSE] token, add the final sentence
153
+ # if current_sentence:
154
+ # split_sentences.append(" ".join(current_sentence))
155
+
156
+ # return split_sentences
157
+
158
+
159
+
160
+
161
+ ##################
162
+
163
+
164
+ ######################################
165
+
166
+ st.set_page_config(page_title="Hallucination", layout="wide")
167
+ st.title(':blue[Sorry come again! This time slowly, please]')
168
+ st.header("Rephrasing LLM Prompts for Better Comprehension Reduces :blue[Hallucination]")
169
+ ############################
170
+ video_file1 = open('machine.mp4', 'rb')
171
+ video_file2 = open('Pause 3 Out1.mp4', 'rb')
172
+ video_bytes1 = video_file1.read()
173
+ video_bytes2 = video_file2.read()
174
+ col1a, col1b = st.columns(2)
175
+ with col1a:
176
+ st.caption("Original")
177
+ st.video(video_bytes1)
178
+ with col1b:
179
+ st.caption("Paraphrased and added [PAUSE]")
180
+ st.video(video_bytes2)
181
+ #############################
182
+ HF_SPACES_API_KEY = st.secrets["HF_token"]
183
+
184
+ #API_URL = "https://api-inference.huggingface.co/models/openlm-research/open_llama_3b"
185
+ API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom"
186
+ headers = {"Authorization": HF_SPACES_API_KEY}
187
+
188
+ def query(payload):
189
+ response = requests.post(API_URL, headers=headers, json=payload)
190
+ return response.json()
191
+
192
+ API_URL_chunk = "https://api-inference.huggingface.co/models/flair/chunk-english"
193
+
194
+ def query_chunk(payload):
195
+ response = requests.post(API_URL_chunk, headers=headers, json=payload)
196
+ return response.json()
197
+
198
+
199
+
200
+ from tenacity import (
201
+ retry,
202
+ stop_after_attempt,
203
+ wait_random_exponential,
204
+ ) # for exponential backoff
205
+ # openai.api_key = f"{st.secrets['OpenAI_API']}"
206
+ # model_engine = "gpt-4"
207
+ # @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
208
+ # def get_answers(prompt):
209
+ # completion = openai.ChatCompletion.create(
210
+ # model = 'gpt-3.5-turbo',
211
+ # messages = [
212
+ # {'role': 'user', 'content': prompt}
213
+ # ],
214
+ # temperature = 0,max_tokens= 200,
215
+ # )
216
+ # return completion['choices'][0]['message']['content']
217
+ prompt = '''Generate a story from the given text.
218
+ Text : '''
219
+ # paraphrase_prompt = '''Rephrase the given text: '''
220
+
221
+ # _gpt3tokenizer = tiktoken.get_encoding("cl100k_base")
222
+
223
+ ##########################
224
+ # def render_heatmap(original_text, importance_scores_df):
225
+ # # Extract the importance scores
226
+ # importance_values = importance_scores_df['importance_value'].values
227
+
228
+ # # Check for division by zero during normalization
229
+ # min_val = np.min(importance_values)
230
+ # max_val = np.max(importance_values)
231
+
232
+ # if max_val - min_val != 0:
233
+ # normalized_importance_values = (importance_values - min_val) / (max_val - min_val)
234
+ # else:
235
+ # normalized_importance_values = np.zeros_like(importance_values) # Fallback: all-zero array
236
+
237
+ # # Generate a colormap for the heatmap
238
+ # cmap = matplotlib.colormaps['inferno']
239
+
240
+ # # Function to determine text color based on background color
241
+ # def get_text_color(bg_color):
242
+ # brightness = 0.299 * bg_color[0] + 0.587 * bg_color[1] + 0.114 * bg_color[2]
243
+ # if brightness < 0.5:
244
+ # return 'white'
245
+ # else:
246
+ # return 'black'
247
+
248
+ # # Initialize pointers for the original text and token importance
249
+ # original_pointer = 0
250
+ # token_pointer = 0
251
+
252
+ # # Create an HTML representation
253
+ # html = ""
254
+ # while original_pointer < len(original_text):
255
+ # token = importance_scores_df.loc[token_pointer, 'token']
256
+ # if original_pointer == original_text.find(token, original_pointer):
257
+ # importance = normalized_importance_values[token_pointer]
258
+ # rgba = cmap(importance)
259
+ # bg_color = rgba[:3]
260
+ # text_color = get_text_color(bg_color)
261
+ # html += f'<span style="background-color: rgba({int(bg_color[0]*255)}, {int(bg_color[1]*255)}, {int(bg_color[2]*255)}, 1); color: {text_color};">{token}</span>'
262
+ # original_pointer += len(token)
263
+ # token_pointer += 1
264
+ # else:
265
+ # html += original_text[original_pointer]
266
+ # original_pointer += 1
267
+
268
+ # #display(HTML(html))
269
+ # st.markdown(html, unsafe_allow_html=True)
270
+
271
+
272
+ def render_heatmap(original_text, importance_scores_df):
273
+ # Extract the importance scores
274
+ importance_values = importance_scores_df['importance_value'].values
275
+
276
+ # Check for division by zero during normalization
277
+ min_val = np.min(importance_values)
278
+ max_val = np.max(importance_values)
279
+
280
+ if max_val - min_val != 0:
281
+ normalized_importance_values = (importance_values - min_val) / (max_val - min_val)
282
+ else:
283
+ normalized_importance_values = np.zeros_like(importance_values) # Fallback: all-zero array
284
+
285
+ # Generate a colormap for the heatmap (use "Blues")
286
+ cmap = matplotlib.cm.get_cmap('Blues')
287
+
288
+ # Function to determine text color based on background color
289
+ def get_text_color(bg_color):
290
+ brightness = 0.299 * bg_color[0] + 0.587 * bg_color[1] + 0.114 * bg_color[2]
291
+ if brightness < 0.5:
292
+ return 'white'
293
+ else:
294
+ return 'black'
295
+
296
+ # Initialize pointers for the original text and token importance
297
+ original_pointer = 0
298
+ token_pointer = 0
299
+
300
+ # Create an HTML representation
301
+ html = ""
302
+ while original_pointer < len(original_text):
303
+ token = importance_scores_df.loc[token_pointer, 'token']
304
+ if original_pointer == original_text.find(token, original_pointer):
305
+ importance = normalized_importance_values[token_pointer]
306
+ rgba = cmap(importance)
307
+ bg_color = rgba[:3]
308
+ text_color = get_text_color(bg_color)
309
+ html += f'<span style="background-color: rgba({int(bg_color[0]*255)}, {int(bg_color[1]*255)}, {int(bg_color[2]*255)}, 1); color: {text_color};">{token}</span>'
310
+ original_pointer += len(token)
311
+ token_pointer += 1
312
+ else:
313
+ html += original_text[original_pointer]
314
+ original_pointer += 1
315
+
316
+ st.markdown(html, unsafe_allow_html=True)
317
+
318
+ ##########################
319
+ # Create selectbox
320
+
321
+ prompt_list=["Which individuals possessed the ships that were part of the Boston Tea Party?",
322
+ "Freddie Frith", "Robert used PDF for his math homework."
323
+ ]
324
+
325
+ options = [f"Prompt #{i+1}: {prompt_list[i]}" for i in range(3)] + ["Another Prompt..."]
326
+ selection = st.selectbox("Choose a prompt from the dropdown below . Click on :blue['Another Prompt...'] , if you want to enter your own custom prompt.", options=options)
327
+ check=[]
328
+ # if selection == "Another Prompt...":
329
+ # otherOption = st.text_input("Enter your custom prompt...")
330
+ # if otherOption:
331
+ # st.caption(f""":white_check_mark: Your input prompt is : {otherOption}""")
332
+ # st.caption(':green[Kindly hold on for a few minutes while the AI text is being generated]')
333
+
334
+ # check=otherOption
335
+ # st.caption(f"""{check}""")
336
+
337
+ # else:
338
+ # result = re.split(r'#\d+:', selection, 1)
339
+ # if result:
340
+ # st.caption(f""":white_check_mark: Your input prompt is : {result[1]}""")
341
+ # st.caption(':green[Kindly hold on for a few minutes while the AI text is being generated]')
342
+ # check=result[1]
343
+ if selection == "Another Prompt...":
344
+ check = st.text_input("Enter your custom prompt...")
345
+ check = " " + check
346
+ if check:
347
+ st.caption(f""":white_check_mark: Your input prompt is : {check}""")
348
+ st.caption(':green[Kindly hold on for a few minutes while the AI text is being generated]')
349
+
350
+ # check=otherOption
351
+ # st.caption(f"""{check}""")
352
+
353
+ else:
354
+ check = re.split(r'#\d+:', selection, 1)[1]
355
+ if check:
356
+ st.caption(f""":white_check_mark: Your input prompt is : {check}""")
357
+ st.caption(':green[Kindly hold on for a few minutes while the AI text is being generated]')
358
+ # check=result[1]
359
+
360
+ # @st.cache_data
361
+ def load_chunk_model(check):
362
+ iden=['error']
363
+ while 'error' in iden:
364
+ time.sleep(1)
365
+ try:
366
+ output = query_chunk({"inputs": f"""{check}""",})
367
+ iden = output # Update 'check' with the new result
368
+ except Exception as e:
369
+ print(f"An exception occurred: {e}")
370
+
371
+ return output
372
+
373
+
374
+
375
+ ##################################
376
+
377
+
378
+ # st.write(entity_tags)
379
+
380
+
381
+ ##################################
382
+ # colored_output, _ = colorize_tokens(load_chunk_model(check),check)
383
+ # st.caption('The below :blue[NER] tags are found for orginal prompt:')
384
+ # st.markdown(colored_output, unsafe_allow_html=True)
385
+
386
+ # @st.cache_resource
387
+ def load_text_gen_model(check):
388
+ iden=['error']
389
+ while 'error' in iden:
390
+ time.sleep(1)
391
+ try:
392
+ output = query({
393
+ "inputs": f"""{check}""",
394
+ "parameters": {
395
+ "min_new_tokens": 30,
396
+ "max_new_tokens": 100,
397
+ "do_sample":True,
398
+ #"remove_invalid_values" : True
399
+ #"temperature" :0.6
400
+ # "top_k":1
401
+ # "num_beams":2,
402
+ # "no_repeat_ngram_size":2,
403
+ # "early_stopping":True
404
+ }
405
+ })
406
+ iden = output # Update 'check' with the new result
407
+ except Exception as e:
408
+ print(f"An exception occurred: {e}")
409
+
410
+ return output[0]['generated_text']
411
+ # @st.cache_data
412
+ # def load_text_gen_model(check):
413
+ # return get_answers(prompt + check)
414
+
415
+
416
+
417
+ def decoded_tokens(string, tokenizer):
418
+ return [tokenizer.decode([x]) for x in tokenizer.encode(string)]
419
+
420
+ # def analyze_heatmap(df):
421
+ # sns.set_palette(sns.color_palette("viridis"))
422
+
423
+ # # Create a copy of the DataFrame to prevent modification of the original
424
+ # df_copy = df.copy()
425
+
426
+ # # Ensure DataFrame has the required columns
427
+ # if 'token' not in df_copy.columns or 'importance_value' not in df_copy.columns:
428
+ # raise ValueError("The DataFrame must contain 'token' and 'importance_value' columns.")
429
+
430
+ # # Add 'Position' column to the DataFrame copy
431
+ # df_copy['Position'] = range(len(df_copy))
432
+
433
+ # # Plot a bar chart for importance score per token
434
+ # plt.figure(figsize=(len(df_copy) * 0.3, 4))
435
+ # sns.barplot(x='token', y='importance_value', data=df_copy)
436
+ # plt.xticks(rotation=45, ha='right')
437
+ # plt.title('Importance Score per Token')
438
+ # return plt
439
+ # #plt.show()
440
+
441
+ # ###########################
442
+
443
+ # def analyze_heatmap(df_input):
444
+ # df = df_input.copy()
445
+ # df["Position"] = range(len(df))
446
+
447
+ # # Get the viridis colormap
448
+ # viridis = matplotlib.cm.get_cmap("viridis")
449
+ # # Create a Matplotlib figure and axis
450
+ # fig, ax = plt.subplots(figsize=(10, 6))
451
+
452
+ # # Normalize the importance values
453
+ # min_val = df["importance_value"].min()
454
+ # max_val = df["importance_value"].max()
455
+ # normalized_values = (df["importance_value"] - min_val) / (max_val - min_val)
456
+
457
+ # # Create the bars, colored based on normalized importance_value
458
+ # for i, (token, norm_value) in enumerate(zip(df["token"], normalized_values)):
459
+ # color = viridis(norm_value)
460
+ # ax.bar(
461
+ # x=[i], # Use index for x-axis
462
+ # height=[df["importance_value"].iloc[i]],
463
+ # width=1.0, # Set the width to make bars touch each other
464
+ # color=[color],
465
+ # )
466
+
467
+ # # Additional styling
468
+ # ax.set_title("Importance Score per Token", size=25)
469
+ # ax.set_xlabel("Token")
470
+ # ax.set_ylabel("Importance Value")
471
+ # ax.set_xticks(range(len(df["token"])))
472
+ # ax.set_xticklabels(df["token"], rotation=45)
473
+
474
+ # return fig
475
+ @st.cache_data
476
+ def analyze_heatmap(df_input):
477
+ df = df_input.copy()
478
+ df["Position"] = range(len(df))
479
+
480
+ # Get the Blues colormap
481
+ blues = matplotlib.cm.get_cmap("Blues")
482
+ # Create a Matplotlib figure and axis
483
+ fig, ax = plt.subplots(figsize=(10, 6))
484
+
485
+ # Normalize the importance values
486
+ min_val = df["importance_value"].min()
487
+ max_val = df["importance_value"].max()
488
+ normalized_values = (df["importance_value"] - min_val) / (max_val - min_val)
489
+
490
+ # Create the bars, colored based on normalized importance_value
491
+ for i, (token, norm_value) in enumerate(zip(df["token"], normalized_values)):
492
+ color = blues(norm_value)
493
+ ax.bar(
494
+ x=[i], # Use index for x-axis
495
+ height=[df["importance_value"].iloc[i]],
496
+ width=1.0, # Set the width to make bars touch each other
497
+ color=[color],
498
+ )
499
+
500
+ # Additional styling
501
+ ax.set_title("Importance Score per Token", size=25)
502
+ ax.set_xlabel("Token")
503
+ ax.set_ylabel("Importance Value")
504
+ ax.set_xticks(range(len(df["token"])))
505
+ ax.set_xticklabels(df["token"], rotation=45)
506
+
507
+ return fig
508
+
509
+ # def analyze_heatmap(df_input):
510
+ # df = df_input.copy()
511
+ # df["Position"] = range(len(df))
512
+
513
+ # # Get the viridis colormap
514
+ # viridis = matplotlib.colormaps["viridis"]
515
+ # # Initialize the figure
516
+ # fig = go.Figure()
517
+ # # Create the histogram bars with viridis coloring
518
+
519
+ # # Normalize the importance values
520
+ # min_val = df["importance_value"].min()
521
+ # max_val = df["importance_value"].max()
522
+ # normalized_values = (df["importance_value"] - min_val) / (max_val - min_val)
523
+ # # Initialize the figure
524
+ # fig = go.Figure()
525
+ # # Create the bars, colored based on normalized importance_value
526
+ # for i, (token, norm_value) in enumerate(zip(df["token"], normalized_values)):
527
+ # color = f"rgb({int(viridis(norm_value)[0] * 255)}, {int(viridis(norm_value)[1] * 255)}, {int(viridis(norm_value)[2] * 255)})"
528
+ # fig.add_trace(
529
+ # go.Bar(
530
+ # x=[i], # Use index for x-axis
531
+ # y=[df["importance_value"].iloc[i]],
532
+ # width=1.0, # Set the width to make bars touch each other
533
+ # marker=dict(color=color),
534
+ # )
535
+ # )
536
+ # # Additional styling
537
+ # fig.update_layout(
538
+ # title=f"Importance Score per Token",
539
+ # title_font={'size': 25},
540
+ # xaxis_title="Token",
541
+ # yaxis_title="Importance Value",
542
+ # showlegend=False,
543
+ # bargap=0, # Remove gap between bars
544
+ # xaxis=dict( # Set tick labels to tokens
545
+ # tickmode="array",
546
+ # tickvals=list(range(len(df["token"]))),
547
+ # ticktext=list(df["token"]),
548
+ # ),
549
+ # )
550
+ # # Rotate x-axis labels by 45 degrees
551
+ # fig.update_xaxes(tickangle=45)
552
+ # return fig
553
+
554
+ ############################
555
+ # @st.cache_data
556
+ def integrated_gradients(input_ids, baseline, model, n_steps= 10): #100
557
+ # Convert input_ids and baseline to LongTensors
558
+ input_ids = input_ids.long()
559
+ baseline = baseline.long()
560
+
561
+ # Initialize tensor to store accumulated gradients
562
+ accumulated_grads = None
563
+
564
+ # Create interpolated inputs
565
+ alphas = torch.linspace(0, 1, n_steps)
566
+ delta = input_ids - baseline
567
+ interpolates = [(baseline + (alpha * delta).long()).long() for alpha in alphas] # Explicitly cast to LongTensor
568
+
569
+ # Initialize tqdm progress bar
570
+ pbar = tqdm(total=n_steps, desc="Calculating Integrated Gradients")
571
+
572
+ for interpolate in interpolates:
573
+
574
+ # Update tqdm progress bar
575
+ pbar.update(1)
576
+
577
+ # Convert interpolated samples to embeddings
578
+ interpolate_embedding = model.transformer.wte(interpolate).clone().detach().requires_grad_(True)
579
+
580
+ # Forward pass
581
+ output = model(inputs_embeds=interpolate_embedding, output_attentions=False)[0]
582
+
583
+ # Aggregate the logits across all positions (using sum in this example)
584
+ aggregated_logit = output.sum()
585
+
586
+ # Backward pass to calculate gradients
587
+ aggregated_logit.backward()
588
+
589
+ # Accumulate gradients
590
+ if accumulated_grads is None:
591
+ accumulated_grads = interpolate_embedding.grad.clone()
592
+ else:
593
+ accumulated_grads += interpolate_embedding.grad
594
+
595
+ # Clear gradients
596
+ model.zero_grad()
597
+ interpolate_embedding.grad.zero_()
598
+
599
+ # Close tqdm progress bar
600
+ pbar.close()
601
+
602
+ # Compute average gradients
603
+ avg_grads = accumulated_grads / n_steps
604
+
605
+ # Compute attributions
606
+ with torch.no_grad():
607
+ input_embedding = model.transformer.wte(input_ids)
608
+ baseline_embedding = model.transformer.wte(baseline)
609
+ attributions = (input_embedding - baseline_embedding) * avg_grads
610
+
611
+ return attributions
612
+ # @st.cache_data
613
+ def process_integrated_gradients(input_text, _gpt2tokenizer, model):
614
+ inputs = torch.tensor([_gpt2tokenizer.encode(input_text)])
615
+
616
+ gpt2tokens = decoded_tokens(input_text, _gpt2tokenizer)
617
+
618
+ with torch.no_grad():
619
+ outputs = model(inputs, output_attentions=True, output_hidden_states=True)
620
+
621
+ attentions = outputs[-1]
622
+
623
+ # Initialize a baseline as zero tensor
624
+ baseline = torch.zeros_like(inputs).long()
625
+
626
+ # Compute Integrated Gradients targeting the aggregated sequence output
627
+ attributions = integrated_gradients(inputs, baseline, model)
628
+
629
+ # Convert tensors to numpy array for easier manipulation
630
+ attributions_np = attributions.detach().numpy().sum(axis=2)
631
+
632
+ # Sum across the embedding dimensions to get a single attribution score per token
633
+ attributions_sum = attributions.sum(axis=2).squeeze(0).detach().numpy()
634
+
635
+ l2_norm_attributions = np.linalg.norm(attributions_sum, 2)
636
+ normalized_attributions_sum = attributions_sum / l2_norm_attributions
637
+
638
+ clamped_attributions_sum = np.where(normalized_attributions_sum < 0, 0, normalized_attributions_sum)
639
+
640
+ attribution_df = pd.DataFrame({
641
+ 'token': gpt2tokens,
642
+ 'importance_value': clamped_attributions_sum
643
+ })
644
+ return attribution_df
645
+ ########################
646
+ model_type = 'gpt2'
647
+ model_version = 'gpt2'
648
+ model = GPT2LMHeadModel.from_pretrained(model_version, output_attentions=True)
649
+ _gpt2tokenizer = tiktoken.get_encoding("gpt2")
650
+ #######################
651
+ para_tokenizer = AutoTokenizer.from_pretrained("humarin/chatgpt_paraphraser_on_T5_base")
652
+ para_model = AutoModelForSeq2SeqLM.from_pretrained("humarin/chatgpt_paraphraser_on_T5_base")
653
+ ######################
654
+ @st.cache_resource
655
+ def paraphrase(
656
+ question,
657
+ num_beams=5,
658
+ num_beam_groups=5,
659
+ num_return_sequences=5,
660
+ repetition_penalty=10.0,
661
+ diversity_penalty=3.0,
662
+ no_repeat_ngram_size=2,
663
+ temperature=0.7,
664
+ max_length=64 #128
665
+ ):
666
+ input_ids = para_tokenizer(
667
+ f'paraphrase: {question}',
668
+ return_tensors="pt", padding="longest",
669
+ max_length=max_length,
670
+ truncation=True,
671
+ ).input_ids
672
+
673
+ outputs = para_model.generate(
674
+ input_ids, temperature=temperature, repetition_penalty=repetition_penalty,
675
+ num_return_sequences=num_return_sequences, no_repeat_ngram_size=no_repeat_ngram_size,
676
+ num_beams=num_beams, num_beam_groups=num_beam_groups,
677
+ max_length=max_length, diversity_penalty=diversity_penalty
678
+ )
679
+
680
+ res = para_tokenizer.batch_decode(outputs, skip_special_tokens=True)
681
+
682
+ return res
683
+
684
+ ###########################
685
+
686
+ class SentenceAnalyzer:
687
+ def __init__(self, check, original, _gpt2tokenizer, model):
688
+ self.check = check
689
+ self.original = original
690
+ self._gpt2tokenizer = _gpt2tokenizer
691
+ self.model = model
692
+ self.entity_tags = load_chunk_model(check)
693
+ self.tagged_sentence = generate_tagged_sentence(check, self.entity_tags)
694
+ self.sentence_with_pause = replace_pp_with_pause(check, self.entity_tags)
695
+ self.split_sentences = get_split_sentences(check, self.entity_tags)
696
+ self.colored_output = colorize_tokens(self.entity_tags, check)
697
+
698
+ def analyze(self):
699
+ # st.caption(f"The below :blue[shallow parsing] tags are found for {self.original} prompt:")
700
+ # st.markdown(self.colored_output, unsafe_allow_html=True)
701
+ attribution_df1 = process_integrated_gradients(self.check, self._gpt2tokenizer, self.model)
702
+ st.caption(f":blue[{self.original}]:")
703
+ render_heatmap(self.check, attribution_df1)
704
+ # st.write("Original")
705
+ st.pyplot(analyze_heatmap(attribution_df1))
706
+ # st.write("After [PAUSE]")
707
+ # st.write("Sentence with [PAUSE] Replacement:", self.sentence_with_pause)
708
+ dataframes_list = []
709
+
710
+ for i, split_sentence in enumerate(self.split_sentences):
711
+ # st.write(f"Sentence {i + 1} : {split_sentence}")
712
+ attribution_df1 = process_integrated_gradients(split_sentence, self._gpt2tokenizer, self.model)
713
+ if i < len(self.split_sentences) - 1:
714
+ # Add a row with [PAUSE] and value 0 at the end
715
+ pause_row = pd.DataFrame({'token': '[PAUSE]', 'importance_value': 0},index=[len(attribution_df1)])
716
+ attribution_df1 = pd.concat([attribution_df1,pause_row], ignore_index=True)
717
+
718
+ dataframes_list.append(attribution_df1)
719
+
720
+ # After the loop, you can concatenate the dataframes in the list if needed
721
+ if dataframes_list:
722
+ combined_dataframe = pd.concat(dataframes_list, axis=0)
723
+ combined_dataframe = combined_dataframe[combined_dataframe['token'] != " "].reset_index(drop=True)
724
+ combined_dataframe1 = combined_dataframe[combined_dataframe['token'] != "[PAUSE]"]
725
+ combined_dataframe1.reset_index(drop=True, inplace=True)
726
+ st.write(f"Sentence with [PAUSE] Replacement:")
727
+ # st.dataframe(combined_dataframe1)
728
+ render_heatmap(self.sentence_with_pause,combined_dataframe1)
729
+ # render_heatmap(self.sentence_with_pause,combined_dataframe)
730
+ st.pyplot(analyze_heatmap(combined_dataframe))
731
+
732
+
733
+ paraphrase_list=paraphrase(check)
734
+ # st.write(paraphrase_list)
735
+ ######################
736
+
737
+ col1, col2 = st.columns(2)
738
+ with col1:
739
+ analyzer = SentenceAnalyzer(check, "Original Prompt", _gpt2tokenizer, model)
740
+ analyzer.analyze()
741
+ with col2:
742
+ ai_gen_text=load_text_gen_model(check)
743
+ st.caption(':blue[AI generated text by GPT4]')
744
+ st.write(ai_gen_text)
745
+
746
+ #st.markdown("""<hr style="height:5px;border:none;color:#333;background-color:#333;" /> """, unsafe_allow_html=True)
747
+ st.markdown("""<hr style="height:5px;border:none;color:lightblue;background-color:lightblue;" /> """, unsafe_allow_html=True)
748
+
749
+
750
+ col3, col4 = st.columns(2)
751
+ with col3:
752
+ analyzer = SentenceAnalyzer(" "+paraphrase_list[0], "Paraphrase 1", _gpt2tokenizer, model)
753
+ analyzer.analyze()
754
+ with col4:
755
+ ai_gen_text=load_text_gen_model(paraphrase_list[0])
756
+ st.caption(':blue[AI generated text by GPT4]')
757
+ st.write(ai_gen_text)
758
+
759
+ st.markdown("""<hr style="height:5px;border:none;color:lightblue;background-color:skyblue;" /> """, unsafe_allow_html=True)
760
+
761
+ col5, col6 = st.columns(2)
762
+ with col5:
763
+ analyzer = SentenceAnalyzer(" "+paraphrase_list[1], "Paraphrase 2", _gpt2tokenizer, model)
764
+ analyzer.analyze()
765
+ with col6:
766
+ ai_gen_text=load_text_gen_model(paraphrase_list[1])
767
+ st.caption(':blue[AI generated text by GPT4]')
768
+ st.write(ai_gen_text)
769
+
770
+ st.markdown("""<hr style="height:5px;border:none;color:lightblue;background-color:skyblue;" /> """, unsafe_allow_html=True)
771
+
772
+ col7, col8 = st.columns(2)
773
+ with col7:
774
+ analyzer = SentenceAnalyzer(" "+paraphrase_list[2], "Paraphrase 3", _gpt2tokenizer, model)
775
+ analyzer.analyze()
776
+ with col8:
777
+ ai_gen_text=load_text_gen_model(paraphrase_list[2])
778
+ st.caption(':blue[AI generated text by GPT4]')
779
+ st.write(ai_gen_text)
780
+
781
+ st.markdown("""<hr style="height:5px;border:none;color:lightblue;background-color:skyblue;" /> """, unsafe_allow_html=True)
782
+
783
+ col9, col10 = st.columns(2)
784
+ with col9:
785
+ analyzer = SentenceAnalyzer(" "+paraphrase_list[3], "Paraphrase 4", _gpt2tokenizer, model)
786
+ analyzer.analyze()
787
+ with col10:
788
+ ai_gen_text=load_text_gen_model(paraphrase_list[3])
789
+ st.caption(':blue[AI generated text by GPT4]')
790
+ st.write(ai_gen_text)
791
+
792
+ st.markdown("""<hr style="height:5px;border:none;color:lightblue;background-color:skyblue;" /> """, unsafe_allow_html=True)
793
+
794
+ col11, col12 = st.columns(2)
795
+ with col11:
796
+ analyzer = SentenceAnalyzer(" "+paraphrase_list[4], "Paraphrase 5", _gpt2tokenizer, model)
797
+ analyzer.analyze()
798
+ with col12:
799
+ ai_gen_text=load_text_gen_model(paraphrase_list[4])
800
+ st.caption(':blue[AI generated text by GPT4]')
801
+ st.write(ai_gen_text)
802
+
803
+