|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
def top_k_filtering(logits, top_k: int = 1): |
|
""" |
|
Filter a distribution of logits using top-k and/or top-p (nucleus) filtering. |
|
The input logits tensor is modified in-place. |
|
|
|
Args: |
|
logits: A tensor of logits to be filtered. Expected shape is [..., vocab_size]. |
|
top_k: If > 0, only keep the top k tokens with highest probability. |
|
top_p: If < 1.0, only keep tokens whose cumulative probability is below this threshold. |
|
|
|
Returns: |
|
A tensor of logits where values outside the top-k/top-p threshold are set to -β. |
|
""" |
|
if top_k > 0: |
|
idx_to_remove = logits < logits.topk(top_k, largest=True, sorted=False, dim=-1)[ |
|
0 |
|
].amin(dim=-1, keepdim=True) |
|
logits.masked_fill_(idx_to_remove, -torch.inf) |
|
|
|
return logits |
|
|
|
|
|
def process_logits( |
|
logits, |
|
top_k: int = 1, |
|
): |
|
""" |
|
Process logits by optionally applying top-k filtering. |
|
The final probabilities are returned after applying softmax on the filtered logits. |
|
|
|
Args: |
|
logits: A tensor of logits to process. Expected shape is [..., vocab_size]. |
|
top_k: If > 0, only keep the top k tokens with highest probability. |
|
|
|
Returns: |
|
A tensor of probabilities after filtering, with the same shape as the input logits. |
|
""" |
|
logits = top_k_filtering(logits, top_k=top_k) |
|
probs = F.softmax(logits, dim=-1) |
|
return probs |
|
|