Spaces:
Running
on
Zero
Running
on
Zero
Safe sampling
Browse files
app.py
CHANGED
@@ -135,9 +135,13 @@ def generate_diffusion_text(input_ids):
|
|
135 |
with torch.no_grad():
|
136 |
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
|
137 |
logits = model(input_ids=input_tensor)["logits"]
|
138 |
-
probs = torch.nn.functional.softmax(logits, dim=-1)
|
139 |
-
probs = torch.clamp(probs, min=1e-8, max=1.0)
|
140 |
-
|
|
|
|
|
|
|
|
|
141 |
|
142 |
# Extract confidence of selected tokens
|
143 |
conf = probs[range(len(sampled)), sampled].cpu().numpy()
|
|
|
135 |
with torch.no_grad():
|
136 |
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
|
137 |
logits = model(input_ids=input_tensor)["logits"]
|
138 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
139 |
+
probs = torch.clamp(probs, min=1e-8, max=1.0)]
|
140 |
+
print("probs", probs)
|
141 |
+
print("probs min:", probs.min().item(), "max:", probs.max().item(), "sum:", probs.sum().item())
|
142 |
+
assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
|
143 |
+
assert (probs >= 0).all(), "Negative probs!"
|
144 |
+
sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist()
|
145 |
|
146 |
# Extract confidence of selected tokens
|
147 |
conf = probs[range(len(sampled)), sampled].cpu().numpy()
|