jiacheng-ye commited on
Commit
d6fa7d6
·
verified ·
1 Parent(s): 9ccfc13

Fix batch generation

Browse files
Files changed (1) hide show
  1. generation_utils.py +10 -7
generation_utils.py CHANGED
@@ -433,18 +433,21 @@ class DreamGenerationMixin:
433
  confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
434
  else:
435
  raise RuntimeError(f"Unknown alg: {alg}")
436
- num_mask_token = mask_index.sum()
437
- number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else num_mask_token
 
 
438
  if number_transfer_tokens > 0:
439
  if alg_temp is None or alg_temp == 0:
440
- _, transfer_index = torch.topk(confidence, number_transfer_tokens)
441
  else:
442
  confidence = confidence / alg_temp
443
  confidence = F.softmax(confidence, dim=-1)
444
- transfer_index = torch.multinomial(confidence, num_samples=number_transfer_tokens)
445
- x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_token_id
446
- x0_[transfer_index] = x0[transfer_index].clone()
447
- x[mask_index] = x0_
 
448
 
449
  # this allows user-defined token control of the intermediate steps
450
  x = generation_tokens_hook_func(i, x, logits)
 
433
  confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
434
  else:
435
  raise RuntimeError(f"Unknown alg: {alg}")
436
+ num_mask_token = mask_index.sum() / mask_index.shape[0]
437
+ number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
438
+ full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
439
+ full_confidence[mask_index] = confidence
440
  if number_transfer_tokens > 0:
441
  if alg_temp is None or alg_temp == 0:
442
+ _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
443
  else:
444
  confidence = confidence / alg_temp
445
  confidence = F.softmax(confidence, dim=-1)
446
+ transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
447
+ x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
448
+ x_[mask_index] = x0.clone()
449
+ row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
450
+ x[row_indices,transfer_index] = x_[row_indices,transfer_index]
451
 
452
  # this allows user-defined token control of the intermediate steps
453
  x = generation_tokens_hook_func(i, x, logits)