import torch | |
import random | |
from masking_methods import MaskingProcessor | |
import nltk | |
from nltk.corpus import words | |
import torch.nn.functional as F | |
class SamplingProcessor: | |
def __init__(self, tokenizer): | |
""" | |
Initialize the SamplingProcessor. | |
Args: | |
tokenizer: BERT tokenizer instance | |
""" | |
self.tokenizer = tokenizer | |
self.subtoken_prefix = self._get_subtoken_prefix() | |
self.subtoken_ids = self._get_subtoken_ids() | |
try: | |
nltk.data.find('corpora/words') | |
except LookupError: | |
nltk.download('words') | |
self.english_words = set(words.words()) | |
# def _get_subtoken_prefix(self): | |
# """ | |
# Identify the subtoken prefix based on the tokenizer. | |
# Returns: | |
# str: The prefix used for subtokens (e.g., "##" for BERT). | |
# """ | |
# # This method assumes that the tokenizer uses a consistent subtoken prefix. | |
# # Adjust accordingly if using different tokenizers. | |
# # For BERT's WordPiece tokenizer: | |
# if hasattr(self.tokenizer, "init_kwargs") and "wordpiece_prefix" in self.tokenizer.init_kwargs: | |
# return self.tokenizer.init_kwargs["wordpiece_prefix"] | |
# elif hasattr(self.tokenizer, "prefix_tokens"): | |
# return self.tokenizer.prefix_tokens | |
# else: | |
# # Default to BERT's subtoken prefix | |
# return "##" | |
def _get_subtoken_prefix(self): | |
""" | |
Identify the subtoken prefix based on the tokenizer. | |
Returns: | |
str: The prefix used for subtokens (e.g., "##" for BERT). | |
""" | |
# This method assumes that the tokenizer uses a consistent subtoken prefix. | |
# Adjust accordingly if using different tokenizers. | |
# For BERT's WordPiece tokenizer: | |
if hasattr(self.tokenizer, "init_kwargs") and "wordpiece_prefix" in self.tokenizer.init_kwargs: | |
return self.tokenizer.init_kwargs["wordpiece_prefix"] | |
elif hasattr(self.tokenizer, "prefix_tokens"): | |
return self.tokenizer.prefix_tokens | |
else: | |
# Default to BERT's subtoken prefix | |
return "##" | |
# def _get_subtoken_ids(self): | |
# """ | |
# Retrieve all token IDs that correspond to subtokens. | |
# Returns: | |
# set: A set of subtoken IDs. | |
# """ | |
# vocab = self.tokenizer.get_vocab() | |
# subtoken_ids = set() | |
# for token, idx in vocab.items(): | |
# if token.startswith(self.subtoken_prefix): | |
# subtoken_ids.add(idx) | |
# return subtoken_ids | |
def _get_subtoken_ids(self): | |
""" | |
Retrieve all token IDs that correspond to subtokens. | |
Returns: | |
list: A list of subtoken IDs. | |
""" | |
vocab = self.tokenizer.get_vocab() | |
subtoken_ids = [] | |
for token, idx in vocab.items(): | |
if token.startswith(self.subtoken_prefix): | |
subtoken_ids.append(idx) | |
return subtoken_ids # Changed from set to list | |
def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0): | |
tokens = self.tokenizer.tokenize(masked_sentence) | |
for mask_pos in sorted(mask_logits_dict.keys()): | |
try: | |
# Get logits and squeeze extra dimension | |
mask_logits = torch.tensor(mask_logits_dict[mask_pos]).squeeze(0) # Remove the extra dimension | |
# Create a mask for valid tokens (no special tokens, no subwords) | |
valid_mask = torch.zeros_like(mask_logits, dtype=torch.bool) | |
for idx in range(len(mask_logits)): | |
token = self.tokenizer.convert_ids_to_tokens([idx])[0] | |
# Only allow regular words (no special tokens, no subwords) | |
if token.isalpha() and not token.startswith('[') and not token.startswith('##'): | |
valid_mask[idx] = True | |
# Get valid logits | |
valid_logits = mask_logits[valid_mask] | |
valid_indices = torch.where(valid_mask)[0] | |
if len(valid_logits) == 0: | |
print(f"Warning: No valid tokens found for position {mask_pos}") | |
continue | |
if sampling_technique == "inverse_transform": | |
probs = torch.softmax(valid_logits / temperature, dim=-1) | |
cumulative_probs = torch.cumsum(probs, dim=-1) | |
random_prob = random.random() | |
sampled_idx = torch.where(cumulative_probs >= random_prob)[0][0].item() | |
sampled_index = valid_indices[sampled_idx].item() | |
elif sampling_technique == "exponential_minimum": | |
probs = torch.softmax(valid_logits / temperature, dim=-1) | |
exp_probs = torch.exp(-torch.log(probs)) | |
random_probs = torch.rand_like(exp_probs) | |
sampled_idx = torch.argmax(random_probs * exp_probs).item() | |
sampled_index = valid_indices[sampled_idx].item() | |
elif sampling_technique == "temperature": | |
valid_logits = torch.clamp(valid_logits, min=-1e8, max=1e8) | |
probs = torch.softmax(valid_logits / temperature, dim=-1) | |
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)): | |
raise ValueError("The computed probabilities contain NaN or inf values.") | |
probs = torch.max(probs, torch.tensor(1e-8)) | |
probs = probs / torch.sum(probs) | |
sampled_idx = torch.multinomial(probs, 1)[0].item() | |
sampled_index = valid_indices[sampled_idx].item() | |
elif sampling_technique == 'greedy': | |
sampled_idx = torch.argmax(valid_logits).item() | |
sampled_index = valid_indices[sampled_idx].item() | |
# Replace mask with sampled token | |
sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0] | |
tokens[mask_pos] = sampled_token | |
except Exception as e: | |
print(f"Error sampling for position {mask_pos}: {str(e)}") | |
continue | |
return self.tokenizer.convert_tokens_to_string(tokens) | |
def process_masked_sentences(self, results_dict, sampling_technique="temperature", temperature=1.0): | |
""" | |
Process all masked sentences in the results dictionary. | |
Args: | |
results_dict (dict): Dictionary containing masked sentences and their logits | |
sampling_technique (str): Sampling method to use | |
temperature (float): Temperature parameter for sampling | |
Returns: | |
dict: Dictionary containing original, masked, and sampled sentences | |
""" | |
processed_results = {} | |
for original_sentence, data in results_dict.items(): | |
masked_sentence = data["masked_sentence"] | |
mask_logits = data["mask_logits"] | |
sampled_sentence = self.sample_tokens( | |
mask_logits, | |
masked_sentence, | |
sampling_technique, | |
temperature | |
) | |
processed_results[original_sentence] = { | |
"masked_sentence": masked_sentence, | |
"sampled_sentence": sampled_sentence | |
} | |
return processed_results | |
if __name__ == "__main__": | |
sentences = [ | |
"The quick brown fox jumps over the lazy dog everyday.", | |
] | |
result_dict = { | |
'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, | |
} | |
# First, mask the sentences | |
masking_processor = MaskingProcessor() | |
masking_results = masking_processor.process_sentences(sentences, result_dict) | |
# Then, sample replacements for the masks | |
sampling_processor = SamplingProcessor(masking_processor.tokenizer) | |
# Try different sampling techniques | |
sampling_techniques = ["temperature", "greedy", "inverse_transform", "exponential_minimum"] | |
for technique in sampling_techniques: | |
print(f"\nSampling using {technique}:") | |
sampled_results = sampling_processor.process_masked_sentences( | |
masking_results, | |
sampling_technique=technique, | |
temperature=1.0 | |
) | |
for original_sentence, result in sampled_results.items(): | |
print(f"Original: {original_sentence}") | |
print(f"Masked: {result['masked_sentence']}") | |
print(f"Sampled: {result['sampled_sentence']}") | |
print("---") | |
# -------------------------------------------------------------------------------------------------- | |
# def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0, top_k=100): | |
# words = masked_sentence.split() | |
# mask_positions = sorted(mask_logits_dict.keys()) | |
# for mask_pos in mask_positions: | |
# mask_logits = torch.tensor(mask_logits_dict[mask_pos]) | |
# try: | |
# if sampling_technique == "inverse_transform": | |
# probs = torch.softmax(mask_logits / temperature, dim=-1) | |
# cumulative_probs = torch.cumsum(probs, dim=-1) | |
# random_prob = random.random() | |
# sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item() | |
# elif sampling_technique == "exponential_minimum": | |
# probs = torch.softmax(mask_logits / temperature, dim=-1) | |
# exp_probs = torch.exp(-torch.log(probs)) | |
# random_probs = torch.rand_like(exp_probs) | |
# sampled_index = torch.argmax(random_probs * exp_probs).item() | |
# elif sampling_technique == "temperature": | |
# mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8) | |
# probs = torch.softmax(mask_logits / temperature, dim=-1) | |
# if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)): | |
# raise ValueError("The computed probabilities contain NaN or inf values.") | |
# probs = torch.max(probs, torch.tensor(1e-8)) | |
# probs = probs / torch.sum(probs) | |
# sampled_index = torch.multinomial(probs, 1)[0].item() | |
# elif sampling_technique == 'greedy': | |
# sampled_index = torch.argmax(mask_logits).item() | |
# else: | |
# raise ValueError(f"Unknown sampling technique: {sampling_technique}") | |
# # Replace mask with sampled token | |
# sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0] | |
# words[mask_pos] = sampled_token | |
# except Exception as e: | |
# print(f"Error sampling for position {mask_pos}: {str(e)}") | |
# continue | |
# return " ".join(words) | |
## MORE WEIRD RESULTS | |
# def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0, top_k=100): | |
# words = masked_sentence.split() | |
# mask_positions = sorted(mask_logits_dict.keys()) | |
# for mask_pos in mask_positions: | |
# mask_logits = torch.tensor(mask_logits_dict[mask_pos]) | |
# try: | |
# # Create a mask for valid tokens (no special tokens, no subwords) | |
# valid_mask = torch.zeros_like(mask_logits, dtype=torch.bool) | |
# for idx in range(len(mask_logits)): | |
# token = self.tokenizer.convert_ids_to_tokens([idx])[0] | |
# # Only allow regular words (no special tokens, no subwords) | |
# if token.isalpha() and not token.startswith('[') and not token.startswith('##'): | |
# valid_mask[idx] = True | |
# # Get valid logits | |
# valid_logits = mask_logits[valid_mask] | |
# valid_indices = torch.where(valid_mask)[0] | |
# if len(valid_logits) == 0: | |
# print(f"Warning: No valid tokens found for position {mask_pos}") | |
# continue | |
# if sampling_technique == "inverse_transform": | |
# probs = torch.softmax(valid_logits / temperature, dim=-1) | |
# cumulative_probs = torch.cumsum(probs, dim=-1) | |
# random_prob = random.random() | |
# sampled_idx = torch.where(cumulative_probs >= random_prob)[0][0].item() | |
# sampled_index = valid_indices[sampled_idx].item() | |
# elif sampling_technique == "exponential_minimum": | |
# probs = torch.softmax(valid_logits / temperature, dim=-1) | |
# exp_probs = torch.exp(-torch.log(probs)) | |
# random_probs = torch.rand_like(exp_probs) | |
# sampled_idx = torch.argmax(random_probs * exp_probs).item() | |
# sampled_index = valid_indices[sampled_idx].item() | |
# elif sampling_technique == "temperature": | |
# valid_logits = torch.clamp(valid_logits, min=-1e8, max=1e8) | |
# probs = torch.softmax(valid_logits / temperature, dim=-1) | |
# if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)): | |
# raise ValueError("The computed probabilities contain NaN or inf values.") | |
# probs = torch.max(probs, torch.tensor(1e-8)) | |
# probs = probs / torch.sum(probs) | |
# sampled_idx = torch.multinomial(probs, 1)[0].item() | |
# sampled_index = valid_indices[sampled_idx].item() | |
# elif sampling_technique == 'greedy': | |
# sampled_idx = torch.argmax(valid_logits).item() | |
# sampled_index = valid_indices[sampled_idx].item() | |
# else: | |
# raise ValueError(f"Unknown sampling technique: {sampling_technique}") | |
# # Replace mask with sampled token | |
# sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0] | |
# words[mask_pos] = sampled_token | |
# except Exception as e: | |
# print(f"Error sampling for position {mask_pos}: {str(e)}") | |
# continue | |
# return " ".join(words) |