import torch import random from masking_methods import MaskingProcessor class SamplingProcessor: def __init__(self, tokenizer): """ Initialize the SamplingProcessor. Args: tokenizer: BERT tokenizer instance """ self.tokenizer = tokenizer def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0): """ Sample tokens for each mask in the sentence using the specified sampling technique. Args: mask_logits_dict (dict): Dictionary of mask positions and their logits/tokens masked_sentence (str): Sentence with [MASK] tokens sampling_technique (str): Sampling method to use temperature (float): Temperature parameter for sampling Returns: str: Sentence with sampled tokens replacing masks """ words = masked_sentence.split() # Convert positions and logits to sorted list to process masks in order mask_positions = sorted(mask_logits_dict.keys()) for mask_pos in mask_positions: mask_data = mask_logits_dict[mask_pos] mask_logits = torch.tensor(mask_data['logits']) candidate_tokens = mask_data['tokens'] try: if sampling_technique == "inverse_transform": probs = torch.softmax(mask_logits / temperature, dim=-1) cumulative_probs = torch.cumsum(probs, dim=-1) random_prob = random.random() sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item() elif sampling_technique == "exponential_minimum": probs = torch.softmax(mask_logits / temperature, dim=-1) exp_probs = torch.exp(-torch.log(probs)) random_probs = torch.rand_like(exp_probs) sampled_index = torch.argmax(random_probs * exp_probs).item() elif sampling_technique == "temperature": mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8) probs = torch.softmax(mask_logits / temperature, dim=-1) if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)): raise ValueError("The computed probabilities contain NaN or inf values.") probs = torch.max(probs, torch.tensor(1e-8)) probs = probs / torch.sum(probs) probs = probs.flatten() if probs.size(0) > 1: sampled_index = torch.multinomial(probs, 1).item() else: sampled_index = torch.argmax(probs).item() elif sampling_technique == 'greedy': sampled_index = torch.argmax(mask_logits).item() else: raise ValueError(f"Unknown sampling technique: {sampling_technique}") # Use the sampled index to get the corresponding token sampled_token = candidate_tokens[sampled_index] # Remove ## if it's a subword token sampled_token = sampled_token.replace('##', '') words[mask_pos] = sampled_token except Exception as e: print(f"Error sampling for position {mask_pos}: {str(e)}") continue return " ".join(words) def process_masked_sentences(self, results_dict, sampling_technique="temperature", temperature=1.0): """ Process all masked sentences in the results dictionary. Args: results_dict (dict): Dictionary containing masked sentences and their logits sampling_technique (str): Sampling method to use temperature (float): Temperature parameter for sampling Returns: dict: Dictionary containing original, masked, and sampled sentences """ processed_results = {} for original_sentence, data in results_dict.items(): masked_sentence = data["masked_sentence"] mask_logits = data["mask_logits"] sampled_sentence = self.sample_tokens( mask_logits, masked_sentence, sampling_technique, temperature ) processed_results[original_sentence] = { "masked_sentence": masked_sentence, "sampled_sentence": sampled_sentence } return processed_results if __name__ == "__main__": sentences = [ "The quick brown fox jumps over the lazy dog everyday.", "A speedy brown fox jumps over a lazy dog.", "A swift brown fox leaps over the lethargic dog." ] result_dict ={ 'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]} } # First, mask the sentences masking_processor = MaskingProcessor() masking_results = masking_processor.process_sentences(sentences, result_dict) # Then, sample replacements for the masks sampling_processor = SamplingProcessor(masking_processor.tokenizer) # Try different sampling techniques sampling_techniques = ["temperature", "greedy", "inverse_transform", "exponential_minimum"] for technique in sampling_techniques: print(f"\nSampling using {technique}:") sampled_results = sampling_processor.process_masked_sentences( masking_results, sampling_technique=technique, temperature=1.0 ) ''' { "original_sentence_1": { "masked_sentence": "sentence with [MASK] tokens", "sampling_method1": "sentence with sampled tokens", }, "original_sentence_2": { "masked_sentence": "sentence with [MASK] tokens", "sampling_method": "sentence with sampled tokens" }, # ... and so on for each input sentence }, ''' for original_sentence, result in sampled_results.items(): print(f"Original: {original_sentence}") print(f"Masked: {result['masked_sentence']}") print(f"Sampled: {result['sampled_sentence']}") print("---")