import random import torch from transformers import BertTokenizer, BertForMaskedLM from nltk.corpus import stopwords import nltk from transformers import RobertaTokenizer, RobertaForMaskedLM # Ensure stopwords are downloaded try: nltk.data.find('corpora/stopwords') except LookupError: nltk.download('stopwords') class MaskingProcessor: # def __init__(self, tokenizer, model): def __init__(self): # self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") # self.model = BertForMaskedLM.from_pretrained("bert-base-uncased") # self.tokenizer = tokenizer # self.model = model self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking") self.model = BertForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking") # self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base") # self.model = RobertaForMaskedLM.from_pretrained("roberta-base") self.stop_words = set(stopwords.words('english')) def remove_stopwords(self, words): """ Remove stopwords from the given list of words. Args: words (list): List of words. Returns: list: List of non-stop words. """ return [word for word in words if word.lower() not in self.stop_words] def adjust_ngram_indices(self, original_words, common_ngrams): """ Adjust indices of common n-grams after removing stopwords. Args: original_words (list): Original list of words. common_ngrams (dict): Common n-grams and their indices. Returns: dict: Adjusted common n-grams with updated indices. """ non_stop_words = self.remove_stopwords(original_words) original_to_non_stop = [] non_stop_idx = 0 for original_idx, word in enumerate(original_words): if word.lower() not in self.stop_words: original_to_non_stop.append((original_idx, non_stop_idx)) non_stop_idx += 1 adjusted_ngrams = {} for ngram, positions in common_ngrams.items(): adjusted_positions = [] for start, end in positions: try: new_start = next(non_stop for orig, non_stop in original_to_non_stop if orig == start) new_end = next(non_stop for orig, non_stop in original_to_non_stop if orig == end) adjusted_positions.append((new_start, new_end)) except StopIteration: continue # Skip if indices cannot be mapped adjusted_ngrams[ngram] = adjusted_positions return adjusted_ngrams def mask_sentence_random(self, sentence, common_ngrams): """ Mask words in the sentence based on the specified rules after removing stopwords. """ # Split sentence into words original_words = sentence.split() # Handle punctuation at the end has_punctuation = False punctuation = None if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']): has_punctuation = True punctuation = original_words[-1][-1] original_words = original_words[:-1] print(f' ---- original_words : {original_words} ----- ') # Process words without punctuation non_stop_words = self.remove_stopwords(original_words) adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams) # Rest of the existing function code... mask_indices = [] ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions] if ngram_positions: first_ngram_start = ngram_positions[0][0] if first_ngram_start > 0: mask_index_before_ngram = random.randint(0, first_ngram_start-1) mask_indices.append(mask_index_before_ngram) # Mask words between common n-grams for i in range(len(ngram_positions) - 1): end_prev = ngram_positions[i][1] start_next = ngram_positions[i + 1][0] if start_next > end_prev + 1: mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1) mask_indices.append(mask_index_between_ngrams) # Mask a word after the last common n-gram last_ngram_end = ngram_positions[-1][1] if last_ngram_end < len(non_stop_words) - 1: mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1) mask_indices.append(mask_index_after_ngram) # Create mapping from non-stop words to original indices non_stop_to_original = {} non_stop_idx = 0 for orig_idx, word in enumerate(original_words): if word.lower() not in self.stop_words: non_stop_to_original[non_stop_idx] = orig_idx non_stop_idx += 1 # Map mask indices and apply masks original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices] masked_words = original_words.copy() for idx in original_mask_indices: masked_words[idx] = self.tokenizer.mask_token # masked_words[idx] = '' # for roberta # Add back punctuation if it existed if has_punctuation: masked_words.append(punctuation) print(f' ***** masked_words at end : {masked_words} ***** ') print(f' ***** original_mask_indices : {original_mask_indices} ***** ') print(f' ***** TESTING : {" ".join(masked_words)} ***** ') return " ".join(masked_words), original_mask_indices def mask_sentence_pseudorandom(self, sentence, common_ngrams): """ Mask words in the sentence based on the specified rules after removing stopwords. """ # Split sentence into words random.seed(3) original_words = sentence.split() # Handle punctuation at the end has_punctuation = False punctuation = None if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']): has_punctuation = True punctuation = original_words[-1][-1] original_words = original_words[:-1] print(f' ---- original_words : {original_words} ----- ') # Process words without punctuation non_stop_words = self.remove_stopwords(original_words) adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams) # Rest of the existing function code... mask_indices = [] ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions] if ngram_positions: first_ngram_start = ngram_positions[0][0] if first_ngram_start > 0: mask_index_before_ngram = random.randint(0, first_ngram_start-1) mask_indices.append(mask_index_before_ngram) # Mask words between common n-grams for i in range(len(ngram_positions) - 1): end_prev = ngram_positions[i][1] start_next = ngram_positions[i + 1][0] if start_next > end_prev + 1: mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1) mask_indices.append(mask_index_between_ngrams) # Mask a word after the last common n-gram last_ngram_end = ngram_positions[-1][1] if last_ngram_end < len(non_stop_words) - 1: mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1) mask_indices.append(mask_index_after_ngram) # Create mapping from non-stop words to original indices non_stop_to_original = {} non_stop_idx = 0 for orig_idx, word in enumerate(original_words): if word.lower() not in self.stop_words: non_stop_to_original[non_stop_idx] = orig_idx non_stop_idx += 1 # Map mask indices and apply masks original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices] masked_words = original_words.copy() for idx in original_mask_indices: masked_words[idx] = self.tokenizer.mask_token # masked_words[idx] = '' # for roberta # Add back punctuation if it existed if has_punctuation: masked_words.append(punctuation) print(f' ***** masked_words at end : {masked_words} ***** ') print(f' ***** original_mask_indices : {original_mask_indices} ***** ') print(f' ***** TESTING : {" ".join(masked_words)} ***** ') return " ".join(masked_words), original_mask_indices def calculate_word_entropy(self, sentence, word_position): """ Calculate entropy for a specific word position in the sentence. Args: sentence (str): The input sentence word_position (int): Position of the word to calculate entropy for Returns: float: Entropy value for the word """ words = sentence.split() masked_words = words.copy() masked_words[word_position] = self.tokenizer.mask_token masked_sentence = " ".join(masked_words) input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] with torch.no_grad(): outputs = self.model(input_ids) logits = outputs.logits # Get probabilities for the masked position probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1) # Calculate entropy: -sum(p * log(p)) entropy = -torch.sum(probs * torch.log(probs + 1e-9)) return entropy.item() def mask_sentence_entropy(self, sentence, common_ngrams): """ Mask words in the sentence based on entropy, following n-gram positioning rules. Args: sentence (str): Original sentence common_ngrams (dict): Common n-grams and their indices Returns: str: Masked sentence """ # Split sentence into words original_words = sentence.split() # Handle punctuation at the end has_punctuation = False punctuation = None if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']): has_punctuation = True punctuation = original_words[-1][-1] original_words = original_words[:-1] # Process words without punctuation non_stop_words = self.remove_stopwords(original_words) adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams) # Create mapping from non-stop words to original indices non_stop_to_original = {} original_to_non_stop = {} non_stop_idx = 0 for orig_idx, word in enumerate(original_words): if word.lower() not in self.stop_words: non_stop_to_original[non_stop_idx] = orig_idx original_to_non_stop[orig_idx] = non_stop_idx non_stop_idx += 1 ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions] mask_indices = [] if ngram_positions: # Handle words before first n-gram first_ngram_start = ngram_positions[0][0] if first_ngram_start > 0: candidate_positions = range(0, first_ngram_start) entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) for pos in candidate_positions] mask_indices.append(max(entropies, key=lambda x: x[1])[0]) # Handle words between n-grams for i in range(len(ngram_positions) - 1): end_prev = ngram_positions[i][1] start_next = ngram_positions[i + 1][0] if start_next > end_prev + 1: candidate_positions = range(end_prev + 1, start_next) entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) for pos in candidate_positions] mask_indices.append(max(entropies, key=lambda x: x[1])[0]) # Handle words after last n-gram last_ngram_end = ngram_positions[-1][1] if last_ngram_end < len(non_stop_words) - 1: candidate_positions = range(last_ngram_end + 1, len(non_stop_words)) entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) for pos in candidate_positions] mask_indices.append(max(entropies, key=lambda x: x[1])[0]) # Map mask indices to original sentence positions and apply masks original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices] masked_words = original_words.copy() for idx in original_mask_indices: masked_words[idx] = self.tokenizer.mask_token # Add back punctuation if it existed if has_punctuation: masked_words.append(punctuation) return " ".join(masked_words), original_mask_indices def calculate_mask_logits(self, original_sentence, original_mask_indices): """ Calculate logits for masked tokens in the sentence using BERT. Args: original_sentence (str): Original sentence without masks original_mask_indices (list): List of indices to mask Returns: dict: Masked token indices and their logits """ print('==========================================================================================================') words = original_sentence.split() print(f' ##### calculate_mask_logits >> words : {words} ##### ') mask_logits = {} for idx in original_mask_indices: # Create a copy of words and mask the current position print(f' ---- idx : {idx} ----- ') masked_words = words.copy() masked_words[idx] = '[MASK]' # masked_words[idx] = '' # for roberta masked_sentence = " ".join(masked_words) print(f' ---- masked_sentence : {masked_sentence} ----- ') # Calculate logits for the current mask input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] with torch.no_grad(): outputs = self.model(input_ids) logits = outputs.logits # Extract logits for the masked position mask_logits_tensor = logits[0, mask_token_index, :] # Get top logits and corresponding tokens top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 100, dim=-1) # Get more candidates # Convert token IDs to words and filter out subword tokens top_tokens = [] top_logits = [] seen_words = set() # To keep track of unique words for token_id, logit in zip(top_mask_indices[0], top_mask_logits[0]): token = self.tokenizer.convert_ids_to_tokens(token_id.item()) # Skip if it's a subword token (starts with ##) if token.startswith('##'): continue # Convert token to proper word word = self.tokenizer.convert_tokens_to_string([token]).strip() # Only add if it's a new word and not empty if word and word not in seen_words: seen_words.add(word) top_tokens.append(word) top_logits.append(logit.item()) # Break if we have 50 unique complete words if len(top_tokens) == 50: break # print(f' ---- top_tokens : {top_tokens} ----- ') # Store results mask_logits[idx] = { "tokens": top_tokens, "logits": top_logits } return mask_logits # def calculate_mask_logits(self, original_sentence, original_mask_indices): # """ # Calculate logits for masked tokens in the sentence using BERT. # Args: # original_sentence (str): Original sentence without masks # original_mask_indices (list): List of indices to mask # Returns: # dict: Masked token indices and their logits # """ # words = original_sentence.split() # print(f' ##### calculate_mask_logits >> words : {words} ##### ') # mask_logits = {} # for idx in original_mask_indices: # # Create a copy of words and mask the current position # print(f' ---- idx : {idx} ----- ') # masked_words = words.copy() # print(f' ---- words : {masked_words} ----- ') # # masked_words[idx] = self.tokenizer.mask_token # masked_words[idx] = '[MASK]' # print(f' ---- masked_words : {masked_words} ----- ') # masked_sentence = " ".join(masked_words) # print(f' ---- masked_sentence : {masked_sentence} ----- ') # # Calculate logits for the current mask # input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] # mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] # with torch.no_grad(): # outputs = self.model(input_ids) # logits = outputs.logits # # Extract logits for the masked position # mask_logits_tensor = logits[0, mask_token_index, :] # # Get top 50 logits and corresponding tokens # top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 50, dim=-1) # # Convert token IDs to words # top_tokens = [self.tokenizer.convert_ids_to_tokens(token_id.item()) for token_id in top_mask_indices[0]] # print(f' ---- top_tokens : {top_tokens} ----- ') # # Store results # mask_logits[idx] = { # "tokens": top_tokens, # "logits": top_mask_logits.tolist() # } # return mask_logits def process_sentences(self, sentences, result_dict, method="random"): """ Process sentences and calculate logits for masked tokens. """ results = {} for sentence, ngrams in result_dict.items(): # Split punctuation from the last word before processing words = sentence.split() last_word = words[-1] if any(last_word.endswith(p) for p in ['.', ',', '!', '?', ';', ':']): # Split the last word and punctuation words[-1] = last_word[:-1] punctuation = last_word[-1] # Rejoin with space before punctuation to treat it as separate token processed_sentence = " ".join(words) + " " + punctuation else: processed_sentence = sentence if method == "random": masked_sentence, original_mask_indices = self.mask_sentence_random(processed_sentence, ngrams) elif method == "pseudorandom": masked_sentence, original_mask_indices = self.mask_sentence_pseudorandom(processed_sentence, ngrams) else: # entropy masked_sentence, original_mask_indices = self.mask_sentence_entropy(processed_sentence, ngrams) logits = self.calculate_mask_logits(processed_sentence, original_mask_indices) results[sentence] = { "masked_sentence": masked_sentence, "mask_logits": logits } return results if __name__ == "__main__": # !!! Working both the cases regardless if the stopword is removed or not sentences = [ "The quick brown fox jumps over small cat the lazy dog everyday again and again .", # "A speedy brown fox jumps over a lazy dog.", # "A swift brown fox leaps over the lethargic dog." ] result_dict ={ 'The quick brown fox jumps over small cat the lazy dog everyday again and again .': {'brown fox': [(2, 3)],'cat': [(7, 7)], 'dog': [(10, 10)]}, # 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, # 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]} } processor = MaskingProcessor() # results_random = processor.process_sentences(sentences, result_dict) results_entropy = processor.process_sentences(sentences, result_dict, method="random") ''' results structure : results = { "The quick brown fox jumps over the lazy dog everyday.": { # Original sentence as key "masked_sentence": str, # The sentence with [MASK] tokens "mask_logits": { # Dictionary of mask positions and their predictions 1: { # Position of mask in sentence "tokens" (words) : list, # List of top 50 predicted tokens "logits" (probabilities) : list # Corresponding logits for those tokens }, 7: { "tokens" (words) : list, "logits" (probabilities) : list }, 10: { "tokens (words)": list, "logits (probabilities)": list } } } } ''' # results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False) for sentence, output in results_entropy.items(): print(f"Original Sentence (Random): {sentence}") print(f"Masked Sentence (Random): {output['masked_sentence']}") # print(f"Mask Logits (Random): {output['mask_logits']}") # print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}') # print(f' length of output["mask_logits"] : {len(output["mask_logits"])}') # print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}') # print('--------------------------------') # for mask_idx, logits in output["mask_logits"].items(): # print(f"Logits for [MASK] at position {mask_idx}:") # print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens # print(f' len(logits) : {len(logits)}') # ------------------------------------------------------------------------------------------------ # def mask_sentence_random(self, sentence, common_ngrams): # """ # Mask words in the sentence based on the specified rules after removing stopwords. # """ # original_words = sentence.split() # # print(f' ---- original_words : {original_words} ----- ') # non_stop_words = self.remove_stopwords(original_words) # # print(f' ---- non_stop_words : {non_stop_words} ----- ') # adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams) # # print(f' ---- common_ngrams : {common_ngrams} ----- ') # # print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ') # mask_indices = [] # # Extract n-gram positions in non-stop words # ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions] # # Mask a word before the first common n-gram # if ngram_positions: # # print(f' ---- ngram_positions : {ngram_positions} ----- ') # first_ngram_start = ngram_positions[0][0] # # print(f' ---- first_ngram_start : {first_ngram_start} ----- ') # if first_ngram_start > 0: # mask_index_before_ngram = random.randint(0, first_ngram_start-1) # # print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ') # mask_indices.append(mask_index_before_ngram) # # Mask words between common n-grams # for i in range(len(ngram_positions) - 1): # end_prev = ngram_positions[i][1] # # print(f' ---- end_prev : {end_prev} ----- ') # start_next = ngram_positions[i + 1][0] # # print(f' ---- start_next : {start_next} ----- ') # if start_next > end_prev + 1: # mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1) # # print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ') # mask_indices.append(mask_index_between_ngrams) # # Mask a word after the last common n-gram # last_ngram_end = ngram_positions[-1][1] # if last_ngram_end < len(non_stop_words) - 1: # # print(f' ---- last_ngram_end : {last_ngram_end} ----- ') # mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1) # # print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ') # mask_indices.append(mask_index_after_ngram) # # Create mapping from non-stop words to original indices # non_stop_to_original = {} # non_stop_idx = 0 # for orig_idx, word in enumerate(original_words): # if word.lower() not in self.stop_words: # non_stop_to_original[non_stop_idx] = orig_idx # non_stop_idx += 1 # # Map mask indices from non-stop word positions to original positions # # print(f' ---- non_stop_to_original : {non_stop_to_original} ----- ') # original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices] # # print(f' ---- original_mask_indices : {original_mask_indices} ----- ') # # Apply masks to the original sentence # masked_words = original_words.copy() # for idx in original_mask_indices: # masked_words[idx] = self.tokenizer.mask_token # return " ".join(masked_words), original_mask_indices