import torch import random from masking_methods import MaskingProcessor import nltk from nltk.corpus import words import torch.nn.functional as F class SamplingProcessor: def __init__(self, tokenizer): """ Initialize the SamplingProcessor. Args: tokenizer: BERT tokenizer instance """ self.tokenizer = tokenizer self.subtoken_prefix = self._get_subtoken_prefix() self.subtoken_ids = self._get_subtoken_ids() try: nltk.data.find('corpora/words') except LookupError: nltk.download('words') self.english_words = set(words.words()) # def _get_subtoken_prefix(self): # """ # Identify the subtoken prefix based on the tokenizer. # Returns: # str: The prefix used for subtokens (e.g., "##" for BERT). # """ # # This method assumes that the tokenizer uses a consistent subtoken prefix. # # Adjust accordingly if using different tokenizers. # # For BERT's WordPiece tokenizer: # if hasattr(self.tokenizer, "init_kwargs") and "wordpiece_prefix" in self.tokenizer.init_kwargs: # return self.tokenizer.init_kwargs["wordpiece_prefix"] # elif hasattr(self.tokenizer, "prefix_tokens"): # return self.tokenizer.prefix_tokens # else: # # Default to BERT's subtoken prefix # return "##" def _get_subtoken_prefix(self): """ Identify the subtoken prefix based on the tokenizer. Returns: str: The prefix used for subtokens (e.g., "##" for BERT). """ # This method assumes that the tokenizer uses a consistent subtoken prefix. # Adjust accordingly if using different tokenizers. # For BERT's WordPiece tokenizer: if hasattr(self.tokenizer, "init_kwargs") and "wordpiece_prefix" in self.tokenizer.init_kwargs: return self.tokenizer.init_kwargs["wordpiece_prefix"] elif hasattr(self.tokenizer, "prefix_tokens"): return self.tokenizer.prefix_tokens else: # Default to BERT's subtoken prefix return "##" # def _get_subtoken_ids(self): # """ # Retrieve all token IDs that correspond to subtokens. # Returns: # set: A set of subtoken IDs. # """ # vocab = self.tokenizer.get_vocab() # subtoken_ids = set() # for token, idx in vocab.items(): # if token.startswith(self.subtoken_prefix): # subtoken_ids.add(idx) # return subtoken_ids def _get_subtoken_ids(self): """ Retrieve all token IDs that correspond to subtokens. Returns: list: A list of subtoken IDs. """ vocab = self.tokenizer.get_vocab() subtoken_ids = [] for token, idx in vocab.items(): if token.startswith(self.subtoken_prefix): subtoken_ids.append(idx) return subtoken_ids # Changed from set to list def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0): tokens = self.tokenizer.tokenize(masked_sentence) for mask_pos in sorted(mask_logits_dict.keys()): try: # Get logits and squeeze extra dimension mask_logits = torch.tensor(mask_logits_dict[mask_pos]).squeeze(0) # Remove the extra dimension # Create a mask for valid tokens (no special tokens, no subwords) valid_mask = torch.zeros_like(mask_logits, dtype=torch.bool) for idx in range(len(mask_logits)): token = self.tokenizer.convert_ids_to_tokens([idx])[0] # Only allow regular words (no special tokens, no subwords) if token.isalpha() and not token.startswith('[') and not token.startswith('##'): valid_mask[idx] = True # Get valid logits valid_logits = mask_logits[valid_mask] valid_indices = torch.where(valid_mask)[0] if len(valid_logits) == 0: print(f"Warning: No valid tokens found for position {mask_pos}") continue if sampling_technique == "inverse_transform": probs = torch.softmax(valid_logits / temperature, dim=-1) cumulative_probs = torch.cumsum(probs, dim=-1) random_prob = random.random() sampled_idx = torch.where(cumulative_probs >= random_prob)[0][0].item() sampled_index = valid_indices[sampled_idx].item() elif sampling_technique == "exponential_minimum": probs = torch.softmax(valid_logits / temperature, dim=-1) exp_probs = torch.exp(-torch.log(probs)) random_probs = torch.rand_like(exp_probs) sampled_idx = torch.argmax(random_probs * exp_probs).item() sampled_index = valid_indices[sampled_idx].item() elif sampling_technique == "temperature": valid_logits = torch.clamp(valid_logits, min=-1e8, max=1e8) probs = torch.softmax(valid_logits / temperature, dim=-1) if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)): raise ValueError("The computed probabilities contain NaN or inf values.") probs = torch.max(probs, torch.tensor(1e-8)) probs = probs / torch.sum(probs) sampled_idx = torch.multinomial(probs, 1)[0].item() sampled_index = valid_indices[sampled_idx].item() elif sampling_technique == 'greedy': sampled_idx = torch.argmax(valid_logits).item() sampled_index = valid_indices[sampled_idx].item() # Replace mask with sampled token sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0] tokens[mask_pos] = sampled_token except Exception as e: print(f"Error sampling for position {mask_pos}: {str(e)}") continue return self.tokenizer.convert_tokens_to_string(tokens) def process_masked_sentences(self, results_dict, sampling_technique="temperature", temperature=1.0): """ Process all masked sentences in the results dictionary. Args: results_dict (dict): Dictionary containing masked sentences and their logits sampling_technique (str): Sampling method to use temperature (float): Temperature parameter for sampling Returns: dict: Dictionary containing original, masked, and sampled sentences """ processed_results = {} for original_sentence, data in results_dict.items(): masked_sentence = data["masked_sentence"] mask_logits = data["mask_logits"] sampled_sentence = self.sample_tokens( mask_logits, masked_sentence, sampling_technique, temperature ) processed_results[original_sentence] = { "masked_sentence": masked_sentence, "sampled_sentence": sampled_sentence } return processed_results if __name__ == "__main__": sentences = [ "The quick brown fox jumps over the lazy dog everyday.", ] result_dict = { 'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, } # First, mask the sentences masking_processor = MaskingProcessor() masking_results = masking_processor.process_sentences(sentences, result_dict) # Then, sample replacements for the masks sampling_processor = SamplingProcessor(masking_processor.tokenizer) # Try different sampling techniques sampling_techniques = ["temperature", "greedy", "inverse_transform", "exponential_minimum"] for technique in sampling_techniques: print(f"\nSampling using {technique}:") sampled_results = sampling_processor.process_masked_sentences( masking_results, sampling_technique=technique, temperature=1.0 ) for original_sentence, result in sampled_results.items(): print(f"Original: {original_sentence}") print(f"Masked: {result['masked_sentence']}") print(f"Sampled: {result['sampled_sentence']}") print("---") # -------------------------------------------------------------------------------------------------- # def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0, top_k=100): # words = masked_sentence.split() # mask_positions = sorted(mask_logits_dict.keys()) # for mask_pos in mask_positions: # mask_logits = torch.tensor(mask_logits_dict[mask_pos]) # try: # if sampling_technique == "inverse_transform": # probs = torch.softmax(mask_logits / temperature, dim=-1) # cumulative_probs = torch.cumsum(probs, dim=-1) # random_prob = random.random() # sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item() # elif sampling_technique == "exponential_minimum": # probs = torch.softmax(mask_logits / temperature, dim=-1) # exp_probs = torch.exp(-torch.log(probs)) # random_probs = torch.rand_like(exp_probs) # sampled_index = torch.argmax(random_probs * exp_probs).item() # elif sampling_technique == "temperature": # mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8) # probs = torch.softmax(mask_logits / temperature, dim=-1) # if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)): # raise ValueError("The computed probabilities contain NaN or inf values.") # probs = torch.max(probs, torch.tensor(1e-8)) # probs = probs / torch.sum(probs) # sampled_index = torch.multinomial(probs, 1)[0].item() # elif sampling_technique == 'greedy': # sampled_index = torch.argmax(mask_logits).item() # else: # raise ValueError(f"Unknown sampling technique: {sampling_technique}") # # Replace mask with sampled token # sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0] # words[mask_pos] = sampled_token # except Exception as e: # print(f"Error sampling for position {mask_pos}: {str(e)}") # continue # return " ".join(words) ## MORE WEIRD RESULTS # def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0, top_k=100): # words = masked_sentence.split() # mask_positions = sorted(mask_logits_dict.keys()) # for mask_pos in mask_positions: # mask_logits = torch.tensor(mask_logits_dict[mask_pos]) # try: # # Create a mask for valid tokens (no special tokens, no subwords) # valid_mask = torch.zeros_like(mask_logits, dtype=torch.bool) # for idx in range(len(mask_logits)): # token = self.tokenizer.convert_ids_to_tokens([idx])[0] # # Only allow regular words (no special tokens, no subwords) # if token.isalpha() and not token.startswith('[') and not token.startswith('##'): # valid_mask[idx] = True # # Get valid logits # valid_logits = mask_logits[valid_mask] # valid_indices = torch.where(valid_mask)[0] # if len(valid_logits) == 0: # print(f"Warning: No valid tokens found for position {mask_pos}") # continue # if sampling_technique == "inverse_transform": # probs = torch.softmax(valid_logits / temperature, dim=-1) # cumulative_probs = torch.cumsum(probs, dim=-1) # random_prob = random.random() # sampled_idx = torch.where(cumulative_probs >= random_prob)[0][0].item() # sampled_index = valid_indices[sampled_idx].item() # elif sampling_technique == "exponential_minimum": # probs = torch.softmax(valid_logits / temperature, dim=-1) # exp_probs = torch.exp(-torch.log(probs)) # random_probs = torch.rand_like(exp_probs) # sampled_idx = torch.argmax(random_probs * exp_probs).item() # sampled_index = valid_indices[sampled_idx].item() # elif sampling_technique == "temperature": # valid_logits = torch.clamp(valid_logits, min=-1e8, max=1e8) # probs = torch.softmax(valid_logits / temperature, dim=-1) # if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)): # raise ValueError("The computed probabilities contain NaN or inf values.") # probs = torch.max(probs, torch.tensor(1e-8)) # probs = probs / torch.sum(probs) # sampled_idx = torch.multinomial(probs, 1)[0].item() # sampled_index = valid_indices[sampled_idx].item() # elif sampling_technique == 'greedy': # sampled_idx = torch.argmax(valid_logits).item() # sampled_index = valid_indices[sampled_idx].item() # else: # raise ValueError(f"Unknown sampling technique: {sampling_technique}") # # Replace mask with sampled token # sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0] # words[mask_pos] = sampled_token # except Exception as e: # print(f"Error sampling for position {mask_pos}: {str(e)}") # continue # return " ".join(words)