import random import torch import logging from transformers import BertTokenizer, BertForMaskedLM from nltk.corpus import stopwords import nltk from transformers import RobertaTokenizer, RobertaForMaskedLM from tqdm import tqdm # Set logging to WARNING for a cleaner terminal. logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # Ensure stopwords are downloaded try: nltk.data.find('corpora/stopwords') except LookupError: nltk.download('stopwords') class MaskingProcessor: def __init__(self, tokenizer, model): self.tokenizer = tokenizer self.model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.stop_words = set(stopwords.words('english')) tqdm.write(f"[MaskingProcessor] Initialized on device: {self.device}") def remove_stopwords(self, words): return [word for word in words if word.lower() not in self.stop_words] def adjust_ngram_indices(self, original_words, common_ngrams): logger.info("Adjusting n-gram indices.") non_stop_words = self.remove_stopwords(original_words) original_to_non_stop = [] non_stop_idx = 0 for original_idx, word in enumerate(original_words): if word.lower() not in self.stop_words: original_to_non_stop.append((original_idx, non_stop_idx)) non_stop_idx += 1 adjusted_ngrams = {} for ngram, positions in common_ngrams.items(): adjusted_positions = [] for start, end in positions: try: new_start = next(non_stop for orig, non_stop in original_to_non_stop if orig == start) new_end = next(non_stop for orig, non_stop in original_to_non_stop if orig == end) adjusted_positions.append((new_start, new_end)) except StopIteration: continue adjusted_ngrams[ngram] = adjusted_positions return adjusted_ngrams def mask_sentence_random(self, sentence, common_ngrams): tqdm.write(f"[MaskingProcessor] Masking (random) sentence: {sentence}") original_words = sentence.split() has_punctuation = False punctuation = '' if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']): has_punctuation = True punctuation = original_words[-1][-1] original_words = original_words[:-1] non_stop_words = self.remove_stopwords(original_words) adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams) mask_indices = [] ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions] if ngram_positions: first_ngram_start = ngram_positions[0][0] if first_ngram_start > 0: mask_index_before_ngram = random.randint(0, first_ngram_start-1) mask_indices.append(mask_index_before_ngram) for i in range(len(ngram_positions) - 1): end_prev = ngram_positions[i][1] start_next = ngram_positions[i + 1][0] if start_next > end_prev + 1: mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1) mask_indices.append(mask_index_between_ngrams) last_ngram_end = ngram_positions[-1][1] if last_ngram_end < len(non_stop_words) - 1: mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1) mask_indices.append(mask_index_after_ngram) non_stop_to_original = {} non_stop_idx = 0 for orig_idx, word in enumerate(original_words): if word.lower() not in self.stop_words: non_stop_to_original[non_stop_idx] = orig_idx non_stop_idx += 1 original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices] masked_words = original_words.copy() for idx in original_mask_indices: masked_words[idx] = self.tokenizer.mask_token if has_punctuation: masked_words.append(punctuation) logger.info(f"Masked sentence (random): {' '.join(masked_words)}") return " ".join(masked_words), original_mask_indices def mask_sentence_pseudorandom(self, sentence, common_ngrams): logger.info(f"Masking sentence using pseudorandom strategy: {sentence}") random.seed(3) original_words = sentence.split() has_punctuation = False punctuation = '' if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']): has_punctuation = True punctuation = original_words[-1][-1] original_words = original_words[:-1] non_stop_words = self.remove_stopwords(original_words) adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams) mask_indices = [] ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions] if ngram_positions: first_ngram_start = ngram_positions[0][0] if first_ngram_start > 0: mask_index_before_ngram = random.randint(0, first_ngram_start-1) mask_indices.append(mask_index_before_ngram) for i in range(len(ngram_positions) - 1): end_prev = ngram_positions[i][1] start_next = ngram_positions[i + 1][0] if start_next > end_prev + 1: mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1) mask_indices.append(mask_index_between_ngrams) last_ngram_end = ngram_positions[-1][1] if last_ngram_end < len(non_stop_words) - 1: mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1) mask_indices.append(mask_index_after_ngram) non_stop_to_original = {} non_stop_idx = 0 for orig_idx, word in enumerate(original_words): if word.lower() not in self.stop_words: non_stop_to_original[non_stop_idx] = orig_idx non_stop_idx += 1 original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices] masked_words = original_words.copy() for idx in original_mask_indices: masked_words[idx] = self.tokenizer.mask_token if has_punctuation: masked_words.append(punctuation) logger.info(f"Masked sentence (pseudorandom): {' '.join(masked_words)}") return " ".join(masked_words), original_mask_indices def mask_sentence_entropy(self, sentence, common_ngrams): logger.info(f"Masking sentence using entropy strategy: {sentence}") original_words = sentence.split() has_punctuation = False punctuation = '' if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']): has_punctuation = True punctuation = original_words[-1][-1] original_words = original_words[:-1] non_stop_words = self.remove_stopwords(original_words) adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams) mask_indices = [] ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions] non_stop_to_original = {} non_stop_idx = 0 for orig_idx, word in enumerate(original_words): if word.lower() not in self.stop_words: non_stop_to_original[non_stop_idx] = orig_idx non_stop_idx += 1 if ngram_positions: first_ngram_start = ngram_positions[0][0] if first_ngram_start > 0: candidate_positions = range(0, first_ngram_start) entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) for pos in candidate_positions] mask_indices.append(max(entropies, key=lambda x: x[1])[0]) for i in range(len(ngram_positions) - 1): end_prev = ngram_positions[i][1] start_next = ngram_positions[i + 1][0] if start_next > end_prev + 1: candidate_positions = range(end_prev + 1, start_next) entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) for pos in candidate_positions] mask_indices.append(max(entropies, key=lambda x: x[1])[0]) last_ngram_end = ngram_positions[-1][1] if last_ngram_end < len(non_stop_words) - 1: candidate_positions = range(last_ngram_end + 1, len(non_stop_words)) entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) for pos in candidate_positions] mask_indices.append(max(entropies, key=lambda x: x[1])[0]) original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices] masked_words = original_words.copy() for idx in original_mask_indices: masked_words[idx] = self.tokenizer.mask_token if has_punctuation: masked_words.append(punctuation) logger.info(f"Masked sentence (entropy): {' '.join(masked_words)}") return " ".join(masked_words), original_mask_indices def calculate_mask_logits(self, original_sentence, original_mask_indices): logger.info(f"Calculating mask logits for sentence: {original_sentence}") words = original_sentence.split() mask_logits = {} for idx in original_mask_indices: masked_words = words.copy() masked_words[idx] = self.tokenizer.mask_token masked_sentence = " ".join(masked_words) input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"].to(self.device) mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] with torch.no_grad(): outputs = self.model(input_ids) logits = outputs.logits mask_logits_tensor = logits[0, mask_token_index, :] top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 100, dim=-1) top_tokens = [] top_logits = [] seen_words = set() for token_id, logit in zip(top_mask_indices[0], top_mask_logits[0]): token = self.tokenizer.convert_ids_to_tokens(token_id.item()) if token.startswith('##'): continue word = self.tokenizer.convert_tokens_to_string([token]).strip() if word and word not in seen_words: seen_words.add(word) top_tokens.append(word) top_logits.append(logit.item()) if len(top_tokens) == 50: break mask_logits[idx] = { "tokens": top_tokens, "logits": top_logits } logger.info("Completed calculating mask logits.") return mask_logits def calculate_word_entropy(self, sentence, word_position): logger.info(f"Calculating word entropy for position {word_position} in sentence: {sentence}") words = sentence.split() masked_words = words.copy() masked_words[word_position] = self.tokenizer.mask_token masked_sentence = " ".join(masked_words) input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"].to(self.device) mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] with torch.no_grad(): outputs = self.model(input_ids) logits = outputs.logits probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1) entropy = -torch.sum(probs * torch.log(probs + 1e-9)) logger.info(f"Computed entropy: {entropy.item()}") return entropy.item() def process_sentences(self, sentences_list, common_grams, method="random"): tqdm.write(f"[MaskingProcessor] Processing sentences using method: {method}") results = {} for sentence, ngrams in tqdm(common_grams.items(), desc="Masking Sentences"): words = sentence.split() last_word = words[-1] if any(last_word.endswith(p) for p in ['.', ',', '!', '?', ';', ':']): words[-1] = last_word[:-1] punctuation = last_word[-1] processed_sentence = " ".join(words) + " " + punctuation else: processed_sentence = sentence if method == "random": masked_sentence, original_mask_indices = self.mask_sentence_random(processed_sentence, ngrams) elif method == "pseudorandom": masked_sentence, original_mask_indices = self.mask_sentence_pseudorandom(processed_sentence, ngrams) else: # entropy masked_sentence, original_mask_indices = self.mask_sentence_entropy(processed_sentence, ngrams) logits = self.calculate_mask_logits(processed_sentence, original_mask_indices) results[sentence] = { "masked_sentence": masked_sentence, "mask_logits": logits } logger.info(f"Processed sentence: {sentence}") tqdm.write("[MaskingProcessor] Completed processing sentences.") return results if __name__ == "__main__": sentences = [ "The quick brown fox jumps over small cat the lazy dog everyday again and again .", ] result_dict = { 'The quick brown fox jumps over small cat the lazy dog everyday again and again .': { 'brown fox': [(2, 3)], 'cat': [(7, 7)], 'dog': [(10, 10)] } } processor = MaskingProcessor( BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking"), BertForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking") ) results_entropy = processor.process_sentences(sentences_list, common_grams, method="random") for sentence, output in results_entropy.items(): logger.info(f"Original Sentence (Random): {sentence}") logger.info(f"Masked Sentence (Random): {output['masked_sentence']}")