Spaces:
Running on Zero

Ruurd commited on
Commit
f86092a
·
1 Parent(s): f5bf8f9

Highlight noised tokens

Browse files
Files changed (1) hide show
  1. app.py +17 -13
app.py CHANGED
@@ -62,14 +62,13 @@ def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0, clust
62
  num_to_noise = int(threshold * answer_len)
63
 
64
  if num_to_noise == 0:
65
- return noised
66
 
67
  mixed_probs = token_probabilities.copy()
68
  mixed_probs[eot_token_id] *= eot_weight
69
  mixed_probs /= mixed_probs.sum()
70
 
71
- # Determine number of clusters and average cluster size
72
- num_clusters = max(1, int((1 - clustering) * num_to_noise)) # fewer clusters if more intensity
73
  cluster_size = max(1, int(num_to_noise / num_clusters))
74
 
75
  noised_indices = set()
@@ -79,15 +78,13 @@ def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0, clust
79
  span_end = min(len(noised), span_start + cluster_size)
80
  noised_indices.update(range(span_start, span_end))
81
 
82
- # Trim in case we overshot due to overlapping clusters
83
  noised_indices = sorted(list(noised_indices))[:num_to_noise]
84
 
85
  noise = rng.choice(np.arange(vocab_size), size=len(noised_indices), p=mixed_probs)
86
  for idx, val in zip(noised_indices, noise):
87
  noised[idx] = val
88
 
89
- return noised
90
-
91
 
92
 
93
  # Add new noising function
@@ -165,7 +162,9 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
165
  input_ids = input_ids[:256]
166
 
167
  ori_input_tokens = input_ids
168
- current_tokens = noisify_answer(ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight)
 
 
169
  prev_decoded_tokens = []
170
  last_tokens = []
171
 
@@ -178,14 +177,19 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
178
  decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
179
  filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
180
  filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else []
181
-
182
  if filtered_prev_tokens:
183
  highlighted = []
184
- for tok_new, tok_old in zip(filtered_tokens, filtered_prev_tokens):
185
- if tok_new != tok_old:
186
- highlighted.append(f'<span style="color:green">{tokenizer.convert_tokens_to_string([tok_new])}</span>')
 
 
 
 
 
187
  else:
188
- highlighted.append(tokenizer.convert_tokens_to_string([tok_new]))
189
  else:
190
  highlighted = [tokenizer.convert_tokens_to_string([tok]) for tok in filtered_tokens]
191
 
@@ -203,7 +207,7 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
203
  if use_confidence_noising:
204
  current_tokens = confidence_guided_noising(generated_tokens, answer_start, confidences, threshold, eot_weight, noise_clipping)
205
  else:
206
- current_tokens = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, clustering=clustering)
207
 
208
  time.sleep(0.01)
209
 
 
62
  num_to_noise = int(threshold * answer_len)
63
 
64
  if num_to_noise == 0:
65
+ return noised, []
66
 
67
  mixed_probs = token_probabilities.copy()
68
  mixed_probs[eot_token_id] *= eot_weight
69
  mixed_probs /= mixed_probs.sum()
70
 
71
+ num_clusters = max(1, int((1 - clustering) * num_to_noise))
 
72
  cluster_size = max(1, int(num_to_noise / num_clusters))
73
 
74
  noised_indices = set()
 
78
  span_end = min(len(noised), span_start + cluster_size)
79
  noised_indices.update(range(span_start, span_end))
80
 
 
81
  noised_indices = sorted(list(noised_indices))[:num_to_noise]
82
 
83
  noise = rng.choice(np.arange(vocab_size), size=len(noised_indices), p=mixed_probs)
84
  for idx, val in zip(noised_indices, noise):
85
  noised[idx] = val
86
 
87
+ return noised, noised_indices
 
88
 
89
 
90
  # Add new noising function
 
162
  input_ids = input_ids[:256]
163
 
164
  ori_input_tokens = input_ids
165
+ current_tokens, just_noised_indices = noisify_answer(
166
+ ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight, clustering=clustering
167
+ )
168
  prev_decoded_tokens = []
169
  last_tokens = []
170
 
 
177
  decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
178
  filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
179
  filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else []
180
+ just_noised_indices = []
181
  if filtered_prev_tokens:
182
  highlighted = []
183
+ for i, tok in enumerate(decoded_tokens):
184
+ token_str = tokenizer.convert_tokens_to_string([tok])
185
+
186
+ abs_idx = answer_start + i
187
+ if abs_idx in just_noised_indices:
188
+ highlighted.append(f'<span style="color:red">{token_str}</span>')
189
+ elif prev_decoded_tokens and i < len(prev_decoded_tokens) and tok != prev_decoded_tokens[i]:
190
+ highlighted.append(f'<span style="color:green">{token_str}</span>')
191
  else:
192
+ highlighted.append(token_str)
193
  else:
194
  highlighted = [tokenizer.convert_tokens_to_string([tok]) for tok in filtered_tokens]
195
 
 
207
  if use_confidence_noising:
208
  current_tokens = confidence_guided_noising(generated_tokens, answer_start, confidences, threshold, eot_weight, noise_clipping)
209
  else:
210
+ current_tokens, just_noised_indices = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, clustering=clustering)
211
 
212
  time.sleep(0.01)
213