anonymousdemo commited on
Commit
e65ff69
·
verified ·
1 Parent(s): 8fb58bc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +869 -0
app.py ADDED
@@ -0,0 +1,869 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import json
5
+
6
+ import time
7
+ import requests
8
+
9
+ import os
10
+ import glob
11
+ import re
12
+ #import smart_open
13
+ import plotly.express as px
14
+ import random
15
+ #import difflib
16
+ import pdb
17
+
18
+ from sentence_transformers import SentenceTransformer, models, util
19
+
20
+ enable_summary_button = True
21
+ dump_pos_data_for_reporting = True
22
+
23
+ bucket_name = "paper_n1"
24
+
25
+ prefix_lst = [
26
+ "pgj_d_4096",
27
+ "pgj_d_2048",
28
+ "pgj_d_1024_v2",
29
+ "pgj_d_1024_layer_14",
30
+ "pgj_d_1024_layer_7",
31
+ "pgj_d_1024_layer_2",
32
+ "pgj_d_1024_layer_1" ]
33
+
34
+ # "my_gptj_6b_tpu_size_8",
35
+
36
+ model_names = {
37
+ prefix_lst[0]: 'PatentGPT-J-6B',
38
+ prefix_lst[1]: 'PatentGPT-J-1.6B',
39
+
40
+ # prefix_lst[2]: 'PatentGPT-J-279M',
41
+ # prefix_lst[3]: 'PatentGPT-J-191M',
42
+ # prefix_lst[4]: 'PatentGPT-J-128M',
43
+ # prefix_lst[5]: 'PatentGPT-J-115M',}
44
+
45
+ prefix_lst[2]: 'PatentGPT-J-456M',
46
+ prefix_lst[3]: 'PatentGPT-J-279M',
47
+ prefix_lst[4]: 'PatentGPT-J-191M',
48
+ prefix_lst[5]: 'PatentGPT-J-128M',
49
+ prefix_lst[6]: 'PatentGPT-J-115M',}
50
+
51
+ # prefix_lst[7]:'GPT-J-6B'
52
+
53
+ # experiment 3
54
+ # folder = os.path.join('experiments', 'non_patent')
55
+ # id_to_scroll = 1 # which of the above to scroll through
56
+ # first_claim_only = True
57
+
58
+ #experiment 2
59
+ # folder = os.path.join('experiments', 'ipg20220104_500')
60
+ # #folder = "device_serve_results"
61
+ # id_to_scroll = 1 # which of the above to scroll through
62
+ # first_claim_only = False
63
+
64
+ # prefix_lst = ["my_gptj_6b_tpu_size_8", "pgj_d_4096", "pgj_d_2048", "pgj_d_1024_layer_14", "pgj_d_1024_layer_7", "pgj_d_1024_layer_2", "pgj_d_1024_layer_1"]
65
+ # #, "pgj_large", "pgj_medium", "pgj_small", ]
66
+ # # "pgj_d_1024_layer_14"
67
+
68
+ # experiment 1
69
+ folder = os.path.join('experiments', 'ipg22_500')
70
+ # (previous) folder = "eval_ipg22_500"
71
+ id_to_scroll = 1 # which of the above to scroll through
72
+ first_claim_only = True
73
+
74
+ ignore_outscope = True # ignore pick > 10
75
+
76
+ # def show_diff(a, b):
77
+ # #print('{} => {}'.format(a,b))
78
+ # for i, s in enumerate(difflib.ndiff(a, b)):
79
+ # if s[0]==' ': continue
80
+ # elif s[0]=='-':
81
+ # print(u'Delete "{}" from position {}'.format(s[-1],i))
82
+ # elif s[0]=='+':
83
+ # print(u'Add "{}" to position {}'.format(s[-1],i))
84
+
85
+ def handle_char_return(text):
86
+ if text == '(none)': # unicorn text
87
+ text == ''
88
+
89
+ return text
90
+
91
+ #return ch.replace('\n', '\\n')
92
+
93
+ #if ch == '\n':
94
+ # ch = "'\\n'"
95
+ #return ch
96
+
97
+ def get_remaining(lst, pos):
98
+ s = ''
99
+ for i in range(pos, len(lst)):
100
+ text = lst[i]['actual_next_token_text']
101
+ if text.startswith(' ') == False:
102
+ s += text
103
+ else:
104
+ break
105
+
106
+ return s
107
+
108
+ def calc_details(base_fn):
109
+ full_fn = os.path.join(folder, base_fn)
110
+ #gs_fn = "gs://%s/%s/%s" % (bucket_name, folder, base_fn)
111
+ #with smart_open.open(gs_fn) as f:
112
+
113
+ if os.path.exists(full_fn) == False:
114
+ return None, -1, -1, None, None, None, None, None
115
+
116
+ with open(full_fn) as f:
117
+ result = json.loads(f.read())
118
+ print("Loaded: %s" % full_fn)
119
+
120
+ lst = result['output']
121
+ recv = result['recv']
122
+ sum_pick = 0
123
+ sum_prob = 0
124
+ sum_outscope_count = 0
125
+ sum_outscope_len = 0
126
+ sum_hit_1 = 0
127
+ sum_top_10_len = 0
128
+ full_text = ''
129
+
130
+ token_count = 0
131
+ #found_end = False
132
+
133
+ #pdb.set_trace()
134
+
135
+ for i, tk in enumerate(lst[:-1]):
136
+ # if found_end:
137
+ # break
138
+
139
+ token_text = handle_char_return(tk['actual_next_token_text'])
140
+
141
+ # Due to tokenizer difference, the following needs more work in the future.
142
+ # if base_fn.find('gptj') >= 0:
143
+ # # using the original gpt-j-6b model
144
+ # # need to skip special tokens
145
+ # if i <= 7:
146
+ # continue # skip |start of claim|>
147
+
148
+ # remaining_text = get_remaining(lst, i)
149
+ # if remaining_text.find('<|end_of_claim|>') >= 0:
150
+ # pos1 = remaining_text.find('<|end_of_claim|>')
151
+ # token_text = remaining_text[:pos1]
152
+ # found_end = True
153
+ # #pdb.set_trace()
154
+ # #break
155
+
156
+ # The following was for GPT-J-6B. Not needed for PatentGPT-J.
157
+ #if token_text.find('<|end_of_claim|>') == 0:
158
+ # #pdb.set_trace()
159
+ # break
160
+
161
+ next_top_seq = int(tk['actual_next_token_top_seq'])
162
+ next_top_prob = float(tk['actual_next_token_top_prob'])
163
+
164
+ full_text += token_text
165
+ if next_top_seq == 0:
166
+ sum_hit_1 += 1 # press "tab" for the top pick
167
+
168
+ if ignore_outscope and next_top_seq>=10:
169
+ sum_outscope_count += 1
170
+ sum_outscope_len += len(token_text) # use length as keystrokes
171
+ else:
172
+ sum_pick += min(next_top_seq+1, len(token_text))
173
+ #sum_pick += (next_top_seq+1) # press "down" & "tab"
174
+ sum_prob += next_top_prob
175
+ sum_top_10_len += len(token_text)
176
+
177
+ token_count += 1
178
+
179
+ if ignore_outscope:
180
+ if token_count == 0: # unlikely
181
+ avg_pick = 0
182
+ avg_prob = 0
183
+ else:
184
+ avg_pick = float(sum_pick) / token_count
185
+ avg_prob = float(sum_prob) / token_count
186
+ else:
187
+ avg_pick = float(sum_pick) / token_count
188
+ avg_prob = float(sum_prob) / token_count
189
+
190
+ # if len(lst) < 2048: # for debugging
191
+ # s = '<|start_of_claim|>' + full_text
192
+ # if len(s) != len(recv['context']):
193
+ # print('length mismatch --> full_text: %s, recv: %s' % (len(s), len(recv['context'])))
194
+ # show_diff(s, recv['context'])
195
+ # pdb.set_trace()
196
+
197
+ return result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text
198
+
199
+ def show_avg(base_fn, model_name, patent_claim_num, show_pick=False):
200
+
201
+ result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)
202
+
203
+ if token_count == 0:
204
+ print('debug 2')
205
+ pdb.set_trace()
206
+
207
+ if result is None:
208
+ return None
209
+
210
+ lst = result['output']
211
+ result = ''
212
+ sum_all = {}
213
+ for i, tk in enumerate(lst):
214
+ token_text = handle_char_return(tk['actual_next_token_text'])
215
+
216
+ if token_text == '<|end_of_claim|>':
217
+ break
218
+
219
+ if token_text == '(none)': # for unicorn text
220
+ break
221
+
222
+ # Skip GPT-J, due to different tokenization
223
+ # if base_fn.find('gptj') >= 0:
224
+ # # using the original gpt-j-6b model
225
+ # # need to skip special tokens
226
+ # if i <= 7:
227
+ # continue # skip |start of claim|>
228
+ # if token_text == '.<': # assuming .<|end of claim|>
229
+ # break
230
+
231
+ pick = int(tk['actual_next_token_top_seq'])
232
+ prob = float(tk['actual_next_token_top_prob'])
233
+
234
+ colors = [
235
+ ['00ff00', '000000', '1'],
236
+ ['008800', 'ffffff', '2-10'],
237
+ ['ff0000', 'ffffff', 'out of top 10'],
238
+ ]
239
+ #colors = [
240
+ # ['00ff00', '000000', '1'],
241
+ # ['008800', 'ffffff', '2-10'],
242
+ # ['aa0000', 'ffffff', '11-100'],
243
+ # ['ff0000', 'ffffff', '101~']
244
+ #]
245
+
246
+ for j, item in enumerate(colors):
247
+ sum_all[item[2]] = 0
248
+
249
+ # skip follow-up subword
250
+ # if token_text.startswith(' ') == False:
251
+ # bg_color = ''
252
+ # fg_color = ''
253
+ # else:
254
+
255
+ if pick == 0:
256
+ bg_color = colors[0][0]
257
+ fg_color = colors[0][1]
258
+ tag = colors[0][2]
259
+ sum_all[tag] += 1
260
+ elif pick >= 1 and pick < 10:
261
+ bg_color = colors[1][0]
262
+ fg_color = colors[1][1]
263
+ tag = colors[1][2]
264
+ sum_all[tag] += 1
265
+ else: # pick >= 10
266
+ #elif pick >= 10 and pick < 100:
267
+ bg_color = colors[2][0]
268
+ fg_color = colors[2][1]
269
+ tag = colors[2][2]
270
+ sum_all[tag] += 1
271
+ #else: #pick >= 100:
272
+ # bg_color = colors[3][0]
273
+ # fg_color = colors[3][1]
274
+ # tag = colors[3][2]
275
+ # sum_all[tag] += 1
276
+
277
+ if show_pick:
278
+ pick = '[%s]' % pick
279
+ else:
280
+ pick = ''
281
+
282
+ result += "<span style=background-color:#%s;color:#%s;border-radius:5px;>%s%s</span> " % (bg_color, fg_color, token_text, pick) #&nbsp;
283
+
284
+ color_msg = ''
285
+ for i, v in enumerate(colors):
286
+ color_msg += "<span style=background-color:#%s;color:#%s;border-radius:5px;>&nbsp;%s&nbsp;</span> " % (v[0], v[1], v[2])
287
+
288
+ #result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)
289
+
290
+ # sum_pick as top 1~10
291
+ keys_with_auto = (sum_pick+sum_outscope_len)
292
+ keys_without_auto = len(full_text)
293
+ saved_ratio = float(keys_without_auto-keys_with_auto)/keys_without_auto * 100
294
+ s = 'model: %s\n' \
295
+ 'Autocomplete Effectiveness: %.1f%% (keystrokes saved)\n' \
296
+ 'Total keystrokes: %s (with autocomplete), %s (without autocomplete)\n' \
297
+ 'Keystroke distribution: top 1~10: %s (top 1: %s), out of top 10: %s' % (model_name, saved_ratio, keys_with_auto, keys_without_auto, sum_pick, sum_hit_1, sum_outscope_len)
298
+ st.text(s)
299
+
300
+ # s = 'file: %s, sum_pick: %s, sum_hit_1: %s, token_count: %s, sum_outscope: %s, avg_pick: %.2f, avg_prob: %.2f, sum_prob: %.2f, hit_1 ratio: %.2f &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;' % (base_fn, sum_pick, sum_hit_1, token_count, sum_outscope, avg_pick, avg_prob, sum_prob, float(sum_hit_1)/token_count)
301
+ #s += color_msg
302
+
303
+ s = color_msg
304
+ st.markdown(s, unsafe_allow_html=True)
305
+ #st.text('file: %s, avg_pick: %5.2f, avg_prob: %.2f, hit count: %s/%s ' % (base_fn, avg_pick, avg_prob, hit_0_count, len(lst)))
306
+ # show histogram
307
+
308
+ st.markdown(result, unsafe_allow_html=True)
309
+ #st.text_area('context with top seq & prob:', result, height=400)
310
+
311
+ sum_lst = [sum_all['1'], sum_all['2-10'], sum_all['out of top 10']]
312
+ #sum_lst = [['1', sum_all['1']], ['2-10', sum_all['2-10']]]
313
+ #sum_lst = [sum_all['1'], sum_all['2-10'], sum_all['11-100'], sum_all['101~']]
314
+
315
+ return sum_lst
316
+
317
+ def show_overall_summary(prefix_lst, select_lst):
318
+ # accumulate all
319
+
320
+ # debug
321
+ # for i, num in enumerate(select_lst):
322
+ # pre_full_text = ''
323
+ # for prefix in prefix_lst:
324
+ # base_fn = '%s_%s_forward.json' % (prefix, num)
325
+ # result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)
326
+
327
+ # if pre_full_text == '':
328
+ # pre_full_text = full_text
329
+ # else:
330
+ # if pre_full_text != full_text:
331
+ # print('debug')
332
+ # pdb.set_trace()
333
+
334
+ # #
335
+ # pdb.set_trace()
336
+
337
+ for prefix in prefix_lst:
338
+ acc_token_count = 0
339
+ acc_sum_pick = 0
340
+ acc_sum_prob = 0
341
+ acc_sum_outscope_count = 0
342
+ acc_sum_outscope_len = 0
343
+ acc_sum_hit_1 = 0
344
+ acc_sum_top_10_len = 0
345
+ acc_full_text_len = 0
346
+
347
+ pre_full_text = ''
348
+ for i, num in enumerate(select_lst):
349
+ base_fn = '%s_%s_forward.json' % (prefix, num)
350
+ result, avg_pick, avg_prob, token_count, sum_pick, sum_prob, sum_outscope_count, sum_outscope_len, sum_hit_1, sum_top_10_len, full_text = calc_details(base_fn)
351
+
352
+ acc_token_count += token_count
353
+ acc_sum_pick += sum_pick
354
+ acc_sum_prob += sum_prob
355
+ acc_sum_outscope_count += sum_outscope_count
356
+ acc_sum_outscope_len += sum_outscope_len
357
+ acc_sum_hit_1 += sum_hit_1
358
+ acc_sum_top_10_len += sum_top_10_len
359
+ acc_full_text_len += len(full_text)
360
+
361
+ if acc_token_count > 0:
362
+ # acc_sum_pick --> top 1~10
363
+ keys_with_auto = acc_sum_pick + acc_sum_outscope_len
364
+ keys_without_auto = acc_full_text_len
365
+ saved_ratio = float(keys_without_auto-keys_with_auto)/keys_without_auto * 100
366
+
367
+ st.text('[ %s ]\n' \
368
+ 'Autocomplete Effectiveness: %.1f%% (ratio of saving keystroke)\n' \
369
+ '(sum) keys_with_auto: %s, top_10_keys: %s, out_of_scope: %s, sum_hit_1: %s\n' \
370
+ 'keys_without_auto: %s, top_10_len: %s, prob: %.2f' % (
371
+ model_names[prefix], saved_ratio,
372
+ '{:,}'.format(keys_with_auto),
373
+ '{:,}'.format(acc_sum_pick),
374
+ '{:,}'.format(acc_sum_outscope_len),
375
+ '{:,}'.format(acc_sum_hit_1),
376
+ '{:,}'.format(keys_without_auto),
377
+ '{:,}'.format(acc_sum_top_10_len),
378
+ acc_sum_prob,
379
+ ))
380
+
381
+ st.text('%s & %.1f\\%% & %s & %s & %s & %s & %s \\\\' % (model_names[prefix], saved_ratio, '{:,}'.format(keys_with_auto), '{:,}'.format(acc_sum_pick), '{:,}'.format(acc_sum_outscope_len), '{:,}'.format(acc_sum_hit_1), '{:,}'.format(keys_without_auto)))
382
+
383
+ # st.text('* acc_token_count =%s --> (avg) hits: %.2f, keys: %.2f, prob: %.2f, outscope: %.2f' % (
384
+ # acc_token_count,
385
+ # float(acc_sum_hit_1)/acc_token_count,
386
+ # float(acc_sum_pick)/acc_token_count,
387
+ # float(acc_sum_prob)/acc_token_count,
388
+ # float(acc_sum_outscope_count)/acc_token_count))
389
+
390
+ def calc_height(s):
391
+ return int(len(s) / 10 * 3) + 30
392
+
393
+ def remove_end_of_claim_text(gen_text):
394
+ tag = '<|end_of_claim|>'
395
+ pos = gen_text.find(tag)
396
+ if pos > 0:
397
+ gen_text = gen_text[:pos+len(tag)]
398
+ return gen_text
399
+
400
+ tag = '<|endoftext|>'
401
+ pos = gen_text.find(tag)
402
+ if pos > 0:
403
+ gen_text = gen_text[:pos+len(tag)]
404
+
405
+ return gen_text
406
+
407
+ def dump_pos_data(prefix_lst, select_lst):
408
+ #statistics = [[0]*3]*2048
409
+ statistics = []
410
+ for i in range(2048):
411
+ statistics.append([0,0,0])
412
+
413
+ #results.append(['model', 'pos', 'key'])
414
+ #results.append(['model', 'patent_claim', 'pos', 'top-1', 'top-2~10', 'out of top 10'])
415
+ max_len = -1
416
+ for prefix in prefix_lst:
417
+ model_name = model_names[prefix].replace('PatentGPT-J-', '')
418
+ if model_name != '456M':
419
+ continue
420
+
421
+ #total = {}
422
+ for i, num in enumerate(select_lst):
423
+ base_fn = '%s_%s_forward.json' % (prefix, num)
424
+ full_fn = os.path.join(folder, base_fn)
425
+ if os.path.exists(full_fn) == False:
426
+ continue
427
+
428
+ with open(full_fn) as f:
429
+ result = json.loads(f.read())
430
+ print("Loaded: %s" % full_fn)
431
+
432
+ lst = result['output']
433
+ for j, tk in enumerate(lst[:-1]):
434
+ max_len = max(j, max_len)
435
+ next_top_seq = int(tk['actual_next_token_top_seq'])
436
+ #next_top_prob = float(tk['actual_next_token_top_prob'])
437
+
438
+ top_1 = top_2_to_10 = out_of_scope = 0
439
+ if next_top_seq == 0:
440
+ top_1 = 1
441
+ tag = 'top-1'
442
+ statistics[j][0] += 1
443
+ elif next_top_seq > 0 and next_top_seq < 10:
444
+ top_2_to_10 = 1
445
+ tag = 'top-2~10'
446
+ statistics[j][1] += 1
447
+ else:
448
+ out_of_scope = 1
449
+ tag = 'out-of-scope'
450
+ statistics[j][2] += 1
451
+
452
+ #total[tag] = total.get(tag, 0) + 1
453
+ #results.append([model_name, str(i+1), tag])
454
+ #results.append([model_name, str(i+1), tag])
455
+ #results.append([model_name, num, str(i+1), tag])
456
+ #results.append([model_name, num, i+1, top_1, top_2_to_10, out_of_scope])
457
+ #pdb.set_trace()
458
+ #pdb.set_trace()
459
+
460
+ dump_file = 'dump4.txt'
461
+ #pdb.set_trace()
462
+ with open(dump_file, 'w') as f:
463
+ for i in range(max_len+1):
464
+ f.write('%s, top-1, %s\n' % (i+1, statistics[i][0]))
465
+ f.write('%s, top-2~10, %s\n' % (i+1, statistics[i][1]))
466
+ f.write('%s, out_of_scope, %s\n' % (i+1, statistics[i][2]))
467
+ # f.write('%s\n' % ', '.join([str(i+1)] + [ str(v) for v in statistics[i] ] ))
468
+ print('saved: %s' % dump_file)
469
+
470
+
471
+ # dump_file = 'dump2.txt'
472
+ # with open(dump_file, 'w') as f:
473
+ # for line in results:
474
+ # f.write('%s\n' % ', '.join(line))
475
+ # print('saved: %s' % dump_file)
476
+
477
+
478
+ def calc_sentence_similarity(sent_model, sent1, sent2):
479
+ rewards = []
480
+ embedding1 = sent_model.encode(sent1, convert_to_tensor=True)
481
+ embedding2 = sent_model.encode(sent2, convert_to_tensor=True)
482
+ similarity = util.cos_sim(embedding1, embedding2)[0][0]
483
+
484
+ #pdb.set_trace()
485
+
486
+ return similarity
487
+
488
+ sent_model = 'patent/st-aipd-nlp-g'
489
+ print('loading SentenceTransformer: %s' % sent_model)
490
+ sent_aipd = SentenceTransformer(sent_model)
491
+
492
+ def load_data(demo):
493
+ fn = 'ppo_open_llama_3b_v2.run.12.delta.txt'
494
+ #fn = 'ppo_output/ppo_open_llama_3b_v2.run.12.delta.txt'
495
+ with open(fn, 'r') as f:
496
+ rows = json.load(f)
497
+
498
+ if demo == 'demo1':
499
+ new_rows = [ row for row in rows if row['instruction'].find('child') > 0 ]
500
+ elif demo == 'demo2':
501
+ new_rows = [ row for row in rows if row['instruction'].find('parent') > 0 ]
502
+ else:
503
+ new_rows = []
504
+
505
+ return new_rows
506
+
507
+ container_style = """
508
+ <style>
509
+ .container1 {
510
+ border: 2px solid #3498db;
511
+ border-radius: 8px;
512
+ padding: 10px;
513
+ margin-bottom: 20px;
514
+ }
515
+ .container2 {
516
+ /* Add styles for Container 2 if needed */
517
+ }
518
+ </style>
519
+ """
520
+
521
+ def main():
522
+ st.set_page_config( # Alternate names: setup_page, page, layout
523
+ layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
524
+ initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed"
525
+ page_title="Demo 1", # String or None. Strings get appended with "• Streamlit".
526
+ page_icon=None, # String, anything supported by st.image, or None.
527
+ )
528
+
529
+ opt_1 = 'parent --> child'
530
+ opt_2 = 'child --> parent'
531
+ options = [opt_1, opt_2]
532
+ rows = None
533
+ pos = None
534
+ patent_num = ''
535
+ claim_num1 = ''
536
+ claim_num2 = ''
537
+ instruction= ''
538
+ input_text = ''
539
+ output_text = ''
540
+ response = ''
541
+ query = ''
542
+ score_lst_1 = 0
543
+ score_lst_2 = 0
544
+ rewards = ''
545
+ with st.container():
546
+ col1, col2, col3 = st.columns([3, 5, 2])
547
+ with col1:
548
+ selected_option = st.selectbox('Select a demo:', options)
549
+ if selected_option == opt_1:
550
+ rows = load_data('demo1')
551
+ msg = 'novelty = sim1-sim2'
552
+ #msg = 'delta of similarities<br>(sim1-sim2)'
553
+ c1_tag = 'pc'
554
+ c2_tag = 'cc1'
555
+ c3_tag = 'cc2'
556
+ elif selected_option == opt_2:
557
+ rows = load_data('demo2')
558
+ msg = 'similarity of<br>(pc1) and (pc2)'
559
+ c1_tag = 'cc'
560
+ c2_tag = 'pc1'
561
+ c3_tag = 'pc2'
562
+ else:
563
+ st.text('Unknown option')
564
+ return
565
+ #rows = rows[:5000] # for debugging
566
+
567
+ with col2:
568
+ pos = st.slider("", 1, len(rows))
569
+ #pos = st.slider("Degree of novelty (Generated v. Actual)", 1, len(rows))
570
+ for i in range(pos):
571
+ #prompt = '%s' % rows[i]
572
+ #pdb.set_trace()
573
+
574
+ patent_num = rows[i]['patent_num']
575
+ claim_num1 = rows[i]['claim_num1']
576
+ claim_num2 = rows[i]['claim_num2']
577
+ instruction= rows[i]['instruction']
578
+ input_text = rows[i]['input']
579
+ output_text = rows[i]['output']
580
+ response = rows[i]['response']
581
+ query = rows[i]['query']
582
+ score_lst_1 = rows[i]['score_lst_1']
583
+ score_lst_2 = rows[i]['score_lst_2']
584
+ delta = rows[i]['delta']
585
+ rewards = rows[i]['rewards']
586
+ with col3:
587
+ #v = round(float(score_lst_1)-float(score_lst_2), 4)
588
+ #v = delta #round(delta,10)
589
+ st.markdown("<center><h7>%s<br>%s</h7></center>" % (msg, delta), unsafe_allow_html=True)
590
+ # style='text-align: center; color: black;'
591
+
592
+
593
+ # selectbox_placeholder = st.empty()
594
+ # selected_option = selectbox_placeholder.selectbox('Select a demo:', options)
595
+ # container1 = st.container()
596
+
597
+
598
+ # with st.container():
599
+ # col1, col2 = st.columns(2)
600
+ # with col1:
601
+ # st.write('Caption for first chart')
602
+ # with col2:
603
+ # st.line_chart((0,1), height=100)
604
+ # with st.container():
605
+ # col1, col2 = st.columns(2)
606
+ # with col1:
607
+ # st.write('Caption for second chart')
608
+ # with col2:
609
+ # st.line_chart((1,0), height=100)
610
+
611
+ #st.write('patent_num:', patent_num)
612
+ # st.write('claim_num1:', claim_num1)
613
+ # st.write('claim_num2:', claim_num2)
614
+ st.write('(instruction) ', instruction)
615
+
616
+ with st.container():
617
+ with st.container(border=True):
618
+ st.write('(%s) [ %s ]\n%s' % (c1_tag, patent_num, input_text))
619
+ #st.write('input:' % patent_num)
620
+ #st.write('input:\n', input_text)
621
+
622
+ #container1.markdown("<div class='container1'>", unsafe_allow_html=True)
623
+ col1, col2 = st.columns(2)
624
+ with col1:
625
+ with st.container(border=True):
626
+ st.write('(%s) (actual)' % c2_tag)
627
+ st.write(output_text)
628
+ with col2:
629
+ with st.container(border=True):
630
+ st.write('(%s) (generated)' % c3_tag)
631
+ st.write(response)
632
+
633
+ col1, col2 = st.columns(2)
634
+ with col1:
635
+ with st.container(border=True):
636
+ st.write('(sim1) similarity between %s and %s+%s: %s' % (c1_tag, c1_tag, c2_tag, str(score_lst_1)))
637
+ with col2:
638
+ with st.container(border=True):
639
+ st.write('(sim2) similarity between %s and %s+%s: %s' % (c1_tag, c1_tag, c3_tag, str(score_lst_2)))
640
+
641
+ #container1.markdown("</div>", unsafe_allow_html=True)
642
+
643
+ # st.write("In Container 1")
644
+ # table_name = st.radio("Please Select Table", list_of_tables)
645
+
646
+ # st.write('output:')
647
+ # st.write(output_text)
648
+ # st.write('response:')
649
+ # st.write(response)
650
+ #st.write('query:', query)
651
+ # st.write('score_lst_1:', score_lst_1)
652
+ # st.write('score_lst_2:', score_lst_2)
653
+ # st.write('rewards:', rewards)
654
+ # st.text('hello')
655
+
656
+ # dict_keys(['patent_num', 'claim_num1', 'claim_num2', 'instruction', 'input', 'output', 'query', 'response', 'score_lst_1', 'score_lst_2', 'rewards'])
657
+
658
+ # st.subheader("Inspecting PatentGPT-J Model Evaluation")
659
+
660
+
661
+
662
+ # num_set = set()
663
+ # fn_lst = glob.glob(os.path.join(folder, '*'))
664
+ # for i, fn in enumerate(fn_lst):
665
+ # for prefix in prefix_lst:
666
+ # v = re.search('(.*?)%s\_(\d+\_\d+)\_(.*?)' % prefix, fn)
667
+ # if v is None:
668
+ # v = re.search('(.*?)%s\_(\w+\_\d+)\_(.*?)' % prefix, fn)
669
+
670
+ # #pdb.set_trace()
671
+ # if v is None:
672
+ # #pdb.set_trace()
673
+ # continue
674
+
675
+ # v = v.group(2)
676
+ # if first_claim_only:
677
+ # if v.endswith('_1'):
678
+ # num_set.add(v)
679
+ # else:
680
+ # num_set.add(v)
681
+
682
+ # num_lst = list(num_set)
683
+ # num_lst.sort()
684
+
685
+ # select_lst = []
686
+ # for i, num in enumerate(num_lst):
687
+ # all_existed = True
688
+ # for prefix in prefix_lst:
689
+ # fn = os.path.join(folder, '%s_%s_forward.json' % (prefix, num))
690
+ # if os.path.exists(fn) == False:
691
+ # all_existed = False
692
+ # break
693
+ # if all_existed:
694
+ # select_lst.append(num)
695
+ # select_lst.sort()
696
+
697
+ # if len(select_lst) == 0:
698
+ # st.text('select_lst is empty')
699
+ # return
700
+
701
+ # if dump_pos_data_for_reporting:
702
+ # dump_pos_data(prefix_lst, select_lst)
703
+ # st.text('Dump data: done')
704
+ # return
705
+
706
+ # # debug
707
+ # #base_fn = 'my_gptj_6b_tpu_size_8_11212952_1_forward.json'
708
+ # #base_fn = 'pgj_small_text-1_1_forward.json'
709
+ # #_ = show_avg(base_fn)
710
+
711
+ # if enable_summary_button:
712
+ # if st.button('Show Summary'):
713
+ # st.text('len(select_lst) = %s' % len(select_lst))
714
+ # show_overall_summary(prefix_lst, select_lst)
715
+
716
+ # # if 'num' not in st.session_state:
717
+ # # num = random.choice(select_lst)
718
+ # # st.session_state['num'] = num
719
+
720
+ # # set_state('num', num)
721
+ # # def set_state(k, v):
722
+ # # if k not in st.session_state:
723
+ # # st.session_state[ k ] = v
724
+
725
+ # show_patent_lst = [ s.replace('_', ' (claim ') + ')' for s in select_lst]
726
+ # selected = st.selectbox("Choose a patent claim", show_patent_lst)
727
+ # num = selected.replace(')', '').replace(' (claim ', '_')
728
+ # if st.button('Random pick'):
729
+ # num = random.choice(select_lst)
730
+
731
+ # st.text('Selected: %s' % num)
732
+ # st.session_state['num'] = num
733
+
734
+ # avgs = []
735
+ # for prefix in prefix_lst:
736
+ # base_fn = '%s_%s_forward.json' % (prefix, num)
737
+ # one_avg = show_avg(base_fn, model_names[prefix], num)
738
+ # if one_avg is not None:
739
+ # avgs.append(one_avg)
740
+
741
+ # # debug
742
+ # #pdb.set_trace()
743
+ # #return
744
+ # #
745
+
746
+ # data_lst = []
747
+ # for i in range(len(avgs[0])):
748
+ # row = []
749
+ # for j, prefix in enumerate(prefix_lst):
750
+ # row.append(avgs[j][i])
751
+ # data_lst.append(row)
752
+
753
+ # df = pd.DataFrame(data_lst, index=['1','2-10','out of top 10'])
754
+ # #df = pd.DataFrame(data_lst, index=['1','2-10','11-100','101~'])
755
+
756
+ # # ], index=['(a) 1','(b) 2-10','(c) 11-100','(d) 101~'])
757
+ # # [avgs[0][0], avgs[1][0], avgs[2][0]],
758
+ # # [avgs[0][1], avgs[1][1], avgs[2][1]],
759
+ # # [avgs[0][2], avgs[1][2], avgs[2][2]],
760
+ # # [avgs[0][3], avgs[1][3], avgs[2][3]],
761
+
762
+ # #df = pd.DataFrame([[1,2],[3,1]], columns=['a', 'b'])
763
+ # #df = pd.DataFrame([
764
+ # # [sum1[0], sum1[1], sum1[2], sum1[3]],
765
+ # # [sum2[0], sum2[1], sum2[2], sum2[3]],
766
+ # # [sum3[0], sum3[1], sum3[2], sum3[3]],
767
+ # # ]) #, index=['(a) 1','(b) 2-10','(c) 11-100','(d) 101~'])
768
+ # #df = pd.DataFrame.from_dict(sum_all, orient='index')
769
+ # #st.line_chart(df)
770
+
771
+ # #data_canada = px.data.gapminder().query("country == 'Canada'")
772
+ # #fig = px.bar(data_canada, x='year', y='pop')
773
+
774
+ # if st.button('Show chart'):
775
+ # fig = px.bar(df, barmode='group')
776
+ # st.plotly_chart(fig, use_container_width=True)
777
+ # #fig.show()
778
+ # #st.area_chart(df)
779
+ # #st.bar_chart(df)
780
+
781
+ # #
782
+ # base_fn = '%s_%s_forward.json' % (prefix_lst[ id_to_scroll ], st.session_state['num'])
783
+ # result, avg_pick, avg_prob, _, _, _, _, _, _, _, _ = calc_details(base_fn)
784
+ # recv = result['recv']
785
+ # lst = result['output']
786
+ # input_tokens = result['input']
787
+
788
+ # # (Pdb) print(token_pos_lst[0].keys())
789
+ # #dict_keys(['idx', 'gen_text', 'actual_next_token_text', 'actual_next_token_top_seq', 'actual_next_token_top_prob', 'top_n_lst'])
790
+
791
+ # height = calc_height(recv['context'])
792
+ # st.text_area('context:', recv['context'], height=height)
793
+
794
+ # pos = st.slider("Token position", 0, len(lst))
795
+ # prompt = ''
796
+ # for i in range(pos+1):
797
+ # prompt += input_tokens[i]['text']
798
+ # height = calc_height(prompt)
799
+ # st.text_area('prompt:', prompt, height=height)
800
+
801
+ # ch = handle_char_return(lst[pos]['actual_next_token_text'])
802
+ # st.text('actual_next_token_text: %s --> pick seq: %s (prob: %.2f)' % (ch, int(lst[pos]['actual_next_token_top_seq'])+1,
803
+ # float(lst[pos]['actual_next_token_top_prob'])))
804
+
805
+ # st.text('top 10 tokens:')
806
+ # for i, v in enumerate(lst[pos]['top_n_lst']):
807
+ # ch = handle_char_return(v['top_n_text'])
808
+ # st.text('[ %s ][ %s ]( %.2f )' % (i+1, ch, float(v['top_n_prob'])))
809
+
810
+ # gen_text = lst[pos]['gen_text']
811
+ # gen_text = remove_end_of_claim_text(gen_text)
812
+
813
+ # st.text('gen_text: %s' % gen_text)
814
+ # #st.text("done. ok.")
815
+ # #st.text('result:\n%s' % result)
816
+
817
+ if __name__ == "__main__":
818
+ main()
819
+
820
+ #def load_data_pre(demo):
821
+ # fn = 'ppo_output/ppo_open_llama_3b_v2.run.12.keep.txt'
822
+ # with open(fn, 'r') as f:
823
+ # rows = json.load(f)
824
+
825
+ # new_rows = []
826
+ # for i, row in enumerate(rows):
827
+ # item1 = {}
828
+ # item2 = {}
829
+ # if demo == 'demo1':
830
+ # item1[ 'delta' ] = abs(row['score_lst_1'][0] - row['score_lst_2'][0])
831
+ # item2[ 'delta' ] = abs(row['score_lst_1'][1] - row['score_lst_2'][1])
832
+ # elif demo == 'demo2':
833
+ # #pdb.set_trace()
834
+ # item1[ 'delta' ] = calc_sentence_similarity(sent_aipd, row['output'][0], row['response'][0])
835
+ # item2[ 'delta' ] = calc_sentence_similarity(sent_aipd, row['output'][1], row['response'][1])
836
+
837
+ # print('[ %s ] detla = %s' % (i, item1[ 'delta' ]))
838
+
839
+ # for k in row.keys():
840
+ # item1[ k ] = row[ k ][0]
841
+ # item2[ k ] = row[ k ][1]
842
+
843
+ # if demo == 'demo1':
844
+ # if item1['instruction'].find('child') > 0:
845
+ # new_rows.append(item1)
846
+ # if item2['instruction'].find('child') > 0:
847
+ # new_rows.append(item2)
848
+ # elif demo == 'demo2':
849
+ # if item1['instruction'].find('parent') > 0:
850
+ # new_rows.append(item1)
851
+ # if item2['instruction'].find('parent') > 0:
852
+ # new_rows.append(item2)
853
+
854
+ # # Assuming new_rows is your list of dictionaries
855
+ # sorted_rows = sorted(new_rows, key=lambda x: x['delta'])
856
+
857
+ # # kv = {}
858
+ # # for i, row in enumerate(new_rows):
859
+ # # if diff > 0.0001:
860
+ # # kv[i] = round(diff, 4)
861
+
862
+ # # sorted_rows = []
863
+ # # sorted_kv = sorted(kv.items(), key=lambda x:x[1])
864
+ # # for k, v in sorted_kv:
865
+ # # sorted_rows.append(new_rows[k])
866
+
867
+ # #pdb.set_trace()
868
+
869
+ # return sorted_rows