jgyasu's picture
Add entire pipeline
060ac52
from transformers import BertTokenizer, BertForMaskedLM
import torch
import random
from masking_methods import MaskingProcessor
from transformers import pipeline
class SamplingProcessorWithModel:
def __init__(self, model_name='bert-base-uncased'):
self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.model = BertForMaskedLM.from_pretrained(model_name)
self.model.eval() # Set the model to evaluation mode
def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
"""
Fills each mask in the masked sentence using the specified sampling technique.
Args:
masked_sentence (str): Sentence with [MASK] tokens.
sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
temperature (float): Temperature parameter for sampling methods.
Returns:
str: Sentence with the masks filled.
"""
input_ids = self.tokenizer.encode(masked_sentence, return_tensors="pt")
while self.tokenizer.mask_token_id in input_ids[0]:
# Find indices of all [MASK] tokens
mask_indices = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
# Process the first [MASK] token in the sequence
mask_index = mask_indices[0].item()
# Get logits from the model
with torch.no_grad():
outputs = self.model(input_ids)
logits = outputs.logits
# Extract logits for the [MASK] token
mask_logits = logits[0, mask_index]
if sampling_technique == "inverse_transform":
probs = torch.softmax(mask_logits / temperature, dim=-1)
cumulative_probs = torch.cumsum(probs, dim=-1)
random_prob = random.random()
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
elif sampling_technique == "exponential_minimum":
probs = torch.softmax(mask_logits / temperature, dim=-1)
exp_probs = torch.exp(-torch.log(probs))
random_probs = torch.rand_like(exp_probs)
sampled_index = torch.argmax(random_probs * exp_probs).item()
elif sampling_technique == "temperature":
mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8)
probs = torch.softmax(mask_logits / temperature, dim=-1)
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
raise ValueError("The computed probabilities contain NaN or inf values.")
probs = torch.max(probs, torch.tensor(1e-8, device=mask_logits.device))
probs = probs / torch.sum(probs)
probs = probs.flatten()
if probs.size(0) > 1:
sampled_index = torch.multinomial(probs, 1).item()
else:
sampled_index = torch.argmax(probs).item()
elif sampling_technique == 'greedy':
sampled_index = torch.argmax(mask_logits).item()
else:
raise ValueError(f"Unknown sampling technique: {sampling_technique}")
# Replace the first [MASK] with the selected token
input_ids[0, mask_index] = sampled_index
return self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
"""
Fills each mask in the masked sentence using the specified sampling technique.
Args:
masked_sentence (str): Sentence with [MASK] tokens.
sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
temperature (float): Temperature parameter for sampling methods.
Returns:
str: Sentence with the masks filled.
"""
while '[MASK]' in masked_sentence:
# Get predictions for the first [MASK]
predictions = self.unmasker(masked_sentence)
# Ensure predictions is a list of dictionaries
if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions):
raise ValueError("Unexpected structure in predictions from the pipeline.")
# Extract logits (scores) from the predictions
logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32)
if sampling_technique == "inverse_transform":
probs = torch.softmax(logits / temperature, dim=-1)
cumulative_probs = torch.cumsum(probs, dim=-1)
random_prob = random.random()
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
elif sampling_technique == "exponential_minimum":
probs = torch.softmax(logits / temperature, dim=-1)
exp_probs = torch.exp(-torch.log(probs))
random_probs = torch.rand_like(exp_probs)
sampled_index = torch.argmax(random_probs * exp_probs).item()
elif sampling_technique == "temperature":
logits = torch.clamp(logits, min=-1e8, max=1e8)
probs = torch.softmax(logits / temperature, dim=-1)
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
raise ValueError("The computed probabilities contain NaN or inf values.")
probs = torch.max(probs, torch.tensor(1e-8, device=logits.device))
probs = probs / torch.sum(probs)
probs = probs.flatten()
if probs.size(0) > 1:
sampled_index = torch.multinomial(probs, 1).item()
else:
sampled_index = torch.argmax(probs).item()
elif sampling_technique == 'greedy':
sampled_index = torch.argmax(logits).item()
else:
raise ValueError(f"Unknown sampling technique: {sampling_technique}")
# Replace the first [MASK] with the selected word
sampled_token = predictions[sampled_index]['token_str']
masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1)
return masked_sentence
# Example usage
if __name__ == "__main__":
from transformers import BertTokenizer
# Define sentences and result_dict
sentences = [
"The quick brown fox jumps over the lazy dog.",
"A quick brown dog outpaces a lazy fox.",
"Quick brown dog leaps over lazy the fox."
]
result_dict = {
"The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]},
"A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]},
"Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}
}
masking_processor = MaskingProcessor()
masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
# Use SamplingProcessor
sampling_processor = SamplingProcessorWithModel()
# Iterate through masking results to apply sampling
for sentence, result in masking_results.items():
print(f"Original Sentence (Random): {sentence}")
print(f"Masked Sentence (Random): {result['masked_sentence']}")
masked_sentence = result["masked_sentence"]
# Apply different sampling techniques
for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
print(f"Sampling Technique: {technique}")
filled_sentence = sampling_processor.fill_masked_sentence(
masked_sentence=masked_sentence,
sampling_technique=technique,
temperature=1.0 # Adjust temperature as needed
)
print(f"Filled Sentence: {filled_sentence}\n")
print('--------------------------------')
# from transformers import pipeline
# import torch
# import random
# from masking_methods import MaskingProcessor
# class SamplingProcessorWithPipeline:
# def __init__(self, model_name='bert-base-uncased'):
# self.unmasker = pipeline('fill-mask', model=model_name)
# self.tokenizer = self.unmasker.tokenizer
# def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
# """
# Fills each mask in the masked sentence using the specified sampling technique.
# Args:
# masked_sentence (str): Sentence with [MASK] tokens.
# sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
# temperature (float): Temperature parameter for sampling methods.
# Returns:
# str: Sentence with the masks filled.
# """
# while '[MASK]' in masked_sentence:
# # Get predictions for the first [MASK]
# predictions = self.unmasker(masked_sentence)
# print(f' predictions : {predictions}')
# print(f' type of predictions : {type(predictions)}')
# # Ensure predictions is a list of dictionaries for the first [MASK]
# if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions):
# raise ValueError("Unexpected structure in predictions from the pipeline.")
# # Extract logits (scores) from the predictions
# logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32)
# if sampling_technique == "inverse_transform":
# probs = torch.softmax(logits / temperature, dim=-1)
# cumulative_probs = torch.cumsum(probs, dim=-1)
# random_prob = random.random()
# sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
# elif sampling_technique == "exponential_minimum":
# probs = torch.softmax(logits / temperature, dim=-1)
# exp_probs = torch.exp(-torch.log(probs))
# random_probs = torch.rand_like(exp_probs)
# sampled_index = torch.argmax(random_probs * exp_probs).item()
# elif sampling_technique == "temperature":
# logits = torch.clamp(logits, min=-1e8, max=1e8)
# probs = torch.softmax(logits / temperature, dim=-1)
# if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
# raise ValueError("The computed probabilities contain NaN or inf values.")
# probs = torch.max(probs, torch.tensor(1e-8, device=logits.device))
# probs = probs / torch.sum(probs)
# probs = probs.flatten()
# if probs.size(0) > 1:
# sampled_index = torch.multinomial(probs, 1).item()
# else:
# sampled_index = torch.argmax(probs).item()
# elif sampling_technique == 'greedy':
# sampled_index = torch.argmax(logits).item()
# else:
# raise ValueError(f"Unknown sampling technique: {sampling_technique}")
# # Replace the first [MASK] with the selected word
# sampled_token = predictions[sampled_index]['token_str']
# masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1)
# return masked_sentence
# # Example usage
# if __name__ == "__main__":
# from transformers import BertTokenizer
# # Define sentences and result_dict
# sentences = [
# "The quick brown fox jumps over the lazy dog.",
# "A quick brown dog outpaces a lazy fox.",
# "Quick brown animals leap over lazy obstacles."
# ]
# result_dict = {
# "The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]},
# "A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]},
# "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]}
# }
# masking_processor = MaskingProcessor()
# masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
# # Use SamplingProcessor
# sampling_processor = SamplingProcessorWithPipeline()
# # Iterate through masking results to apply sampling
# for sentence, result in masking_results.items():
# print(f"Original Sentence (Random): {sentence}")
# print(f"Masked Sentence (Random): {result['masked_sentence']}")
# masked_sentence = result["masked_sentence"]
# # Apply different sampling techniques
# for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
# print(f"Sampling Technique: {technique}")
# filled_sentence = sampling_processor.fill_masked_sentence(
# masked_sentence=masked_sentence,
# sampling_technique=technique,
# temperature=1.0 # Adjust temperature as needed
# )
# print(f"Filled Sentence: {filled_sentence}\n")
# print('--------------------------------')