|
import random |
|
import torch |
|
from transformers import BertTokenizer, BertForMaskedLM |
|
from nltk.corpus import stopwords |
|
import nltk |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
nltk.data.find('corpora/stopwords') |
|
except LookupError: |
|
nltk.download('stopwords') |
|
|
|
class MaskingProcessor: |
|
def __init__(self): |
|
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
self.model = BertForMaskedLM.from_pretrained("bert-base-uncased") |
|
self.stop_words = set(stopwords.words('english')) |
|
|
|
def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords=False): |
|
""" |
|
Mask one word before the first common n-gram, one between two n-grams, |
|
and one after the last common n-gram (random selection). |
|
|
|
Args: |
|
original_sentence (str): Original sentence |
|
common_ngrams (dict): Common n-grams and their indices |
|
|
|
Returns: |
|
str: Masked sentence |
|
""" |
|
if remove_stopwords: |
|
words = original_sentence.split() |
|
words = [word for word in words if word not in self.stop_words] |
|
else: |
|
words = original_sentence.split() |
|
|
|
mask_indices = [] |
|
|
|
if common_ngrams: |
|
first_ngram_start = list(common_ngrams.values())[0][0][0] |
|
if first_ngram_start > 0: |
|
mask_indices.append(random.randint(0, first_ngram_start - 1)) |
|
|
|
|
|
ngram_positions = list(common_ngrams.values()) |
|
for i in range(len(ngram_positions) - 1): |
|
end_prev = ngram_positions[i][-1][1] |
|
start_next = ngram_positions[i + 1][0][0] |
|
if start_next > end_prev + 1: |
|
mask_indices.append(random.randint(end_prev + 1, start_next - 1)) |
|
|
|
|
|
last_ngram_end = ngram_positions[-1][-1][1] |
|
if last_ngram_end < len(words) - 1: |
|
mask_indices.append(random.randint(last_ngram_end + 1, len(words) - 1)) |
|
|
|
|
|
for idx in mask_indices: |
|
if idx not in [index for ngram_indices in common_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]: |
|
words[idx] = self.tokenizer.mask_token |
|
|
|
return " ".join(words) |
|
|
|
def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords=False): |
|
""" |
|
Mask one word before the first common n-gram, one between two n-grams, |
|
and one after the last common n-gram (highest entropy selection). |
|
|
|
Args: |
|
original_sentence (str): Original sentence |
|
common_ngrams (dict): Common n-grams and their indices |
|
|
|
Returns: |
|
str: Masked sentence |
|
""" |
|
if remove_stopwords: |
|
words = original_sentence.split() |
|
words = [word for word in words if word not in self.stop_words] |
|
else: |
|
words = original_sentence.split() |
|
entropy_scores = {} |
|
|
|
for idx, word in enumerate(words): |
|
if idx in [index for ngram_indices in common_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]: |
|
continue |
|
|
|
masked_sentence = words[:idx] + [self.tokenizer.mask_token] + words[idx + 1:] |
|
masked_sentence = " ".join(masked_sentence) |
|
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] |
|
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] |
|
|
|
with torch.no_grad(): |
|
outputs = self.model(input_ids) |
|
logits = outputs.logits |
|
|
|
filtered_logits = logits[0, mask_token_index, :] |
|
probs = torch.softmax(filtered_logits, dim=-1) |
|
entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() |
|
entropy_scores[idx] = entropy |
|
|
|
mask_indices = [] |
|
|
|
|
|
if common_ngrams: |
|
first_ngram_start = list(common_ngrams.values())[0][0][0] |
|
candidates = [i for i in range(first_ngram_start) if i in entropy_scores] |
|
if candidates: |
|
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x])) |
|
|
|
|
|
ngram_positions = list(common_ngrams.values()) |
|
for i in range(len(ngram_positions) - 1): |
|
end_prev = ngram_positions[i][-1][1] |
|
start_next = ngram_positions[i + 1][0][0] |
|
candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores] |
|
if candidates: |
|
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x])) |
|
|
|
|
|
last_ngram_end = ngram_positions[-1][-1][1] |
|
candidates = [i for i in range(last_ngram_end + 1, len(words)) if i in entropy_scores] |
|
if candidates: |
|
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x])) |
|
|
|
|
|
for idx in mask_indices: |
|
words[idx] = self.tokenizer.mask_token |
|
|
|
return " ".join(words) |
|
|
|
def calculate_mask_logits(self, masked_sentence): |
|
""" |
|
Calculate logits for masked tokens in the sentence using BERT. |
|
|
|
Args: |
|
masked_sentence (str): Sentence with [MASK] tokens |
|
|
|
Returns: |
|
dict: Masked token indices and their logits |
|
""" |
|
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] |
|
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] |
|
|
|
with torch.no_grad(): |
|
outputs = self.model(input_ids) |
|
logits = outputs.logits |
|
|
|
mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index} |
|
return mask_logits |
|
|
|
def process_sentences(self, original_sentences, result_dict, remove_stopwords=False, method="random"): |
|
""" |
|
Process a list of sentences and calculate logits for masked tokens using the specified method. |
|
|
|
Args: |
|
original_sentences (list): List of original sentences |
|
result_dict (dict): Common n-grams and their indices for each sentence |
|
method (str): Masking method ("random" or "entropy") |
|
|
|
Returns: |
|
dict: Masked sentences and their logits for each sentence |
|
""" |
|
results = {} |
|
|
|
for sentence, ngrams in result_dict.items(): |
|
if method == "random": |
|
masked_sentence = self.mask_sentence_random(sentence, ngrams) |
|
elif method == "entropy": |
|
masked_sentence = self.mask_sentence_entropy(sentence, ngrams) |
|
else: |
|
raise ValueError("Invalid method. Choose 'random' or 'entropy'.") |
|
|
|
logits = self.calculate_mask_logits(masked_sentence) |
|
results[sentence] = { |
|
"masked_sentence": masked_sentence, |
|
"mask_logits": logits |
|
} |
|
|
|
return results |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
sentences = [ |
|
"The quick brown fox jumps over the lazy dog.", |
|
"A quick brown dog outpaces a lazy fox.", |
|
"Quick brown animals leap over lazy obstacles." |
|
] |
|
|
|
result_dict = { |
|
"The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]}, |
|
"A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]}, |
|
"Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
processor = MaskingProcessor() |
|
results_random = processor.process_sentences(sentences, result_dict, remove_stopwords=True, method="random") |
|
results_entropy = processor.process_sentences(sentences, result_dict, remove_stopwords=True, method="entropy") |
|
|
|
for sentence, output in results_random.items(): |
|
print(f"Original Sentence (Random): {sentence}") |
|
print(f"Masked Sentence (Random): {output['masked_sentence']}") |
|
|
|
|
|
for sentence, output in results_entropy.items(): |
|
print(f"Original Sentence (Entropy): {sentence}") |
|
print(f"Masked Sentence (Entropy): {output['masked_sentence']}") |
|
|
|
|
|
|
|
|
|
|
|
''' |
|
result_dict = { |
|
"The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}, |
|
"A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}, |
|
"Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]} |
|
} |
|
|
|
''' |