|
import random |
|
import torch |
|
from transformers import BertTokenizer, BertForMaskedLM |
|
from nltk.corpus import stopwords |
|
import nltk |
|
|
|
|
|
try: |
|
nltk.data.find('corpora/stopwords') |
|
except LookupError: |
|
nltk.download('stopwords') |
|
|
|
class MaskingProcessor: |
|
def __init__(self, ): |
|
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
self.model = BertForMaskedLM.from_pretrained("bert-base-uncased") |
|
self.stop_words = set(stopwords.words('english')) |
|
|
|
def adjust_ngram_indices(self, words, common_ngrams, remove_stopwords): |
|
""" |
|
Adjust indices of common n-grams after removing stop words. |
|
|
|
Args: |
|
words (list): List of words in the original sentence. |
|
common_ngrams (dict): Common n-grams and their indices. |
|
|
|
Returns: |
|
dict: Adjusted common n-grams and their indices. |
|
""" |
|
if not remove_stopwords: |
|
return common_ngrams |
|
|
|
non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words] |
|
adjusted_ngrams = {} |
|
|
|
for ngram, positions in common_ngrams.items(): |
|
adjusted_positions = [] |
|
for start, end in positions: |
|
try: |
|
new_start = non_stop_word_indices.index(start) |
|
new_end = non_stop_word_indices.index(end) |
|
adjusted_positions.append((new_start, new_end)) |
|
except ValueError: |
|
continue |
|
adjusted_ngrams[ngram] = adjusted_positions |
|
|
|
return adjusted_ngrams |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords): |
|
""" |
|
Mask one word before the first common n-gram, one between two n-grams, |
|
and one after the last common n-gram (random selection). |
|
|
|
Args: |
|
original_sentence (str): Original sentence |
|
common_ngrams (dict): Common n-grams and their indices |
|
remove_stopwords (bool): Whether to remove stop words |
|
|
|
Returns: |
|
str: Masked sentence with original stop words retained |
|
""" |
|
words = original_sentence.split() |
|
if remove_stopwords: |
|
non_stop_words = [word for word in words if word.lower() not in self.stop_words] |
|
non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words] |
|
else: |
|
non_stop_words = words |
|
non_stop_word_indices = list(range(len(words))) |
|
|
|
adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords) |
|
|
|
|
|
common_ngram_indices = { |
|
idx for ngram_positions in adjusted_ngrams.values() |
|
for start, end in ngram_positions |
|
for idx in range(start, end + 1) |
|
} |
|
|
|
mask_indices = [] |
|
|
|
if adjusted_ngrams: |
|
first_ngram_start = list(adjusted_ngrams.values())[0][0][0] |
|
if first_ngram_start > 0: |
|
potential_indices = [i for i in range(first_ngram_start) if i not in common_ngram_indices] |
|
if potential_indices: |
|
mask_indices.append(random.choice(potential_indices)) |
|
|
|
|
|
ngram_positions = list(adjusted_ngrams.values()) |
|
for i in range(len(ngram_positions) - 1): |
|
end_prev = ngram_positions[i][-1][1] |
|
start_next = ngram_positions[i + 1][0][0] |
|
potential_indices = [i for i in range(end_prev + 1, start_next) if i not in common_ngram_indices] |
|
if potential_indices: |
|
mask_indices.append(random.choice(potential_indices)) |
|
|
|
|
|
last_ngram_end = ngram_positions[-1][-1][1] |
|
if last_ngram_end < len(non_stop_words) - 1: |
|
potential_indices = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i not in common_ngram_indices] |
|
if potential_indices: |
|
mask_indices.append(random.choice(potential_indices)) |
|
|
|
|
|
original_masked_sentence = words[:] |
|
for idx in mask_indices: |
|
if remove_stopwords: |
|
original_idx = non_stop_word_indices[idx] |
|
original_masked_sentence[original_idx] = self.tokenizer.mask_token |
|
else: |
|
original_masked_sentence[idx] = self.tokenizer.mask_token |
|
|
|
return " ".join(original_masked_sentence) |
|
|
|
def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords): |
|
""" |
|
Mask one word before the first common n-gram, one between two n-grams, |
|
and one after the last common n-gram (highest entropy selection). |
|
|
|
Args: |
|
original_sentence (str): Original sentence |
|
common_ngrams (dict): Common n-grams and their indices |
|
|
|
Returns: |
|
str: Masked sentence with original stop words retained |
|
""" |
|
words = original_sentence.split() |
|
|
|
if remove_stopwords: |
|
non_stop_words = [word for word in words if word.lower() not in self.stop_words] |
|
non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words] |
|
else: |
|
non_stop_words = words |
|
non_stop_word_indices = list(range(len(words))) |
|
adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords) |
|
entropy_scores = {} |
|
|
|
for idx, word in enumerate(non_stop_words): |
|
if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]: |
|
continue |
|
|
|
masked_sentence = non_stop_words[:idx] + [self.tokenizer.mask_token] + non_stop_words[idx + 1:] |
|
masked_sentence = " ".join(masked_sentence) |
|
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] |
|
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] |
|
|
|
with torch.no_grad(): |
|
outputs = self.model(input_ids) |
|
logits = outputs.logits |
|
|
|
filtered_logits = logits[0, mask_token_index, :] |
|
probs = torch.softmax(filtered_logits, dim=-1) |
|
entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() |
|
entropy_scores[idx] = entropy |
|
|
|
mask_indices = [] |
|
|
|
|
|
if adjusted_ngrams: |
|
first_ngram_start = list(adjusted_ngrams.values())[0][0][0] |
|
candidates = [i for i in range(first_ngram_start) if i in entropy_scores] |
|
if candidates: |
|
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x])) |
|
|
|
|
|
ngram_positions = list(adjusted_ngrams.values()) |
|
for i in range(len(ngram_positions) - 1): |
|
end_prev = ngram_positions[i][-1][1] |
|
start_next = ngram_positions[i + 1][0][0] |
|
candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores] |
|
if candidates: |
|
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x])) |
|
|
|
|
|
last_ngram_end = ngram_positions[-1][-1][1] |
|
candidates = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i in entropy_scores] |
|
if candidates: |
|
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x])) |
|
|
|
|
|
original_masked_sentence = words[:] |
|
|
|
|
|
|
|
|
|
for idx in mask_indices: |
|
if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]: |
|
continue |
|
if remove_stopwords: |
|
original_idx = non_stop_word_indices[idx] |
|
original_masked_sentence[original_idx] = self.tokenizer.mask_token |
|
else: |
|
original_masked_sentence[idx] = self.tokenizer.mask_token |
|
|
|
|
|
return " ".join(original_masked_sentence) |
|
|
|
def calculate_mask_logits(self, masked_sentence): |
|
""" |
|
Calculate logits for masked tokens in the sentence using BERT. |
|
|
|
Args: |
|
masked_sentence (str): Sentence with [MASK] tokens |
|
|
|
Returns: |
|
dict: Masked token indices and their logits |
|
""" |
|
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] |
|
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] |
|
|
|
with torch.no_grad(): |
|
outputs = self.model(input_ids) |
|
logits = outputs.logits |
|
|
|
mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index} |
|
return mask_logits |
|
|
|
def process_sentences(self, original_sentences, result_dict, method="random", remove_stopwords=False): |
|
""" |
|
Process a list of sentences and calculate logits for masked tokens using the specified method. |
|
|
|
Args: |
|
original_sentences (list): List of original sentences |
|
result_dict (dict): Common n-grams and their indices for each sentence |
|
method (str): Masking method ("random" or "entropy") |
|
|
|
Returns: |
|
dict: Masked sentences and their logits for each sentence |
|
""" |
|
results = {} |
|
|
|
for sentence, ngrams in result_dict.items(): |
|
if method == "random": |
|
masked_sentence = self.mask_sentence_random(sentence, ngrams, remove_stopwords) |
|
elif method == "entropy": |
|
masked_sentence = self.mask_sentence_entropy(sentence, ngrams, remove_stopwords) |
|
else: |
|
raise ValueError("Invalid method. Choose 'random' or 'entropy'.") |
|
|
|
logits = self.calculate_mask_logits(masked_sentence) |
|
results[sentence] = { |
|
"masked_sentence": masked_sentence, |
|
"mask_logits": logits |
|
} |
|
|
|
return results |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
sentences = [ |
|
"The quick brown fox jumps over the lazy dog.", |
|
"A speedy brown fox jumps over a lazy dog.", |
|
"A swift brown fox leaps over the lethargic dog." |
|
] |
|
result_dict ={ |
|
'The quick brown fox jumps over the lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, |
|
'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, |
|
'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]} |
|
} |
|
|
|
|
|
processor = MaskingProcessor() |
|
results_random = processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=True) |
|
|
|
|
|
for sentence, output in results_random.items(): |
|
print(f"Original Sentence (Random): {sentence}") |
|
print(f"Masked Sentence (Random): {output['masked_sentence']}") |
|
|
|
|
|
|
|
|
|
print('--------------------------------') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|