ai-text-watermarking-model / utils /old /masking /masking_methods_ok_working.py
jgyasu's picture
Add entire pipeline
060ac52
import random
import torch
from transformers import BertTokenizer, BertForMaskedLM
from nltk.corpus import stopwords
import nltk
# Ensure stopwords are downloaded
try:
nltk.data.find('corpora/stopwords')
except LookupError:
nltk.download('stopwords')
class MaskingProcessor:
def __init__(self, ):
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
self.stop_words = set(stopwords.words('english'))
def adjust_ngram_indices(self, words, common_ngrams, remove_stopwords):
"""
Adjust indices of common n-grams after removing stop words.
Args:
words (list): List of words in the original sentence.
common_ngrams (dict): Common n-grams and their indices.
Returns:
dict: Adjusted common n-grams and their indices.
"""
if not remove_stopwords:
return common_ngrams
non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
adjusted_ngrams = {}
for ngram, positions in common_ngrams.items():
adjusted_positions = []
for start, end in positions:
try:
new_start = non_stop_word_indices.index(start)
new_end = non_stop_word_indices.index(end)
adjusted_positions.append((new_start, new_end))
except ValueError:
continue # Skip if indices cannot be mapped
adjusted_ngrams[ngram] = adjusted_positions
return adjusted_ngrams
def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords):
"""
Mask one word before the first common n-gram, one between two n-grams,
and one after the last common n-gram (random selection).
Args:
original_sentence (str): Original sentence
common_ngrams (dict): Common n-grams and their indices
Returns:
str: Masked sentence with original stop words retained
"""
words = original_sentence.split()
non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
mask_indices = []
# Handle before the first common n-gram
if adjusted_ngrams:
first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
if first_ngram_start > 0:
mask_indices.append(random.randint(0, first_ngram_start - 1))
# Handle between common n-grams
ngram_positions = list(adjusted_ngrams.values())
for i in range(len(ngram_positions) - 1):
end_prev = ngram_positions[i][-1][1]
start_next = ngram_positions[i + 1][0][0]
if start_next > end_prev + 1:
mask_indices.append(random.randint(end_prev + 1, start_next - 1))
# Handle after the last common n-gram
last_ngram_end = ngram_positions[-1][-1][1]
if last_ngram_end < len(non_stop_words) - 1:
mask_indices.append(random.randint(last_ngram_end + 1, len(non_stop_words) - 1))
# Mask the chosen indices
original_masked_sentence = words[:]
for idx in mask_indices:
if idx not in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
non_stop_words[idx] = self.tokenizer.mask_token
original_masked_sentence[idx] = self.tokenizer.mask_token
return " ".join(original_masked_sentence)
def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords):
"""
Mask one word before the first common n-gram, one between two n-grams,
and one after the last common n-gram (highest entropy selection).
Args:
original_sentence (str): Original sentence
common_ngrams (dict): Common n-grams and their indices
Returns:
str: Masked sentence with original stop words retained
"""
words = original_sentence.split()
non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
entropy_scores = {}
for idx, word in enumerate(non_stop_words):
if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
continue # Skip words in common n-grams
masked_sentence = non_stop_words[:idx] + [self.tokenizer.mask_token] + non_stop_words[idx + 1:]
masked_sentence = " ".join(masked_sentence)
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
with torch.no_grad():
outputs = self.model(input_ids)
logits = outputs.logits
filtered_logits = logits[0, mask_token_index, :]
probs = torch.softmax(filtered_logits, dim=-1)
entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0)
entropy_scores[idx] = entropy
mask_indices = []
# Handle before the first common n-gram
if adjusted_ngrams:
first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
candidates = [i for i in range(first_ngram_start) if i in entropy_scores]
if candidates:
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
# Handle between common n-grams
ngram_positions = list(adjusted_ngrams.values())
for i in range(len(ngram_positions) - 1):
end_prev = ngram_positions[i][-1][1]
start_next = ngram_positions[i + 1][0][0]
candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores]
if candidates:
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
# Handle after the last common n-gram
last_ngram_end = ngram_positions[-1][-1][1]
candidates = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i in entropy_scores]
if candidates:
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
# Mask the chosen indices
original_masked_sentence = words[:]
for idx in mask_indices:
non_stop_words[idx] = self.tokenizer.mask_token
original_masked_sentence[idx] = self.tokenizer.mask_token
return " ".join(original_masked_sentence)
def calculate_mask_logits(self, masked_sentence):
"""
Calculate logits for masked tokens in the sentence using BERT.
Args:
masked_sentence (str): Sentence with [MASK] tokens
Returns:
dict: Masked token indices and their logits
"""
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
with torch.no_grad():
outputs = self.model(input_ids)
logits = outputs.logits
mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
return mask_logits
def process_sentences(self, original_sentences, result_dict, method="random", remove_stopwords=False):
"""
Process a list of sentences and calculate logits for masked tokens using the specified method.
Args:
original_sentences (list): List of original sentences
result_dict (dict): Common n-grams and their indices for each sentence
method (str): Masking method ("random" or "entropy")
Returns:
dict: Masked sentences and their logits for each sentence
"""
results = {}
for sentence, ngrams in result_dict.items():
if method == "random":
masked_sentence = self.mask_sentence_random(sentence, ngrams, remove_stopwords)
elif method == "entropy":
masked_sentence = self.mask_sentence_entropy(sentence, ngrams, remove_stopwords)
else:
raise ValueError("Invalid method. Choose 'random' or 'entropy'.")
logits = self.calculate_mask_logits(masked_sentence)
results[sentence] = {
"masked_sentence": masked_sentence,
"mask_logits": logits
}
return results
# Example usage
if __name__ == "__main__":
# !!! Working both the cases regardless if the stopword is removed or not
sentences = [
"The quick brown fox jumps over the lazy dog.",
"A quick brown dog outpaces a lazy fox.",
"Quick brown animals leap over lazy obstacles."
]
result_dict = {
"The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]},
"A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]},
"Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]}
}
# result_dict = {
# "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
# "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
# "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
# }
processor = MaskingProcessor()
results_random = processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
# results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
for sentence, output in results_random.items():
print(f"Original Sentence (Random): {sentence}")
print(f"Masked Sentence (Random): {output['masked_sentence']}")
# print(f"Mask Logits (Random): {output['mask_logits']}")
print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
print('--------------------------------')
for mask_idx, logits in output["mask_logits"].items():
print(f"Logits for [MASK] at position {mask_idx}:")
print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
# print('--------------------------------')
# for sentence, output in results_entropy.items():
# print(f"Original Sentence (Entropy): {sentence}")
# print(f"Masked Sentence (Entropy): {output['masked_sentence']}")
# # print(f"Mask Logits (Entropy): {output['mask_logits']}")
# print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
# print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
# print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')