|
import torch |
|
import random |
|
from masking_methods import MaskingProcessor |
|
|
|
class SamplingProcessor: |
|
def __init__(self, tokenizer): |
|
self.tokenizer = tokenizer |
|
|
|
def fill_masked_sentence(self, original_sentence, mask_logits, sampling_technique, temperature=1.0): |
|
""" |
|
Fills each mask in the masked sentence using the specified sampling technique. |
|
|
|
Args: |
|
original_sentence (str): The original masked sentence. |
|
mask_logits (dict): Logits for each [MASK] token. |
|
sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy"). |
|
temperature (float): Temperature parameter for sampling methods. |
|
|
|
Returns: |
|
str: Sentence with the masks filled. |
|
""" |
|
sentence_tokens = self.tokenizer.tokenize(original_sentence) |
|
mask_token_indices = [i for i, token in enumerate(sentence_tokens) if token == self.tokenizer.mask_token] |
|
|
|
if len(mask_token_indices) != len(mask_logits): |
|
raise ValueError("Mismatch between number of [MASK] tokens and logits provided.") |
|
|
|
for mask_idx, filtered_logits in zip(mask_token_indices, mask_logits.values()): |
|
|
|
filtered_logits = torch.tensor(filtered_logits) |
|
|
|
|
|
|
|
|
|
if sampling_technique == "inverse_transform": |
|
probs = torch.softmax(filtered_logits / temperature, dim=-1) |
|
cumulative_probs = torch.cumsum(probs, dim=-1) |
|
random_prob = random.random() |
|
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item() |
|
|
|
elif sampling_technique == "exponential_minimum": |
|
probs = torch.softmax(filtered_logits / temperature, dim=-1) |
|
exp_probs = torch.exp(-torch.log(probs)) |
|
random_probs = torch.rand_like(exp_probs) |
|
sampled_index = torch.argmax(random_probs * exp_probs).item() |
|
|
|
elif sampling_technique == "temperature": |
|
filtered_logits = torch.clamp(filtered_logits, min=-1e8, max=1e8) |
|
probs = torch.softmax(filtered_logits / temperature, dim=-1) |
|
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)): |
|
raise ValueError("The computed probabilities contain NaN or inf values.") |
|
probs = torch.max(probs, torch.tensor(1e-8, device=filtered_logits.device)) |
|
probs = probs / torch.sum(probs) |
|
probs = probs.flatten() |
|
if probs.size(0) > 1: |
|
sampled_index = torch.multinomial(probs, 1).item() |
|
else: |
|
sampled_index = torch.argmax(probs).item() |
|
|
|
elif sampling_technique == 'greedy': |
|
sampled_index = torch.argmax(filtered_logits).item() |
|
|
|
else: |
|
raise ValueError(f"Unknown sampling technique: {sampling_technique}") |
|
|
|
sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0] |
|
sentence_tokens[mask_idx] = sampled_token |
|
|
|
return self.tokenizer.convert_tokens_to_string(sentence_tokens) |
|
|
|
|
|
|
|
def process_samples(self, masked_sentences, mask_logits, sampling_technique, temperature=1.0): |
|
""" |
|
Process multiple masked sentences and fill their masks using the specified sampling technique. |
|
|
|
Args: |
|
masked_sentences (list): List of masked sentences. |
|
mask_logits (dict): Logits for each [MASK] token in each sentence. |
|
sampling_technique (str): Sampling technique to use. |
|
temperature (float): Temperature parameter for sampling methods. |
|
|
|
Returns: |
|
list: List of sentences with masks filled. |
|
""" |
|
filled_sentences = [] |
|
for sentence, logits in zip(masked_sentences, mask_logits): |
|
filled_sentence = self.fill_masked_sentence(sentence, logits, sampling_technique, temperature) |
|
filled_sentences.append(filled_sentence) |
|
return filled_sentences |
|
|
|
|
|
if __name__ == "__main__": |
|
from transformers import BertTokenizer |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
processor = SamplingProcessor(tokenizer) |
|
|
|
sentences = [ |
|
"The quick brown fox jumps over the lazy dog.", |
|
"A quick brown dog outpaces a lazy fox.", |
|
"Quick brown dog leaps over lazy the fox." |
|
] |
|
result_dict = { |
|
"The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]}, |
|
"A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}, |
|
"Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]} |
|
} |
|
|
|
|
|
masking_processor = MaskingProcessor() |
|
masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for sentence, result in masking_results.items(): |
|
print(f"Original Sentence (Random): {sentence}") |
|
print(f"Masked Sentence (Random): {result['masked_sentence']}") |
|
|
|
print(f' type(result["mask_logits"]) : {type(result["mask_logits"])}') |
|
print(f' length of result["mask_logits"] : {len(result["mask_logits"])}') |
|
print(f' result["mask_logits"].keys() : {result["mask_logits"].keys()}') |
|
masked_sentence = result["masked_sentence"] |
|
mask_logits = result["mask_logits"] |
|
|
|
print(f"Original Masked Sentence: {masked_sentence}") |
|
|
|
|
|
for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]: |
|
print(f"Sampling Technique: {technique}") |
|
|
|
|
|
filled_sentence = processor.fill_masked_sentence( |
|
original_sentence=masked_sentence, |
|
mask_logits=mask_logits, |
|
sampling_technique=technique, |
|
temperature=1.0 |
|
) |
|
|
|
print(f"Filled Sentence: {filled_sentence}\n") |
|
print('--------------------------------') |