Spaces:
Running on Zero

Ruurd commited on
Commit
2ba8b3f
·
1 Parent(s): a5ca1bf

Add confidence based noising

Browse files
Files changed (1) hide show
  1. app.py +30 -7
app.py CHANGED
@@ -73,6 +73,21 @@ def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0):
73
  noised[idx] = val
74
  return noised
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  @spaces.GPU
77
  def generate_diffusion_text(input_ids, answer_start):
78
  with torch.no_grad():
@@ -81,18 +96,21 @@ def generate_diffusion_text(input_ids, answer_start):
81
  probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
82
  probs = torch.clamp(probs, min=1e-8, max=1.0)
83
  sampled = torch.multinomial(probs, num_samples=1).squeeze().tolist()
84
- return input_ids[:answer_start] + sampled[answer_start:]
 
 
 
85
 
86
  # --- Inference Wrapper ---
87
 
88
 
89
- def diffusion_chat(question, eot_weight, max_it, sharpness):
 
90
  placeholder = "What do you know about the city of New York?"
91
  if question.strip() == "":
92
  question = placeholder
93
-
94
- print('started generation')
95
 
 
96
  prompt = f"User: {question}\nAssistant:"
97
  input_ids = tokenizer.encode(prompt, add_special_tokens=False)
98
  answer_start = find_answer_start(input_ids, assistant_marker_ids)
@@ -112,7 +130,7 @@ def diffusion_chat(question, eot_weight, max_it, sharpness):
112
 
113
  for i in range(max_it):
114
  print('Generating output')
115
- generated_tokens = generate_diffusion_text(current_tokens, answer_start)
116
  current_tokens = generated_tokens
117
 
118
  decoded_ids = current_tokens[answer_start:]
@@ -141,7 +159,11 @@ def diffusion_chat(question, eot_weight, max_it, sharpness):
141
  break
142
 
143
  threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
144
- current_tokens = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight)
 
 
 
 
145
  time.sleep(0.01)
146
 
147
  final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
@@ -162,7 +184,8 @@ demo = gr.Interface(
162
  gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"),
163
  gr.Slider(0, 1, value=0.4, step=0.05, label="↓ = longer answers (EOT weight)"),
164
  gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
165
- gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)")
 
166
  ],
167
  outputs=[gr.HTML(label="Diffusion Output")],
168
  title="Diffusion Language Model Chat",
 
73
  noised[idx] = val
74
  return noised
75
 
76
+ # Add new noising function
77
+ def confidence_guided_noising(input_ids, answer_start, confidences, eot_weight):
78
+ noised = input_ids.copy()
79
+ mixed_probs = token_probabilities.copy()
80
+ mixed_probs[eot_token_id] *= eot_weight
81
+ mixed_probs /= mixed_probs.sum()
82
+
83
+ for i, conf in enumerate(confidences[answer_start:]):
84
+ p_noise = 1.0 - conf
85
+ if rng.random() < p_noise:
86
+ idx = answer_start + i
87
+ noised[idx] = rng.choice(np.arange(vocab_size), p=mixed_probs)
88
+
89
+ return noised
90
+
91
  @spaces.GPU
92
  def generate_diffusion_text(input_ids, answer_start):
93
  with torch.no_grad():
 
96
  probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
97
  probs = torch.clamp(probs, min=1e-8, max=1.0)
98
  sampled = torch.multinomial(probs, num_samples=1).squeeze().tolist()
99
+
100
+ # Extract confidence of selected tokens
101
+ conf = probs[range(len(sampled)), sampled].cpu().numpy()
102
+ return sampled, conf # ✅ NEW: Return confidence
103
 
104
  # --- Inference Wrapper ---
105
 
106
 
107
+ # Modify diffusion_chat to use new noising conditionally
108
+ def diffusion_chat(question, eot_weight, max_it, sharpness, use_confidence_noising):
109
  placeholder = "What do you know about the city of New York?"
110
  if question.strip() == "":
111
  question = placeholder
 
 
112
 
113
+ print('started generation')
114
  prompt = f"User: {question}\nAssistant:"
115
  input_ids = tokenizer.encode(prompt, add_special_tokens=False)
116
  answer_start = find_answer_start(input_ids, assistant_marker_ids)
 
130
 
131
  for i in range(max_it):
132
  print('Generating output')
133
+ generated_tokens, confidences = generate_diffusion_text(current_tokens, answer_start)
134
  current_tokens = generated_tokens
135
 
136
  decoded_ids = current_tokens[answer_start:]
 
159
  break
160
 
161
  threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
162
+ if use_confidence_noising:
163
+ current_tokens = confidence_guided_noising(generated_tokens, answer_start, confidences, eot_weight)
164
+ else:
165
+ current_tokens = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight)
166
+
167
  time.sleep(0.01)
168
 
169
  final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
 
184
  gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"),
185
  gr.Slider(0, 1, value=0.4, step=0.05, label="↓ = longer answers (EOT weight)"),
186
  gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
187
+ gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)"),
188
+ gr.Checkbox(value=False, label="Use confidence-guided noising") # ✅ NEW
189
  ],
190
  outputs=[gr.HTML(label="Diffusion Output")],
191
  title="Diffusion Language Model Chat",