jgyasu's picture
Add entire pipeline
060ac52
import torch
import random
from masking_methods import MaskingProcessor
import nltk
from nltk.corpus import words
import torch.nn.functional as F
class SamplingProcessor:
def __init__(self, tokenizer):
"""
Initialize the SamplingProcessor.
Args:
tokenizer: BERT tokenizer instance
"""
self.tokenizer = tokenizer
self.subtoken_prefix = self._get_subtoken_prefix()
self.subtoken_ids = self._get_subtoken_ids()
try:
nltk.data.find('corpora/words')
except LookupError:
nltk.download('words')
self.english_words = set(words.words())
# def _get_subtoken_prefix(self):
# """
# Identify the subtoken prefix based on the tokenizer.
# Returns:
# str: The prefix used for subtokens (e.g., "##" for BERT).
# """
# # This method assumes that the tokenizer uses a consistent subtoken prefix.
# # Adjust accordingly if using different tokenizers.
# # For BERT's WordPiece tokenizer:
# if hasattr(self.tokenizer, "init_kwargs") and "wordpiece_prefix" in self.tokenizer.init_kwargs:
# return self.tokenizer.init_kwargs["wordpiece_prefix"]
# elif hasattr(self.tokenizer, "prefix_tokens"):
# return self.tokenizer.prefix_tokens
# else:
# # Default to BERT's subtoken prefix
# return "##"
def _get_subtoken_prefix(self):
"""
Identify the subtoken prefix based on the tokenizer.
Returns:
str: The prefix used for subtokens (e.g., "##" for BERT).
"""
# This method assumes that the tokenizer uses a consistent subtoken prefix.
# Adjust accordingly if using different tokenizers.
# For BERT's WordPiece tokenizer:
if hasattr(self.tokenizer, "init_kwargs") and "wordpiece_prefix" in self.tokenizer.init_kwargs:
return self.tokenizer.init_kwargs["wordpiece_prefix"]
elif hasattr(self.tokenizer, "prefix_tokens"):
return self.tokenizer.prefix_tokens
else:
# Default to BERT's subtoken prefix
return "##"
# def _get_subtoken_ids(self):
# """
# Retrieve all token IDs that correspond to subtokens.
# Returns:
# set: A set of subtoken IDs.
# """
# vocab = self.tokenizer.get_vocab()
# subtoken_ids = set()
# for token, idx in vocab.items():
# if token.startswith(self.subtoken_prefix):
# subtoken_ids.add(idx)
# return subtoken_ids
def _get_subtoken_ids(self):
"""
Retrieve all token IDs that correspond to subtokens.
Returns:
list: A list of subtoken IDs.
"""
vocab = self.tokenizer.get_vocab()
subtoken_ids = []
for token, idx in vocab.items():
if token.startswith(self.subtoken_prefix):
subtoken_ids.append(idx)
return subtoken_ids # Changed from set to list
def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0):
tokens = self.tokenizer.tokenize(masked_sentence)
for mask_pos in sorted(mask_logits_dict.keys()):
try:
# Get logits and squeeze extra dimension
mask_logits = torch.tensor(mask_logits_dict[mask_pos]).squeeze(0) # Remove the extra dimension
# Create a mask for valid tokens (no special tokens, no subwords)
valid_mask = torch.zeros_like(mask_logits, dtype=torch.bool)
for idx in range(len(mask_logits)):
token = self.tokenizer.convert_ids_to_tokens([idx])[0]
# Only allow regular words (no special tokens, no subwords)
if token.isalpha() and not token.startswith('[') and not token.startswith('##'):
valid_mask[idx] = True
# Get valid logits
valid_logits = mask_logits[valid_mask]
valid_indices = torch.where(valid_mask)[0]
if len(valid_logits) == 0:
print(f"Warning: No valid tokens found for position {mask_pos}")
continue
if sampling_technique == "inverse_transform":
probs = torch.softmax(valid_logits / temperature, dim=-1)
cumulative_probs = torch.cumsum(probs, dim=-1)
random_prob = random.random()
sampled_idx = torch.where(cumulative_probs >= random_prob)[0][0].item()
sampled_index = valid_indices[sampled_idx].item()
elif sampling_technique == "exponential_minimum":
probs = torch.softmax(valid_logits / temperature, dim=-1)
exp_probs = torch.exp(-torch.log(probs))
random_probs = torch.rand_like(exp_probs)
sampled_idx = torch.argmax(random_probs * exp_probs).item()
sampled_index = valid_indices[sampled_idx].item()
elif sampling_technique == "temperature":
valid_logits = torch.clamp(valid_logits, min=-1e8, max=1e8)
probs = torch.softmax(valid_logits / temperature, dim=-1)
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
raise ValueError("The computed probabilities contain NaN or inf values.")
probs = torch.max(probs, torch.tensor(1e-8))
probs = probs / torch.sum(probs)
sampled_idx = torch.multinomial(probs, 1)[0].item()
sampled_index = valid_indices[sampled_idx].item()
elif sampling_technique == 'greedy':
sampled_idx = torch.argmax(valid_logits).item()
sampled_index = valid_indices[sampled_idx].item()
# Replace mask with sampled token
sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
tokens[mask_pos] = sampled_token
except Exception as e:
print(f"Error sampling for position {mask_pos}: {str(e)}")
continue
return self.tokenizer.convert_tokens_to_string(tokens)
def process_masked_sentences(self, results_dict, sampling_technique="temperature", temperature=1.0):
"""
Process all masked sentences in the results dictionary.
Args:
results_dict (dict): Dictionary containing masked sentences and their logits
sampling_technique (str): Sampling method to use
temperature (float): Temperature parameter for sampling
Returns:
dict: Dictionary containing original, masked, and sampled sentences
"""
processed_results = {}
for original_sentence, data in results_dict.items():
masked_sentence = data["masked_sentence"]
mask_logits = data["mask_logits"]
sampled_sentence = self.sample_tokens(
mask_logits,
masked_sentence,
sampling_technique,
temperature
)
processed_results[original_sentence] = {
"masked_sentence": masked_sentence,
"sampled_sentence": sampled_sentence
}
return processed_results
if __name__ == "__main__":
sentences = [
"The quick brown fox jumps over the lazy dog everyday.",
]
result_dict = {
'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
}
# First, mask the sentences
masking_processor = MaskingProcessor()
masking_results = masking_processor.process_sentences(sentences, result_dict)
# Then, sample replacements for the masks
sampling_processor = SamplingProcessor(masking_processor.tokenizer)
# Try different sampling techniques
sampling_techniques = ["temperature", "greedy", "inverse_transform", "exponential_minimum"]
for technique in sampling_techniques:
print(f"\nSampling using {technique}:")
sampled_results = sampling_processor.process_masked_sentences(
masking_results,
sampling_technique=technique,
temperature=1.0
)
for original_sentence, result in sampled_results.items():
print(f"Original: {original_sentence}")
print(f"Masked: {result['masked_sentence']}")
print(f"Sampled: {result['sampled_sentence']}")
print("---")
# --------------------------------------------------------------------------------------------------
# def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0, top_k=100):
# words = masked_sentence.split()
# mask_positions = sorted(mask_logits_dict.keys())
# for mask_pos in mask_positions:
# mask_logits = torch.tensor(mask_logits_dict[mask_pos])
# try:
# if sampling_technique == "inverse_transform":
# probs = torch.softmax(mask_logits / temperature, dim=-1)
# cumulative_probs = torch.cumsum(probs, dim=-1)
# random_prob = random.random()
# sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
# elif sampling_technique == "exponential_minimum":
# probs = torch.softmax(mask_logits / temperature, dim=-1)
# exp_probs = torch.exp(-torch.log(probs))
# random_probs = torch.rand_like(exp_probs)
# sampled_index = torch.argmax(random_probs * exp_probs).item()
# elif sampling_technique == "temperature":
# mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8)
# probs = torch.softmax(mask_logits / temperature, dim=-1)
# if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
# raise ValueError("The computed probabilities contain NaN or inf values.")
# probs = torch.max(probs, torch.tensor(1e-8))
# probs = probs / torch.sum(probs)
# sampled_index = torch.multinomial(probs, 1)[0].item()
# elif sampling_technique == 'greedy':
# sampled_index = torch.argmax(mask_logits).item()
# else:
# raise ValueError(f"Unknown sampling technique: {sampling_technique}")
# # Replace mask with sampled token
# sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
# words[mask_pos] = sampled_token
# except Exception as e:
# print(f"Error sampling for position {mask_pos}: {str(e)}")
# continue
# return " ".join(words)
## MORE WEIRD RESULTS
# def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0, top_k=100):
# words = masked_sentence.split()
# mask_positions = sorted(mask_logits_dict.keys())
# for mask_pos in mask_positions:
# mask_logits = torch.tensor(mask_logits_dict[mask_pos])
# try:
# # Create a mask for valid tokens (no special tokens, no subwords)
# valid_mask = torch.zeros_like(mask_logits, dtype=torch.bool)
# for idx in range(len(mask_logits)):
# token = self.tokenizer.convert_ids_to_tokens([idx])[0]
# # Only allow regular words (no special tokens, no subwords)
# if token.isalpha() and not token.startswith('[') and not token.startswith('##'):
# valid_mask[idx] = True
# # Get valid logits
# valid_logits = mask_logits[valid_mask]
# valid_indices = torch.where(valid_mask)[0]
# if len(valid_logits) == 0:
# print(f"Warning: No valid tokens found for position {mask_pos}")
# continue
# if sampling_technique == "inverse_transform":
# probs = torch.softmax(valid_logits / temperature, dim=-1)
# cumulative_probs = torch.cumsum(probs, dim=-1)
# random_prob = random.random()
# sampled_idx = torch.where(cumulative_probs >= random_prob)[0][0].item()
# sampled_index = valid_indices[sampled_idx].item()
# elif sampling_technique == "exponential_minimum":
# probs = torch.softmax(valid_logits / temperature, dim=-1)
# exp_probs = torch.exp(-torch.log(probs))
# random_probs = torch.rand_like(exp_probs)
# sampled_idx = torch.argmax(random_probs * exp_probs).item()
# sampled_index = valid_indices[sampled_idx].item()
# elif sampling_technique == "temperature":
# valid_logits = torch.clamp(valid_logits, min=-1e8, max=1e8)
# probs = torch.softmax(valid_logits / temperature, dim=-1)
# if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
# raise ValueError("The computed probabilities contain NaN or inf values.")
# probs = torch.max(probs, torch.tensor(1e-8))
# probs = probs / torch.sum(probs)
# sampled_idx = torch.multinomial(probs, 1)[0].item()
# sampled_index = valid_indices[sampled_idx].item()
# elif sampling_technique == 'greedy':
# sampled_idx = torch.argmax(valid_logits).item()
# sampled_index = valid_indices[sampled_idx].item()
# else:
# raise ValueError(f"Unknown sampling technique: {sampling_technique}")
# # Replace mask with sampled token
# sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
# words[mask_pos] = sampled_token
# except Exception as e:
# print(f"Error sampling for position {mask_pos}: {str(e)}")
# continue
# return " ".join(words)