import random import torch from transformers import BertTokenizer, BertForMaskedLM from nltk.corpus import stopwords import nltk # Ensure stopwords are downloaded try: nltk.data.find('corpora/stopwords') except LookupError: nltk.download('stopwords') class MaskingProcessor: def __init__(self, ): self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") self.model = BertForMaskedLM.from_pretrained("bert-base-uncased") self.stop_words = set(stopwords.words('english')) def adjust_ngram_indices(self, words, common_ngrams, remove_stopwords): """ Adjust indices of common n-grams after removing stop words. Args: words (list): List of words in the original sentence. common_ngrams (dict): Common n-grams and their indices. Returns: dict: Adjusted common n-grams and their indices. """ if not remove_stopwords: return common_ngrams non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words] adjusted_ngrams = {} for ngram, positions in common_ngrams.items(): adjusted_positions = [] for start, end in positions: try: new_start = non_stop_word_indices.index(start) new_end = non_stop_word_indices.index(end) adjusted_positions.append((new_start, new_end)) except ValueError: continue # Skip if indices cannot be mapped adjusted_ngrams[ngram] = adjusted_positions return adjusted_ngrams # def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords): # """ # Mask one word before the first common n-gram, one between two n-grams, # and one after the last common n-gram (random selection). # Args: # original_sentence (str): Original sentence # common_ngrams (dict): Common n-grams and their indices # Returns: # str: Masked sentence with original stop words retained # """ # words = original_sentence.split() # if remove_stopwords: # non_stop_words = [word for word in words if word.lower() not in self.stop_words] # non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words] # else: # non_stop_words = words # non_stop_word_indices = list(range(len(words))) # # non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words # adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords) # mask_indices = [] # # Handle before the first common n-gram # if adjusted_ngrams: # first_ngram_start = list(adjusted_ngrams.values())[0][0][0] # if first_ngram_start > 0: # mask_indices.append(random.randint(0, first_ngram_start - 1)) # # Handle between common n-grams # ngram_positions = list(adjusted_ngrams.values()) # for i in range(len(ngram_positions) - 1): # end_prev = ngram_positions[i][-1][1] # start_next = ngram_positions[i + 1][0][0] # if start_next > end_prev + 1: # mask_indices.append(random.randint(end_prev + 1, start_next - 1)) # # Handle after the last common n-gram # last_ngram_end = ngram_positions[-1][-1][1] # if last_ngram_end < len(non_stop_words) - 1: # mask_indices.append(random.randint(last_ngram_end + 1, len(non_stop_words) - 1)) # # Mask the chosen indices # original_masked_sentence = words[:] # # for idx in mask_indices: # # if idx not in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]: # # non_stop_words[idx] = self.tokenizer.mask_token # # original_masked_sentence[idx] = self.tokenizer.mask_token # for idx in mask_indices: # if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]: # continue # Skip if index belongs to common n-grams # if remove_stopwords: # original_idx = non_stop_word_indices[idx] # Map back to original indices # original_masked_sentence[original_idx] = self.tokenizer.mask_token # else: # original_masked_sentence[idx] = self.tokenizer.mask_token # return " ".join(original_masked_sentence) def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords): """ Mask one word before the first common n-gram, one between two n-grams, and one after the last common n-gram (random selection). Args: original_sentence (str): Original sentence common_ngrams (dict): Common n-grams and their indices remove_stopwords (bool): Whether to remove stop words Returns: str: Masked sentence with original stop words retained """ words = original_sentence.split() if remove_stopwords: non_stop_words = [word for word in words if word.lower() not in self.stop_words] non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words] else: non_stop_words = words non_stop_word_indices = list(range(len(words))) adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords) # Collect all indices corresponding to common n-grams common_ngram_indices = { idx for ngram_positions in adjusted_ngrams.values() for start, end in ngram_positions for idx in range(start, end + 1) } mask_indices = [] # Handle before the first common n-gram if adjusted_ngrams: first_ngram_start = list(adjusted_ngrams.values())[0][0][0] if first_ngram_start > 0: potential_indices = [i for i in range(first_ngram_start) if i not in common_ngram_indices] if potential_indices: mask_indices.append(random.choice(potential_indices)) # Handle between common n-grams ngram_positions = list(adjusted_ngrams.values()) for i in range(len(ngram_positions) - 1): end_prev = ngram_positions[i][-1][1] start_next = ngram_positions[i + 1][0][0] potential_indices = [i for i in range(end_prev + 1, start_next) if i not in common_ngram_indices] if potential_indices: mask_indices.append(random.choice(potential_indices)) # Handle after the last common n-gram last_ngram_end = ngram_positions[-1][-1][1] if last_ngram_end < len(non_stop_words) - 1: potential_indices = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i not in common_ngram_indices] if potential_indices: mask_indices.append(random.choice(potential_indices)) # Mask the chosen indices original_masked_sentence = words[:] for idx in mask_indices: if remove_stopwords: original_idx = non_stop_word_indices[idx] # Map back to original indices original_masked_sentence[original_idx] = self.tokenizer.mask_token else: original_masked_sentence[idx] = self.tokenizer.mask_token return " ".join(original_masked_sentence) def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords): """ Mask one word before the first common n-gram, one between two n-grams, and one after the last common n-gram (highest entropy selection). Args: original_sentence (str): Original sentence common_ngrams (dict): Common n-grams and their indices Returns: str: Masked sentence with original stop words retained """ words = original_sentence.split() # non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words if remove_stopwords: non_stop_words = [word for word in words if word.lower() not in self.stop_words] non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words] else: non_stop_words = words non_stop_word_indices = list(range(len(words))) adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords) entropy_scores = {} for idx, word in enumerate(non_stop_words): if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]: continue # Skip words in common n-grams masked_sentence = non_stop_words[:idx] + [self.tokenizer.mask_token] + non_stop_words[idx + 1:] masked_sentence = " ".join(masked_sentence) input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] with torch.no_grad(): outputs = self.model(input_ids) logits = outputs.logits filtered_logits = logits[0, mask_token_index, :] probs = torch.softmax(filtered_logits, dim=-1) entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0) entropy_scores[idx] = entropy mask_indices = [] # Handle before the first common n-gram if adjusted_ngrams: first_ngram_start = list(adjusted_ngrams.values())[0][0][0] candidates = [i for i in range(first_ngram_start) if i in entropy_scores] if candidates: mask_indices.append(max(candidates, key=lambda x: entropy_scores[x])) # Handle between common n-grams ngram_positions = list(adjusted_ngrams.values()) for i in range(len(ngram_positions) - 1): end_prev = ngram_positions[i][-1][1] start_next = ngram_positions[i + 1][0][0] candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores] if candidates: mask_indices.append(max(candidates, key=lambda x: entropy_scores[x])) # Handle after the last common n-gram last_ngram_end = ngram_positions[-1][-1][1] candidates = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i in entropy_scores] if candidates: mask_indices.append(max(candidates, key=lambda x: entropy_scores[x])) # Mask the chosen indices original_masked_sentence = words[:] # for idx in mask_indices: # non_stop_words[idx] = self.tokenizer.mask_token # original_masked_sentence[idx] = self.tokenizer.mask_token for idx in mask_indices: if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]: continue # Skip if index belongs to common n-grams if remove_stopwords: original_idx = non_stop_word_indices[idx] # Map back to original indices original_masked_sentence[original_idx] = self.tokenizer.mask_token else: original_masked_sentence[idx] = self.tokenizer.mask_token return " ".join(original_masked_sentence) def calculate_mask_logits(self, masked_sentence): """ Calculate logits for masked tokens in the sentence using BERT. Args: masked_sentence (str): Sentence with [MASK] tokens Returns: dict: Masked token indices and their logits """ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] with torch.no_grad(): outputs = self.model(input_ids) logits = outputs.logits mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index} return mask_logits def process_sentences(self, original_sentences, result_dict, method="random", remove_stopwords=False): """ Process a list of sentences and calculate logits for masked tokens using the specified method. Args: original_sentences (list): List of original sentences result_dict (dict): Common n-grams and their indices for each sentence method (str): Masking method ("random" or "entropy") Returns: dict: Masked sentences and their logits for each sentence """ results = {} for sentence, ngrams in result_dict.items(): if method == "random": masked_sentence = self.mask_sentence_random(sentence, ngrams, remove_stopwords) elif method == "entropy": masked_sentence = self.mask_sentence_entropy(sentence, ngrams, remove_stopwords) else: raise ValueError("Invalid method. Choose 'random' or 'entropy'.") logits = self.calculate_mask_logits(masked_sentence) results[sentence] = { "masked_sentence": masked_sentence, "mask_logits": logits } return results # Example usage if __name__ == "__main__": # !!! Working both the cases regardless if the stopword is removed or not sentences = [ "The quick brown fox jumps over the lazy dog.", "A speedy brown fox jumps over a lazy dog.", "A swift brown fox leaps over the lethargic dog." ] result_dict ={ 'The quick brown fox jumps over the lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]} } processor = MaskingProcessor() results_random = processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=True) # results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False) for sentence, output in results_random.items(): print(f"Original Sentence (Random): {sentence}") print(f"Masked Sentence (Random): {output['masked_sentence']}") # # print(f"Mask Logits (Random): {output['mask_logits']}") # print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}') # print(f' length of output["mask_logits"] : {len(output["mask_logits"])}') # print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}') print('--------------------------------') # for mask_idx, logits in output["mask_logits"].items(): # print(f"Logits for [MASK] at position {mask_idx}:") # print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens # result_dict = { # "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}, # "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}, # "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]} # } # print('--------------------------------') # for sentence, output in results_entropy.items(): # print(f"Original Sentence (Entropy): {sentence}") # print(f"Masked Sentence (Entropy): {output['masked_sentence']}") # # print(f"Mask Logits (Entropy): {output['mask_logits']}") # print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}') # print(f' length of output["mask_logits"] : {len(output["mask_logits"])}') # print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')