jgyasu's picture
Add entire pipeline
060ac52
import random
import torch
from transformers import BertTokenizer, BertForMaskedLM
from nltk.corpus import stopwords
import nltk
# Ensure stopwords are downloaded
try:
nltk.data.find('corpora/stopwords')
except LookupError:
nltk.download('stopwords')
class MaskingProcessor:
def __init__(self, ):
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
self.stop_words = set(stopwords.words('english'))
def adjust_ngram_indices(self, words, common_ngrams, remove_stopwords):
"""
Adjust indices of common n-grams after removing stop words.
Args:
words (list): List of words in the original sentence.
common_ngrams (dict): Common n-grams and their indices.
Returns:
dict: Adjusted common n-grams and their indices.
"""
if not remove_stopwords:
return common_ngrams
non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
adjusted_ngrams = {}
for ngram, positions in common_ngrams.items():
adjusted_positions = []
for start, end in positions:
try:
new_start = non_stop_word_indices.index(start)
new_end = non_stop_word_indices.index(end)
adjusted_positions.append((new_start, new_end))
except ValueError:
continue # Skip if indices cannot be mapped
adjusted_ngrams[ngram] = adjusted_positions
return adjusted_ngrams
# def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords):
# """
# Mask one word before the first common n-gram, one between two n-grams,
# and one after the last common n-gram (random selection).
# Args:
# original_sentence (str): Original sentence
# common_ngrams (dict): Common n-grams and their indices
# Returns:
# str: Masked sentence with original stop words retained
# """
# words = original_sentence.split()
# if remove_stopwords:
# non_stop_words = [word for word in words if word.lower() not in self.stop_words]
# non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
# else:
# non_stop_words = words
# non_stop_word_indices = list(range(len(words)))
# # non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
# adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
# mask_indices = []
# # Handle before the first common n-gram
# if adjusted_ngrams:
# first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
# if first_ngram_start > 0:
# mask_indices.append(random.randint(0, first_ngram_start - 1))
# # Handle between common n-grams
# ngram_positions = list(adjusted_ngrams.values())
# for i in range(len(ngram_positions) - 1):
# end_prev = ngram_positions[i][-1][1]
# start_next = ngram_positions[i + 1][0][0]
# if start_next > end_prev + 1:
# mask_indices.append(random.randint(end_prev + 1, start_next - 1))
# # Handle after the last common n-gram
# last_ngram_end = ngram_positions[-1][-1][1]
# if last_ngram_end < len(non_stop_words) - 1:
# mask_indices.append(random.randint(last_ngram_end + 1, len(non_stop_words) - 1))
# # Mask the chosen indices
# original_masked_sentence = words[:]
# # for idx in mask_indices:
# # if idx not in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
# # non_stop_words[idx] = self.tokenizer.mask_token
# # original_masked_sentence[idx] = self.tokenizer.mask_token
# for idx in mask_indices:
# if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
# continue # Skip if index belongs to common n-grams
# if remove_stopwords:
# original_idx = non_stop_word_indices[idx] # Map back to original indices
# original_masked_sentence[original_idx] = self.tokenizer.mask_token
# else:
# original_masked_sentence[idx] = self.tokenizer.mask_token
# return " ".join(original_masked_sentence)
def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords):
"""
Mask one word before the first common n-gram, one between two n-grams,
and one after the last common n-gram (random selection).
Args:
original_sentence (str): Original sentence
common_ngrams (dict): Common n-grams and their indices
remove_stopwords (bool): Whether to remove stop words
Returns:
str: Masked sentence with original stop words retained
"""
words = original_sentence.split()
if remove_stopwords:
non_stop_words = [word for word in words if word.lower() not in self.stop_words]
non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
else:
non_stop_words = words
non_stop_word_indices = list(range(len(words)))
adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
# Collect all indices corresponding to common n-grams
common_ngram_indices = {
idx for ngram_positions in adjusted_ngrams.values()
for start, end in ngram_positions
for idx in range(start, end + 1)
}
mask_indices = []
# Handle before the first common n-gram
if adjusted_ngrams:
first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
if first_ngram_start > 0:
potential_indices = [i for i in range(first_ngram_start) if i not in common_ngram_indices]
if potential_indices:
mask_indices.append(random.choice(potential_indices))
# Handle between common n-grams
ngram_positions = list(adjusted_ngrams.values())
for i in range(len(ngram_positions) - 1):
end_prev = ngram_positions[i][-1][1]
start_next = ngram_positions[i + 1][0][0]
potential_indices = [i for i in range(end_prev + 1, start_next) if i not in common_ngram_indices]
if potential_indices:
mask_indices.append(random.choice(potential_indices))
# Handle after the last common n-gram
last_ngram_end = ngram_positions[-1][-1][1]
if last_ngram_end < len(non_stop_words) - 1:
potential_indices = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i not in common_ngram_indices]
if potential_indices:
mask_indices.append(random.choice(potential_indices))
# Mask the chosen indices
original_masked_sentence = words[:]
for idx in mask_indices:
if remove_stopwords:
original_idx = non_stop_word_indices[idx] # Map back to original indices
original_masked_sentence[original_idx] = self.tokenizer.mask_token
else:
original_masked_sentence[idx] = self.tokenizer.mask_token
return " ".join(original_masked_sentence)
def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords):
"""
Mask one word before the first common n-gram, one between two n-grams,
and one after the last common n-gram (highest entropy selection).
Args:
original_sentence (str): Original sentence
common_ngrams (dict): Common n-grams and their indices
Returns:
str: Masked sentence with original stop words retained
"""
words = original_sentence.split()
# non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
if remove_stopwords:
non_stop_words = [word for word in words if word.lower() not in self.stop_words]
non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
else:
non_stop_words = words
non_stop_word_indices = list(range(len(words)))
adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
entropy_scores = {}
for idx, word in enumerate(non_stop_words):
if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
continue # Skip words in common n-grams
masked_sentence = non_stop_words[:idx] + [self.tokenizer.mask_token] + non_stop_words[idx + 1:]
masked_sentence = " ".join(masked_sentence)
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
with torch.no_grad():
outputs = self.model(input_ids)
logits = outputs.logits
filtered_logits = logits[0, mask_token_index, :]
probs = torch.softmax(filtered_logits, dim=-1)
entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0)
entropy_scores[idx] = entropy
mask_indices = []
# Handle before the first common n-gram
if adjusted_ngrams:
first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
candidates = [i for i in range(first_ngram_start) if i in entropy_scores]
if candidates:
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
# Handle between common n-grams
ngram_positions = list(adjusted_ngrams.values())
for i in range(len(ngram_positions) - 1):
end_prev = ngram_positions[i][-1][1]
start_next = ngram_positions[i + 1][0][0]
candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores]
if candidates:
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
# Handle after the last common n-gram
last_ngram_end = ngram_positions[-1][-1][1]
candidates = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i in entropy_scores]
if candidates:
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
# Mask the chosen indices
original_masked_sentence = words[:]
# for idx in mask_indices:
# non_stop_words[idx] = self.tokenizer.mask_token
# original_masked_sentence[idx] = self.tokenizer.mask_token
for idx in mask_indices:
if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
continue # Skip if index belongs to common n-grams
if remove_stopwords:
original_idx = non_stop_word_indices[idx] # Map back to original indices
original_masked_sentence[original_idx] = self.tokenizer.mask_token
else:
original_masked_sentence[idx] = self.tokenizer.mask_token
return " ".join(original_masked_sentence)
def calculate_mask_logits(self, masked_sentence):
"""
Calculate logits for masked tokens in the sentence using BERT.
Args:
masked_sentence (str): Sentence with [MASK] tokens
Returns:
dict: Masked token indices and their logits
"""
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
with torch.no_grad():
outputs = self.model(input_ids)
logits = outputs.logits
mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
return mask_logits
def process_sentences(self, original_sentences, result_dict, method="random", remove_stopwords=False):
"""
Process a list of sentences and calculate logits for masked tokens using the specified method.
Args:
original_sentences (list): List of original sentences
result_dict (dict): Common n-grams and their indices for each sentence
method (str): Masking method ("random" or "entropy")
Returns:
dict: Masked sentences and their logits for each sentence
"""
results = {}
for sentence, ngrams in result_dict.items():
if method == "random":
masked_sentence = self.mask_sentence_random(sentence, ngrams, remove_stopwords)
elif method == "entropy":
masked_sentence = self.mask_sentence_entropy(sentence, ngrams, remove_stopwords)
else:
raise ValueError("Invalid method. Choose 'random' or 'entropy'.")
logits = self.calculate_mask_logits(masked_sentence)
results[sentence] = {
"masked_sentence": masked_sentence,
"mask_logits": logits
}
return results
# Example usage
if __name__ == "__main__":
# !!! Working both the cases regardless if the stopword is removed or not
sentences = [
"The quick brown fox jumps over the lazy dog.",
"A speedy brown fox jumps over a lazy dog.",
"A swift brown fox leaps over the lethargic dog."
]
result_dict ={
'The quick brown fox jumps over the lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
}
processor = MaskingProcessor()
results_random = processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=True)
# results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
for sentence, output in results_random.items():
print(f"Original Sentence (Random): {sentence}")
print(f"Masked Sentence (Random): {output['masked_sentence']}")
# # print(f"Mask Logits (Random): {output['mask_logits']}")
# print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
# print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
# print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
print('--------------------------------')
# for mask_idx, logits in output["mask_logits"].items():
# print(f"Logits for [MASK] at position {mask_idx}:")
# print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
# result_dict = {
# "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
# "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
# "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
# }
# print('--------------------------------')
# for sentence, output in results_entropy.items():
# print(f"Original Sentence (Entropy): {sentence}")
# print(f"Masked Sentence (Entropy): {output['masked_sentence']}")
# # print(f"Mask Logits (Entropy): {output['mask_logits']}")
# print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
# print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
# print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')