Fix batch generation
Browse files- generation_utils.py +10 -7
generation_utils.py
CHANGED
@@ -433,18 +433,21 @@ class DreamGenerationMixin:
|
|
433 |
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
|
434 |
else:
|
435 |
raise RuntimeError(f"Unknown alg: {alg}")
|
436 |
-
num_mask_token = mask_index.sum()
|
437 |
-
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else num_mask_token
|
|
|
|
|
438 |
if number_transfer_tokens > 0:
|
439 |
if alg_temp is None or alg_temp == 0:
|
440 |
-
_, transfer_index = torch.topk(
|
441 |
else:
|
442 |
confidence = confidence / alg_temp
|
443 |
confidence = F.softmax(confidence, dim=-1)
|
444 |
-
transfer_index = torch.multinomial(
|
445 |
-
|
446 |
-
|
447 |
-
x
|
|
|
448 |
|
449 |
# this allows user-defined token control of the intermediate steps
|
450 |
x = generation_tokens_hook_func(i, x, logits)
|
|
|
433 |
confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
|
434 |
else:
|
435 |
raise RuntimeError(f"Unknown alg: {alg}")
|
436 |
+
num_mask_token = mask_index.sum() / mask_index.shape[0]
|
437 |
+
number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
|
438 |
+
full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
|
439 |
+
full_confidence[mask_index] = confidence
|
440 |
if number_transfer_tokens > 0:
|
441 |
if alg_temp is None or alg_temp == 0:
|
442 |
+
_, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
|
443 |
else:
|
444 |
confidence = confidence / alg_temp
|
445 |
confidence = F.softmax(confidence, dim=-1)
|
446 |
+
transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
|
447 |
+
x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
|
448 |
+
x_[mask_index] = x0.clone()
|
449 |
+
row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
|
450 |
+
x[row_indices,transfer_index] = x_[row_indices,transfer_index]
|
451 |
|
452 |
# this allows user-defined token control of the intermediate steps
|
453 |
x = generation_tokens_hook_func(i, x, logits)
|