import random import torch from transformers import BertTokenizer, BertForMaskedLM from nltk.corpus import stopwords import nltk # Ensure stopwords are downloaded try: nltk.data.find('corpora/stopwords') except LookupError: nltk.download('stopwords') class MaskingProcessor: def __init__(self, ): self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") self.model = BertForMaskedLM.from_pretrained("bert-base-uncased") self.stop_words = set(stopwords.words('english')) def adjust_ngram_indices(self, words, common_ngrams, remove_stopwords): """ Adjust indices of common n-grams after removing stop words. Args: words (list): List of words in the original sentence. common_ngrams (dict): Common n-grams and their indices. Returns: dict: Adjusted common n-grams and their indices. """ if not remove_stopwords: return common_ngrams non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words] adjusted_ngrams = {} for ngram, positions in common_ngrams.items(): adjusted_positions = [] for start, end in positions: try: new_start = non_stop_word_indices.index(start) new_end = non_stop_word_indices.index(end) adjusted_positions.append((new_start, new_end)) except ValueError: continue # Skip if indices cannot be mapped adjusted_ngrams[ngram] = adjusted_positions return adjusted_ngrams def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords): """ Mask one word before the first common n-gram, one between two n-grams, and one after the last common n-gram (random selection). Args: original_sentence (str): Original sentence common_ngrams (dict): Common n-grams and their indices Returns: str: Masked sentence with original stop words retained """ words = original_sentence.split() non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords) mask_indices = [] # Handle before the first common n-gram if adjusted_ngrams: first_ngram_start = list(adjusted_ngrams.values())[0][0][0] if first_ngram_start > 0: mask_indices.append(random.randint(0, first_ngram_start - 1)) # Handle between common n-grams ngram_positions = list(adjusted_ngrams.values()) for i in range(len(ngram_positions) - 1): end_prev = ngram_positions[i][-1][1] start_next = ngram_positions[i + 1][0][0] if start_next > end_prev + 1: mask_indices.append(random.randint(end_prev + 1, start_next - 1)) # Handle after the last common n-gram last_ngram_end = ngram_positions[-1][-1][1] if last_ngram_end < len(non_stop_words) - 1: mask_indices.append(random.randint(last_ngram_end + 1, len(non_stop_words) - 1)) # Mask the chosen indices original_masked_sentence = words[:] for idx in mask_indices: if idx not in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]: non_stop_words[idx] = self.tokenizer.mask_token original_masked_sentence[idx] = self.tokenizer.mask_token return " ".join(original_masked_sentence) def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords): """ Mask one word before the first common n-gram, one between two n-grams, and one after the last common n-gram (highest entropy selection). Args: original_sentence (str): Original sentence common_ngrams (dict): Common n-grams and their indices Returns: str: Masked sentence with original stop words retained """ words = original_sentence.split() non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords) entropy_scores = {} for idx, word in enumerate(non_stop_words): if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]: continue # Skip words in common n-grams masked_sentence = non_stop_words[:idx] + [self.tokenizer.mask_token] + non_stop_words[idx + 1:] masked_sentence = " ".join(masked_sentence) input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] with torch.no_grad(): outputs = self.model(input_ids) logits = outputs.logits filtered_logits = logits[0, mask_token_index, :] probs = torch.softmax(filtered_logits, dim=-1) entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0) entropy_scores[idx] = entropy mask_indices = [] # Handle before the first common n-gram if adjusted_ngrams: first_ngram_start = list(adjusted_ngrams.values())[0][0][0] candidates = [i for i in range(first_ngram_start) if i in entropy_scores] if candidates: mask_indices.append(max(candidates, key=lambda x: entropy_scores[x])) # Handle between common n-grams ngram_positions = list(adjusted_ngrams.values()) for i in range(len(ngram_positions) - 1): end_prev = ngram_positions[i][-1][1] start_next = ngram_positions[i + 1][0][0] candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores] if candidates: mask_indices.append(max(candidates, key=lambda x: entropy_scores[x])) # Handle after the last common n-gram last_ngram_end = ngram_positions[-1][-1][1] candidates = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i in entropy_scores] if candidates: mask_indices.append(max(candidates, key=lambda x: entropy_scores[x])) # Mask the chosen indices original_masked_sentence = words[:] for idx in mask_indices: non_stop_words[idx] = self.tokenizer.mask_token original_masked_sentence[idx] = self.tokenizer.mask_token return " ".join(original_masked_sentence) def calculate_mask_logits(self, masked_sentence): """ Calculate logits for masked tokens in the sentence using BERT. Args: masked_sentence (str): Sentence with [MASK] tokens Returns: dict: Masked token indices and their logits """ input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] with torch.no_grad(): outputs = self.model(input_ids) logits = outputs.logits mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index} return mask_logits def process_sentences(self, original_sentences, result_dict, method="random", remove_stopwords=False): """ Process a list of sentences and calculate logits for masked tokens using the specified method. Args: original_sentences (list): List of original sentences result_dict (dict): Common n-grams and their indices for each sentence method (str): Masking method ("random" or "entropy") Returns: dict: Masked sentences and their logits for each sentence """ results = {} for sentence, ngrams in result_dict.items(): if method == "random": masked_sentence = self.mask_sentence_random(sentence, ngrams, remove_stopwords) elif method == "entropy": masked_sentence = self.mask_sentence_entropy(sentence, ngrams, remove_stopwords) else: raise ValueError("Invalid method. Choose 'random' or 'entropy'.") logits = self.calculate_mask_logits(masked_sentence) results[sentence] = { "masked_sentence": masked_sentence, "mask_logits": logits } return results # Example usage if __name__ == "__main__": # !!! Working both the cases regardless if the stopword is removed or not sentences = [ "The quick brown fox jumps over the lazy dog.", "A quick brown dog outpaces a lazy fox.", "Quick brown animals leap over lazy obstacles." ] result_dict = { "The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]}, "A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]}, "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]} } # result_dict = { # "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}, # "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}, # "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]} # } processor = MaskingProcessor() results_random = processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False) # results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False) for sentence, output in results_random.items(): print(f"Original Sentence (Random): {sentence}") print(f"Masked Sentence (Random): {output['masked_sentence']}") # print(f"Mask Logits (Random): {output['mask_logits']}") print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}') print(f' length of output["mask_logits"] : {len(output["mask_logits"])}') print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}') print('--------------------------------') for mask_idx, logits in output["mask_logits"].items(): print(f"Logits for [MASK] at position {mask_idx}:") print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens # print('--------------------------------') # for sentence, output in results_entropy.items(): # print(f"Original Sentence (Entropy): {sentence}") # print(f"Masked Sentence (Entropy): {output['masked_sentence']}") # # print(f"Mask Logits (Entropy): {output['mask_logits']}") # print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}') # print(f' length of output["mask_logits"] : {len(output["mask_logits"])}') # print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')