|
import random |
|
import torch |
|
from transformers import BertTokenizer, BertForMaskedLM |
|
from nltk.corpus import stopwords |
|
import nltk |
|
from transformers import RobertaTokenizer, RobertaForMaskedLM |
|
|
|
|
|
|
|
try: |
|
nltk.data.find('corpora/stopwords') |
|
except LookupError: |
|
nltk.download('stopwords') |
|
|
|
class MaskingProcessor: |
|
|
|
def __init__(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking") |
|
self.model = BertForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking") |
|
|
|
|
|
|
|
|
|
self.stop_words = set(stopwords.words('english')) |
|
|
|
def remove_stopwords(self, words): |
|
""" |
|
Remove stopwords from the given list of words. |
|
|
|
Args: |
|
words (list): List of words. |
|
|
|
Returns: |
|
list: List of non-stop words. |
|
""" |
|
return [word for word in words if word.lower() not in self.stop_words] |
|
|
|
def adjust_ngram_indices(self, original_words, common_ngrams): |
|
""" |
|
Adjust indices of common n-grams after removing stopwords. |
|
|
|
Args: |
|
original_words (list): Original list of words. |
|
common_ngrams (dict): Common n-grams and their indices. |
|
|
|
Returns: |
|
dict: Adjusted common n-grams with updated indices. |
|
""" |
|
non_stop_words = self.remove_stopwords(original_words) |
|
original_to_non_stop = [] |
|
non_stop_idx = 0 |
|
|
|
for original_idx, word in enumerate(original_words): |
|
if word.lower() not in self.stop_words: |
|
original_to_non_stop.append((original_idx, non_stop_idx)) |
|
non_stop_idx += 1 |
|
|
|
adjusted_ngrams = {} |
|
for ngram, positions in common_ngrams.items(): |
|
adjusted_positions = [] |
|
for start, end in positions: |
|
try: |
|
new_start = next(non_stop for orig, non_stop in original_to_non_stop if orig == start) |
|
new_end = next(non_stop for orig, non_stop in original_to_non_stop if orig == end) |
|
adjusted_positions.append((new_start, new_end)) |
|
except StopIteration: |
|
continue |
|
adjusted_ngrams[ngram] = adjusted_positions |
|
|
|
return adjusted_ngrams |
|
|
|
def mask_sentence_random(self, sentence, common_ngrams): |
|
""" |
|
Mask words in the sentence based on the specified rules after removing stopwords. |
|
""" |
|
|
|
original_words = sentence.split() |
|
|
|
|
|
has_punctuation = False |
|
punctuation = None |
|
if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']): |
|
has_punctuation = True |
|
punctuation = original_words[-1][-1] |
|
original_words = original_words[:-1] |
|
|
|
print(f' ---- original_words : {original_words} ----- ') |
|
|
|
|
|
non_stop_words = self.remove_stopwords(original_words) |
|
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams) |
|
|
|
|
|
mask_indices = [] |
|
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions] |
|
|
|
if ngram_positions: |
|
first_ngram_start = ngram_positions[0][0] |
|
if first_ngram_start > 0: |
|
mask_index_before_ngram = random.randint(0, first_ngram_start-1) |
|
mask_indices.append(mask_index_before_ngram) |
|
|
|
|
|
for i in range(len(ngram_positions) - 1): |
|
end_prev = ngram_positions[i][1] |
|
start_next = ngram_positions[i + 1][0] |
|
if start_next > end_prev + 1: |
|
mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1) |
|
mask_indices.append(mask_index_between_ngrams) |
|
|
|
|
|
last_ngram_end = ngram_positions[-1][1] |
|
if last_ngram_end < len(non_stop_words) - 1: |
|
mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1) |
|
mask_indices.append(mask_index_after_ngram) |
|
|
|
|
|
non_stop_to_original = {} |
|
non_stop_idx = 0 |
|
for orig_idx, word in enumerate(original_words): |
|
if word.lower() not in self.stop_words: |
|
non_stop_to_original[non_stop_idx] = orig_idx |
|
non_stop_idx += 1 |
|
|
|
|
|
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices] |
|
masked_words = original_words.copy() |
|
for idx in original_mask_indices: |
|
masked_words[idx] = self.tokenizer.mask_token |
|
|
|
|
|
|
|
if has_punctuation: |
|
masked_words.append(punctuation) |
|
|
|
print(f' ***** masked_words at end : {masked_words} ***** ') |
|
print(f' ***** original_mask_indices : {original_mask_indices} ***** ') |
|
print(f' ***** TESTING : {" ".join(masked_words)} ***** ') |
|
|
|
return " ".join(masked_words), original_mask_indices |
|
|
|
def mask_sentence_pseudorandom(self, sentence, common_ngrams): |
|
""" |
|
Mask words in the sentence based on the specified rules after removing stopwords. |
|
""" |
|
|
|
random.seed(3) |
|
original_words = sentence.split() |
|
|
|
|
|
has_punctuation = False |
|
punctuation = None |
|
if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']): |
|
has_punctuation = True |
|
punctuation = original_words[-1][-1] |
|
original_words = original_words[:-1] |
|
|
|
print(f' ---- original_words : {original_words} ----- ') |
|
|
|
|
|
non_stop_words = self.remove_stopwords(original_words) |
|
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams) |
|
|
|
|
|
mask_indices = [] |
|
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions] |
|
|
|
if ngram_positions: |
|
first_ngram_start = ngram_positions[0][0] |
|
if first_ngram_start > 0: |
|
mask_index_before_ngram = random.randint(0, first_ngram_start-1) |
|
mask_indices.append(mask_index_before_ngram) |
|
|
|
|
|
for i in range(len(ngram_positions) - 1): |
|
end_prev = ngram_positions[i][1] |
|
start_next = ngram_positions[i + 1][0] |
|
if start_next > end_prev + 1: |
|
mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1) |
|
mask_indices.append(mask_index_between_ngrams) |
|
|
|
|
|
last_ngram_end = ngram_positions[-1][1] |
|
if last_ngram_end < len(non_stop_words) - 1: |
|
mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1) |
|
mask_indices.append(mask_index_after_ngram) |
|
|
|
|
|
non_stop_to_original = {} |
|
non_stop_idx = 0 |
|
for orig_idx, word in enumerate(original_words): |
|
if word.lower() not in self.stop_words: |
|
non_stop_to_original[non_stop_idx] = orig_idx |
|
non_stop_idx += 1 |
|
|
|
|
|
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices] |
|
masked_words = original_words.copy() |
|
for idx in original_mask_indices: |
|
masked_words[idx] = self.tokenizer.mask_token |
|
|
|
|
|
|
|
if has_punctuation: |
|
masked_words.append(punctuation) |
|
|
|
print(f' ***** masked_words at end : {masked_words} ***** ') |
|
print(f' ***** original_mask_indices : {original_mask_indices} ***** ') |
|
print(f' ***** TESTING : {" ".join(masked_words)} ***** ') |
|
|
|
return " ".join(masked_words), original_mask_indices |
|
|
|
|
|
def calculate_word_entropy(self, sentence, word_position): |
|
""" |
|
Calculate entropy for a specific word position in the sentence. |
|
|
|
Args: |
|
sentence (str): The input sentence |
|
word_position (int): Position of the word to calculate entropy for |
|
|
|
Returns: |
|
float: Entropy value for the word |
|
""" |
|
words = sentence.split() |
|
masked_words = words.copy() |
|
masked_words[word_position] = self.tokenizer.mask_token |
|
masked_sentence = " ".join(masked_words) |
|
|
|
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] |
|
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] |
|
|
|
with torch.no_grad(): |
|
outputs = self.model(input_ids) |
|
logits = outputs.logits |
|
|
|
|
|
probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1) |
|
|
|
entropy = -torch.sum(probs * torch.log(probs + 1e-9)) |
|
|
|
return entropy.item() |
|
|
|
def mask_sentence_entropy(self, sentence, common_ngrams): |
|
""" |
|
Mask words in the sentence based on entropy, following n-gram positioning rules. |
|
|
|
Args: |
|
sentence (str): Original sentence |
|
common_ngrams (dict): Common n-grams and their indices |
|
|
|
Returns: |
|
str: Masked sentence |
|
""" |
|
|
|
original_words = sentence.split() |
|
|
|
|
|
has_punctuation = False |
|
punctuation = None |
|
if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']): |
|
has_punctuation = True |
|
punctuation = original_words[-1][-1] |
|
original_words = original_words[:-1] |
|
|
|
|
|
non_stop_words = self.remove_stopwords(original_words) |
|
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams) |
|
|
|
|
|
non_stop_to_original = {} |
|
original_to_non_stop = {} |
|
non_stop_idx = 0 |
|
for orig_idx, word in enumerate(original_words): |
|
if word.lower() not in self.stop_words: |
|
non_stop_to_original[non_stop_idx] = orig_idx |
|
original_to_non_stop[orig_idx] = non_stop_idx |
|
non_stop_idx += 1 |
|
|
|
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions] |
|
mask_indices = [] |
|
|
|
if ngram_positions: |
|
|
|
first_ngram_start = ngram_positions[0][0] |
|
if first_ngram_start > 0: |
|
candidate_positions = range(0, first_ngram_start) |
|
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) |
|
for pos in candidate_positions] |
|
mask_indices.append(max(entropies, key=lambda x: x[1])[0]) |
|
|
|
|
|
for i in range(len(ngram_positions) - 1): |
|
end_prev = ngram_positions[i][1] |
|
start_next = ngram_positions[i + 1][0] |
|
if start_next > end_prev + 1: |
|
candidate_positions = range(end_prev + 1, start_next) |
|
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) |
|
for pos in candidate_positions] |
|
mask_indices.append(max(entropies, key=lambda x: x[1])[0]) |
|
|
|
|
|
last_ngram_end = ngram_positions[-1][1] |
|
if last_ngram_end < len(non_stop_words) - 1: |
|
candidate_positions = range(last_ngram_end + 1, len(non_stop_words)) |
|
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos])) |
|
for pos in candidate_positions] |
|
mask_indices.append(max(entropies, key=lambda x: x[1])[0]) |
|
|
|
|
|
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices] |
|
masked_words = original_words.copy() |
|
for idx in original_mask_indices: |
|
masked_words[idx] = self.tokenizer.mask_token |
|
|
|
|
|
if has_punctuation: |
|
masked_words.append(punctuation) |
|
|
|
return " ".join(masked_words), original_mask_indices |
|
|
|
def calculate_mask_logits(self, original_sentence, original_mask_indices): |
|
""" |
|
Calculate logits for masked tokens in the sentence using BERT. |
|
|
|
Args: |
|
original_sentence (str): Original sentence without masks |
|
original_mask_indices (list): List of indices to mask |
|
|
|
Returns: |
|
dict: Masked token indices and their logits |
|
""" |
|
print('==========================================================================================================') |
|
words = original_sentence.split() |
|
print(f' ##### calculate_mask_logits >> words : {words} ##### ') |
|
mask_logits = {} |
|
|
|
for idx in original_mask_indices: |
|
|
|
print(f' ---- idx : {idx} ----- ') |
|
masked_words = words.copy() |
|
masked_words[idx] = '[MASK]' |
|
|
|
masked_sentence = " ".join(masked_words) |
|
print(f' ---- masked_sentence : {masked_sentence} ----- ') |
|
|
|
|
|
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"] |
|
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1] |
|
|
|
with torch.no_grad(): |
|
outputs = self.model(input_ids) |
|
logits = outputs.logits |
|
|
|
|
|
mask_logits_tensor = logits[0, mask_token_index, :] |
|
|
|
|
|
top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 100, dim=-1) |
|
|
|
|
|
top_tokens = [] |
|
top_logits = [] |
|
seen_words = set() |
|
|
|
for token_id, logit in zip(top_mask_indices[0], top_mask_logits[0]): |
|
token = self.tokenizer.convert_ids_to_tokens(token_id.item()) |
|
|
|
|
|
if token.startswith('##'): |
|
continue |
|
|
|
|
|
word = self.tokenizer.convert_tokens_to_string([token]).strip() |
|
|
|
|
|
if word and word not in seen_words: |
|
seen_words.add(word) |
|
top_tokens.append(word) |
|
top_logits.append(logit.item()) |
|
|
|
|
|
if len(top_tokens) == 50: |
|
break |
|
|
|
|
|
|
|
|
|
mask_logits[idx] = { |
|
"tokens": top_tokens, |
|
"logits": top_logits |
|
} |
|
|
|
return mask_logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_sentences(self, sentences, result_dict, method="random"): |
|
""" |
|
Process sentences and calculate logits for masked tokens. |
|
""" |
|
results = {} |
|
|
|
for sentence, ngrams in result_dict.items(): |
|
|
|
words = sentence.split() |
|
last_word = words[-1] |
|
if any(last_word.endswith(p) for p in ['.', ',', '!', '?', ';', ':']): |
|
|
|
words[-1] = last_word[:-1] |
|
punctuation = last_word[-1] |
|
|
|
processed_sentence = " ".join(words) + " " + punctuation |
|
else: |
|
processed_sentence = sentence |
|
|
|
if method == "random": |
|
masked_sentence, original_mask_indices = self.mask_sentence_random(processed_sentence, ngrams) |
|
elif method == "pseudorandom": |
|
masked_sentence, original_mask_indices = self.mask_sentence_pseudorandom(processed_sentence, ngrams) |
|
else: |
|
masked_sentence, original_mask_indices = self.mask_sentence_entropy(processed_sentence, ngrams) |
|
|
|
logits = self.calculate_mask_logits(processed_sentence, original_mask_indices) |
|
results[sentence] = { |
|
"masked_sentence": masked_sentence, |
|
"mask_logits": logits |
|
} |
|
|
|
return results |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
sentences = [ |
|
"The quick brown fox jumps over small cat the lazy dog everyday again and again .", |
|
|
|
|
|
|
|
] |
|
result_dict ={ |
|
'The quick brown fox jumps over small cat the lazy dog everyday again and again .': {'brown fox': [(2, 3)],'cat': [(7, 7)], 'dog': [(10, 10)]}, |
|
|
|
|
|
} |
|
|
|
|
|
processor = MaskingProcessor() |
|
|
|
results_entropy = processor.process_sentences(sentences, result_dict, method="random") |
|
|
|
''' |
|
results structure : |
|
results = { |
|
"The quick brown fox jumps over the lazy dog everyday.": |
|
{ # Original sentence as key |
|
"masked_sentence": str, # The sentence with [MASK] tokens |
|
"mask_logits": |
|
{ # Dictionary of mask positions and their predictions |
|
1: |
|
{ # Position of mask in sentence |
|
"tokens" (words) : list, # List of top 50 predicted tokens |
|
"logits" (probabilities) : list # Corresponding logits for those tokens |
|
}, |
|
7: |
|
{ |
|
"tokens" (words) : list, |
|
"logits" (probabilities) : list |
|
}, |
|
10: |
|
{ |
|
"tokens (words)": list, |
|
"logits (probabilities)": list |
|
} |
|
} |
|
} |
|
} |
|
|
|
''' |
|
|
|
|
|
for sentence, output in results_entropy.items(): |
|
print(f"Original Sentence (Random): {sentence}") |
|
print(f"Masked Sentence (Random): {output['masked_sentence']}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|