jgyasu's picture
Add entire pipeline
060ac52
import torch
import random
import logging
from utils.masking_methods import MaskingProcessor
from tqdm import tqdm
# Configure logging to suppress INFO-level messages on the console.
logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class SamplingProcessor:
def __init__(self, tokenizer):
"""
Initialize the SamplingProcessor.
Args:
tokenizer: BERT tokenizer instance
"""
self.tokenizer = tokenizer
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tqdm.write(f"[SamplingProcessor] Initialized on device: {self.device}")
def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0):
"""
Sample tokens for each mask in the sentence using the specified sampling technique.
Args:
mask_logits_dict (dict): Dictionary of mask positions and their logits/tokens
masked_sentence (str): Sentence with [MASK] tokens
sampling_technique (str): Sampling method to use
temperature (float): Temperature parameter for sampling
Returns:
str: Sentence with sampled tokens replacing masks
"""
tqdm.write(f"[SamplingProcessor] Sampling tokens for: {masked_sentence}")
print(f"[SamplingProcessor] Sampling tokens for: {masked_sentence}")
words = masked_sentence.split()
print(f"words: {words}")
# Convert positions and logits to sorted list to process masks in order
mask_positions = sorted(mask_logits_dict.keys())
print(f"mask_positions: {mask_positions}")
for mask_pos in mask_positions:
mask_data = mask_logits_dict[mask_pos]
# Move logits tensor to GPU
mask_logits = torch.tensor(mask_data['logits']).to(self.device)
candidate_tokens = mask_data['tokens']
try:
if sampling_technique == "inverse_transform":
probs = torch.softmax(mask_logits / temperature, dim=-1)
cumulative_probs = torch.cumsum(probs, dim=-1)
random_prob = random.random()
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
elif sampling_technique == "exponential_minimum":
probs = torch.softmax(mask_logits / temperature, dim=-1)
exp_probs = torch.exp(-torch.log(probs))
random_probs = torch.rand_like(exp_probs)
sampled_index = torch.argmax(random_probs * exp_probs).item()
elif sampling_technique == "temperature":
mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8)
probs = torch.softmax(mask_logits / temperature, dim=-1)
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
raise ValueError("The computed probabilities contain NaN or inf values.")
probs = torch.max(probs, torch.tensor(1e-8).to(self.device))
probs = probs / torch.sum(probs)
probs = probs.flatten()
if probs.size(0) > 1:
sampled_index = torch.multinomial(probs, 1).item()
else:
sampled_index = torch.argmax(probs).item()
elif sampling_technique == 'greedy':
sampled_index = torch.argmax(mask_logits).item()
else:
raise ValueError(f"Unknown sampling technique: {sampling_technique}")
# Use the sampled index to get the corresponding token
sampled_token = candidate_tokens[sampled_index]
# Remove ## if it's a subword token
sampled_token = sampled_token.replace('##', '')
words[mask_pos] = sampled_token
logger.info(f"Sampled token '{sampled_token}' for mask position {mask_pos}.")
except Exception as e:
logger.error(f"Error sampling for position {mask_pos}: {str(e)}")
continue
sampled_sentence = " ".join(words)
tqdm.write(f"[SamplingProcessor] Sampled sentence: {sampled_sentence}")
return sampled_sentence
def process_masked_sentences(self, results_dict, sampling_technique="temperature", temperature=1.0):
"""
Process all masked sentences in the results dictionary.
Args:
results_dict (dict): Dictionary containing masked sentences and their logits
sampling_technique (str): Sampling method to use
temperature (float): Temperature parameter for sampling
Returns:
dict: Dictionary containing original, masked, and sampled sentences
"""
tqdm.write("[SamplingProcessor] Starting sampling for masked sentences.")
processed_results = {}
# Wrap the iteration over each original sentence with tqdm
for original_sentence, data in tqdm(results_dict.items(), desc="Sampling Masked Sentences"):
masked_sentence = data["masked_sentence"]
mask_logits = data["mask_logits"]
sampled_sentence = self.sample_tokens(mask_logits,
masked_sentence,
sampling_technique,
temperature)
processed_results[original_sentence] = {
"masked_sentence": masked_sentence,
"sampled_sentence": sampled_sentence
}
logger.info(f"Processed sampling for sentence: {original_sentence}")
tqdm.write("[SamplingProcessor] Completed sampling for all sentences.")
return processed_results
if __name__ == "__main__":
sentences = [
"The quick brown fox jumps over the lazy dog everyday.",
"A speedy brown fox jumps over a lazy dog.",
"A swift brown fox leaps over the lethargic dog."
]
result_dict = {
'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
}
# First, mask the sentences
masking_processor = MaskingProcessor()
masking_results = masking_processor.process_sentences(sentences, result_dict)
# Then, sample replacements for the masks
sampling_processor = SamplingProcessor(masking_processor.tokenizer)
# Try different sampling techniques
sampling_techniques = ["temperature", "greedy", "inverse_transform", "exponential_minimum"]
for technique in sampling_techniques:
logger.info(f"Sampling using technique: {technique}")
sampled_results = sampling_processor.process_masked_sentences(
masking_results,
sampling_technique=technique,
temperature=1.0
)
'''
{
"original_sentence_1":
{
"masked_sentence": "sentence with [MASK] tokens",
"sampling_method1": "sentence with sampled tokens",
},
"original_sentence_2":
{
"masked_sentence": "sentence with [MASK] tokens",
"sampling_method": "sentence with sampled tokens"
},
# ... and so on for each input sentence
},
'''
for original_sentence, result in sampled_results.items():
logger.info(f"Original: {original_sentence}")
logger.info(f"Masked: {result['masked_sentence']}")
logger.info(f"Sampled: {result['sampled_sentence']}")
logger.info("---")