import torch import random import logging from utils.masking_methods import MaskingProcessor from tqdm import tqdm # Configure logging to suppress INFO-level messages on the console. logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) class SamplingProcessor: def __init__(self, tokenizer): """ Initialize the SamplingProcessor. Args: tokenizer: BERT tokenizer instance """ self.tokenizer = tokenizer self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tqdm.write(f"[SamplingProcessor] Initialized on device: {self.device}") def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0): """ Sample tokens for each mask in the sentence using the specified sampling technique. Args: mask_logits_dict (dict): Dictionary of mask positions and their logits/tokens masked_sentence (str): Sentence with [MASK] tokens sampling_technique (str): Sampling method to use temperature (float): Temperature parameter for sampling Returns: str: Sentence with sampled tokens replacing masks """ tqdm.write(f"[SamplingProcessor] Sampling tokens for: {masked_sentence}") print(f"[SamplingProcessor] Sampling tokens for: {masked_sentence}") words = masked_sentence.split() print(f"words: {words}") # Convert positions and logits to sorted list to process masks in order mask_positions = sorted(mask_logits_dict.keys()) print(f"mask_positions: {mask_positions}") for mask_pos in mask_positions: mask_data = mask_logits_dict[mask_pos] # Move logits tensor to GPU mask_logits = torch.tensor(mask_data['logits']).to(self.device) candidate_tokens = mask_data['tokens'] try: if sampling_technique == "inverse_transform": probs = torch.softmax(mask_logits / temperature, dim=-1) cumulative_probs = torch.cumsum(probs, dim=-1) random_prob = random.random() sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item() elif sampling_technique == "exponential_minimum": probs = torch.softmax(mask_logits / temperature, dim=-1) exp_probs = torch.exp(-torch.log(probs)) random_probs = torch.rand_like(exp_probs) sampled_index = torch.argmax(random_probs * exp_probs).item() elif sampling_technique == "temperature": mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8) probs = torch.softmax(mask_logits / temperature, dim=-1) if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)): raise ValueError("The computed probabilities contain NaN or inf values.") probs = torch.max(probs, torch.tensor(1e-8).to(self.device)) probs = probs / torch.sum(probs) probs = probs.flatten() if probs.size(0) > 1: sampled_index = torch.multinomial(probs, 1).item() else: sampled_index = torch.argmax(probs).item() elif sampling_technique == 'greedy': sampled_index = torch.argmax(mask_logits).item() else: raise ValueError(f"Unknown sampling technique: {sampling_technique}") # Use the sampled index to get the corresponding token sampled_token = candidate_tokens[sampled_index] # Remove ## if it's a subword token sampled_token = sampled_token.replace('##', '') words[mask_pos] = sampled_token logger.info(f"Sampled token '{sampled_token}' for mask position {mask_pos}.") except Exception as e: logger.error(f"Error sampling for position {mask_pos}: {str(e)}") continue sampled_sentence = " ".join(words) tqdm.write(f"[SamplingProcessor] Sampled sentence: {sampled_sentence}") return sampled_sentence def process_masked_sentences(self, results_dict, sampling_technique="temperature", temperature=1.0): """ Process all masked sentences in the results dictionary. Args: results_dict (dict): Dictionary containing masked sentences and their logits sampling_technique (str): Sampling method to use temperature (float): Temperature parameter for sampling Returns: dict: Dictionary containing original, masked, and sampled sentences """ tqdm.write("[SamplingProcessor] Starting sampling for masked sentences.") processed_results = {} # Wrap the iteration over each original sentence with tqdm for original_sentence, data in tqdm(results_dict.items(), desc="Sampling Masked Sentences"): masked_sentence = data["masked_sentence"] mask_logits = data["mask_logits"] sampled_sentence = self.sample_tokens(mask_logits, masked_sentence, sampling_technique, temperature) processed_results[original_sentence] = { "masked_sentence": masked_sentence, "sampled_sentence": sampled_sentence } logger.info(f"Processed sampling for sentence: {original_sentence}") tqdm.write("[SamplingProcessor] Completed sampling for all sentences.") return processed_results if __name__ == "__main__": sentences = [ "The quick brown fox jumps over the lazy dog everyday.", "A speedy brown fox jumps over a lazy dog.", "A swift brown fox leaps over the lethargic dog." ] result_dict = { 'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]} } # First, mask the sentences masking_processor = MaskingProcessor() masking_results = masking_processor.process_sentences(sentences, result_dict) # Then, sample replacements for the masks sampling_processor = SamplingProcessor(masking_processor.tokenizer) # Try different sampling techniques sampling_techniques = ["temperature", "greedy", "inverse_transform", "exponential_minimum"] for technique in sampling_techniques: logger.info(f"Sampling using technique: {technique}") sampled_results = sampling_processor.process_masked_sentences( masking_results, sampling_technique=technique, temperature=1.0 ) ''' { "original_sentence_1": { "masked_sentence": "sentence with [MASK] tokens", "sampling_method1": "sentence with sampled tokens", }, "original_sentence_2": { "masked_sentence": "sentence with [MASK] tokens", "sampling_method": "sentence with sampled tokens" }, # ... and so on for each input sentence }, ''' for original_sentence, result in sampled_results.items(): logger.info(f"Original: {original_sentence}") logger.info(f"Masked: {result['masked_sentence']}") logger.info(f"Sampled: {result['sampled_sentence']}") logger.info("---")