from transformers import BertTokenizer, BertForMaskedLM import torch import random from masking_methods import MaskingProcessor from transformers import pipeline class SamplingProcessorWithModel: def __init__(self, model_name='bert-base-uncased'): self.tokenizer = BertTokenizer.from_pretrained(model_name) self.model = BertForMaskedLM.from_pretrained(model_name) self.model.eval() # Set the model to evaluation mode def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0): """ Fills each mask in the masked sentence using the specified sampling technique. Args: masked_sentence (str): Sentence with [MASK] tokens. sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy"). temperature (float): Temperature parameter for sampling methods. Returns: str: Sentence with the masks filled. """ input_ids = self.tokenizer.encode(masked_sentence, return_tensors="pt") while self.tokenizer.mask_token_id in input_ids[0]: # Find indices of all [MASK] tokens mask_indices = torch.where(input_ids == self.tokenizer.mask_token_id)[1] # Process the first [MASK] token in the sequence mask_index = mask_indices[0].item() # Get logits from the model with torch.no_grad(): outputs = self.model(input_ids) logits = outputs.logits # Extract logits for the [MASK] token mask_logits = logits[0, mask_index] if sampling_technique == "inverse_transform": probs = torch.softmax(mask_logits / temperature, dim=-1) cumulative_probs = torch.cumsum(probs, dim=-1) random_prob = random.random() sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item() elif sampling_technique == "exponential_minimum": probs = torch.softmax(mask_logits / temperature, dim=-1) exp_probs = torch.exp(-torch.log(probs)) random_probs = torch.rand_like(exp_probs) sampled_index = torch.argmax(random_probs * exp_probs).item() elif sampling_technique == "temperature": mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8) probs = torch.softmax(mask_logits / temperature, dim=-1) if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)): raise ValueError("The computed probabilities contain NaN or inf values.") probs = torch.max(probs, torch.tensor(1e-8, device=mask_logits.device)) probs = probs / torch.sum(probs) probs = probs.flatten() if probs.size(0) > 1: sampled_index = torch.multinomial(probs, 1).item() else: sampled_index = torch.argmax(probs).item() elif sampling_technique == 'greedy': sampled_index = torch.argmax(mask_logits).item() else: raise ValueError(f"Unknown sampling technique: {sampling_technique}") # Replace the first [MASK] with the selected token input_ids[0, mask_index] = sampled_index return self.tokenizer.decode(input_ids[0], skip_special_tokens=True) def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0): """ Fills each mask in the masked sentence using the specified sampling technique. Args: masked_sentence (str): Sentence with [MASK] tokens. sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy"). temperature (float): Temperature parameter for sampling methods. Returns: str: Sentence with the masks filled. """ while '[MASK]' in masked_sentence: # Get predictions for the first [MASK] predictions = self.unmasker(masked_sentence) # Ensure predictions is a list of dictionaries if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions): raise ValueError("Unexpected structure in predictions from the pipeline.") # Extract logits (scores) from the predictions logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32) if sampling_technique == "inverse_transform": probs = torch.softmax(logits / temperature, dim=-1) cumulative_probs = torch.cumsum(probs, dim=-1) random_prob = random.random() sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item() elif sampling_technique == "exponential_minimum": probs = torch.softmax(logits / temperature, dim=-1) exp_probs = torch.exp(-torch.log(probs)) random_probs = torch.rand_like(exp_probs) sampled_index = torch.argmax(random_probs * exp_probs).item() elif sampling_technique == "temperature": logits = torch.clamp(logits, min=-1e8, max=1e8) probs = torch.softmax(logits / temperature, dim=-1) if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)): raise ValueError("The computed probabilities contain NaN or inf values.") probs = torch.max(probs, torch.tensor(1e-8, device=logits.device)) probs = probs / torch.sum(probs) probs = probs.flatten() if probs.size(0) > 1: sampled_index = torch.multinomial(probs, 1).item() else: sampled_index = torch.argmax(probs).item() elif sampling_technique == 'greedy': sampled_index = torch.argmax(logits).item() else: raise ValueError(f"Unknown sampling technique: {sampling_technique}") # Replace the first [MASK] with the selected word sampled_token = predictions[sampled_index]['token_str'] masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1) return masked_sentence # Example usage if __name__ == "__main__": from transformers import BertTokenizer # Define sentences and result_dict sentences = [ "The quick brown fox jumps over the lazy dog.", "A quick brown dog outpaces a lazy fox.", "Quick brown dog leaps over lazy the fox." ] result_dict = { "The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]}, "A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}, "Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]} } masking_processor = MaskingProcessor() masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False) # Use SamplingProcessor sampling_processor = SamplingProcessorWithModel() # Iterate through masking results to apply sampling for sentence, result in masking_results.items(): print(f"Original Sentence (Random): {sentence}") print(f"Masked Sentence (Random): {result['masked_sentence']}") masked_sentence = result["masked_sentence"] # Apply different sampling techniques for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]: print(f"Sampling Technique: {technique}") filled_sentence = sampling_processor.fill_masked_sentence( masked_sentence=masked_sentence, sampling_technique=technique, temperature=1.0 # Adjust temperature as needed ) print(f"Filled Sentence: {filled_sentence}\n") print('--------------------------------') # from transformers import pipeline # import torch # import random # from masking_methods import MaskingProcessor # class SamplingProcessorWithPipeline: # def __init__(self, model_name='bert-base-uncased'): # self.unmasker = pipeline('fill-mask', model=model_name) # self.tokenizer = self.unmasker.tokenizer # def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0): # """ # Fills each mask in the masked sentence using the specified sampling technique. # Args: # masked_sentence (str): Sentence with [MASK] tokens. # sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy"). # temperature (float): Temperature parameter for sampling methods. # Returns: # str: Sentence with the masks filled. # """ # while '[MASK]' in masked_sentence: # # Get predictions for the first [MASK] # predictions = self.unmasker(masked_sentence) # print(f' predictions : {predictions}') # print(f' type of predictions : {type(predictions)}') # # Ensure predictions is a list of dictionaries for the first [MASK] # if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions): # raise ValueError("Unexpected structure in predictions from the pipeline.") # # Extract logits (scores) from the predictions # logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32) # if sampling_technique == "inverse_transform": # probs = torch.softmax(logits / temperature, dim=-1) # cumulative_probs = torch.cumsum(probs, dim=-1) # random_prob = random.random() # sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item() # elif sampling_technique == "exponential_minimum": # probs = torch.softmax(logits / temperature, dim=-1) # exp_probs = torch.exp(-torch.log(probs)) # random_probs = torch.rand_like(exp_probs) # sampled_index = torch.argmax(random_probs * exp_probs).item() # elif sampling_technique == "temperature": # logits = torch.clamp(logits, min=-1e8, max=1e8) # probs = torch.softmax(logits / temperature, dim=-1) # if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)): # raise ValueError("The computed probabilities contain NaN or inf values.") # probs = torch.max(probs, torch.tensor(1e-8, device=logits.device)) # probs = probs / torch.sum(probs) # probs = probs.flatten() # if probs.size(0) > 1: # sampled_index = torch.multinomial(probs, 1).item() # else: # sampled_index = torch.argmax(probs).item() # elif sampling_technique == 'greedy': # sampled_index = torch.argmax(logits).item() # else: # raise ValueError(f"Unknown sampling technique: {sampling_technique}") # # Replace the first [MASK] with the selected word # sampled_token = predictions[sampled_index]['token_str'] # masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1) # return masked_sentence # # Example usage # if __name__ == "__main__": # from transformers import BertTokenizer # # Define sentences and result_dict # sentences = [ # "The quick brown fox jumps over the lazy dog.", # "A quick brown dog outpaces a lazy fox.", # "Quick brown animals leap over lazy obstacles." # ] # result_dict = { # "The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]}, # "A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]}, # "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]} # } # masking_processor = MaskingProcessor() # masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False) # # Use SamplingProcessor # sampling_processor = SamplingProcessorWithPipeline() # # Iterate through masking results to apply sampling # for sentence, result in masking_results.items(): # print(f"Original Sentence (Random): {sentence}") # print(f"Masked Sentence (Random): {result['masked_sentence']}") # masked_sentence = result["masked_sentence"] # # Apply different sampling techniques # for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]: # print(f"Sampling Technique: {technique}") # filled_sentence = sampling_processor.fill_masked_sentence( # masked_sentence=masked_sentence, # sampling_technique=technique, # temperature=1.0 # Adjust temperature as needed # ) # print(f"Filled Sentence: {filled_sentence}\n") # print('--------------------------------')