jgyasu's picture
Add entire pipeline
060ac52
import torch
import random
from masking_methods import MaskingProcessor
class SamplingProcessor:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def fill_masked_sentence(self, original_sentence, mask_logits, sampling_technique, temperature=1.0):
"""
Fills each mask in the masked sentence using the specified sampling technique.
Args:
original_sentence (str): The original masked sentence.
mask_logits (dict): Logits for each [MASK] token.
sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
temperature (float): Temperature parameter for sampling methods.
Returns:
str: Sentence with the masks filled.
"""
sentence_tokens = self.tokenizer.tokenize(original_sentence)
mask_token_indices = [i for i, token in enumerate(sentence_tokens) if token == self.tokenizer.mask_token]
if len(mask_token_indices) != len(mask_logits):
raise ValueError("Mismatch between number of [MASK] tokens and logits provided.")
for mask_idx, filtered_logits in zip(mask_token_indices, mask_logits.values()):
# Convert logits to a tensor
filtered_logits = torch.tensor(filtered_logits)
# filtered_logits, _ = torch.sort(filtered_logits, descending=True)
# print(f' type of filtered_logits : {type(filtered_logits)}')
# filtered_logits = filtered_logits[:5]
if sampling_technique == "inverse_transform":
probs = torch.softmax(filtered_logits / temperature, dim=-1)
cumulative_probs = torch.cumsum(probs, dim=-1)
random_prob = random.random()
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
elif sampling_technique == "exponential_minimum":
probs = torch.softmax(filtered_logits / temperature, dim=-1)
exp_probs = torch.exp(-torch.log(probs))
random_probs = torch.rand_like(exp_probs)
sampled_index = torch.argmax(random_probs * exp_probs).item()
elif sampling_technique == "temperature":
filtered_logits = torch.clamp(filtered_logits, min=-1e8, max=1e8)
probs = torch.softmax(filtered_logits / temperature, dim=-1)
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
raise ValueError("The computed probabilities contain NaN or inf values.")
probs = torch.max(probs, torch.tensor(1e-8, device=filtered_logits.device))
probs = probs / torch.sum(probs)
probs = probs.flatten()
if probs.size(0) > 1:
sampled_index = torch.multinomial(probs, 1).item()
else:
sampled_index = torch.argmax(probs).item()
elif sampling_technique == 'greedy':
sampled_index = torch.argmax(filtered_logits).item()
else:
raise ValueError(f"Unknown sampling technique: {sampling_technique}")
sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
sentence_tokens[mask_idx] = sampled_token
return self.tokenizer.convert_tokens_to_string(sentence_tokens)
def process_samples(self, masked_sentences, mask_logits, sampling_technique, temperature=1.0):
"""
Process multiple masked sentences and fill their masks using the specified sampling technique.
Args:
masked_sentences (list): List of masked sentences.
mask_logits (dict): Logits for each [MASK] token in each sentence.
sampling_technique (str): Sampling technique to use.
temperature (float): Temperature parameter for sampling methods.
Returns:
list: List of sentences with masks filled.
"""
filled_sentences = []
for sentence, logits in zip(masked_sentences, mask_logits):
filled_sentence = self.fill_masked_sentence(sentence, logits, sampling_technique, temperature)
filled_sentences.append(filled_sentence)
return filled_sentences
# Example usage
if __name__ == "__main__":
from transformers import BertTokenizer
# tokenizer = BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
processor = SamplingProcessor(tokenizer)
sentences = [
"The quick brown fox jumps over the lazy dog.",
"A quick brown dog outpaces a lazy fox.",
"Quick brown dog leaps over lazy the fox."
]
result_dict = {
"The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]},
"A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]},
"Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}
}
masking_processor = MaskingProcessor()
masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
# masked_sentence = "The [MASK] brown fox jumps [MASK] the lazy dog."
# mask_logits = {
# 1: torch.randn(len(tokenizer)), # Example logits for first [MASK]
# 5: torch.randn(len(tokenizer)), # Example logits for second [MASK]
# }
# Iterate through masking results to apply sampling
for sentence, result in masking_results.items():
print(f"Original Sentence (Random): {sentence}")
print(f"Masked Sentence (Random): {result['masked_sentence']}")
# print(f"Mask Logits (Random): {output['mask_logits']}")
print(f' type(result["mask_logits"]) : {type(result["mask_logits"])}')
print(f' length of result["mask_logits"] : {len(result["mask_logits"])}')
print(f' result["mask_logits"].keys() : {result["mask_logits"].keys()}')
masked_sentence = result["masked_sentence"]
mask_logits = result["mask_logits"]
print(f"Original Masked Sentence: {masked_sentence}")
# Apply different sampling techniques
for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
print(f"Sampling Technique: {technique}")
# Fill the masks using the sampling processor
filled_sentence = processor.fill_masked_sentence(
original_sentence=masked_sentence,
mask_logits=mask_logits,
sampling_technique=technique,
temperature=1.0 # Adjust temperature as needed
)
print(f"Filled Sentence: {filled_sentence}\n")
print('--------------------------------')