Spaces:
Running on Zero

Ruurd commited on
Commit
dc427d9
·
verified ·
1 Parent(s): a721355

Safe sampling

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -135,9 +135,13 @@ def generate_diffusion_text(input_ids):
135
  with torch.no_grad():
136
  input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
137
  logits = model(input_ids=input_tensor)["logits"]
138
- probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
139
- probs = torch.clamp(probs, min=1e-8, max=1.0)
140
- sampled = torch.multinomial(probs, num_samples=1).squeeze().tolist()
 
 
 
 
141
 
142
  # Extract confidence of selected tokens
143
  conf = probs[range(len(sampled)), sampled].cpu().numpy()
 
135
  with torch.no_grad():
136
  input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
137
  logits = model(input_ids=input_tensor)["logits"]
138
+ probs = torch.nn.functional.softmax(logits, dim=-1)[0]
139
+ probs = torch.clamp(probs, min=1e-8, max=1.0)]
140
+ print("probs", probs)
141
+ print("probs min:", probs.min().item(), "max:", probs.max().item(), "sum:", probs.sum().item())
142
+ assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
143
+ assert (probs >= 0).all(), "Negative probs!"
144
+ sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist()
145
 
146
  # Extract confidence of selected tokens
147
  conf = probs[range(len(sampled)), sampled].cpu().numpy()