ai-text-watermarking-model / utils /old /masking /masking_methods_v1_working.py
jgyasu's picture
Add entire pipeline
060ac52
import random
import torch
from transformers import BertTokenizer, BertForMaskedLM
from nltk.corpus import stopwords
import nltk
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# THIS IS WORKING WHEN THE COORDINATES ARE WITHOUT REMOVING STOPWORDS
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# Ensure stopwords are downloaded
try:
nltk.data.find('corpora/stopwords')
except LookupError:
nltk.download('stopwords')
class MaskingProcessor:
def __init__(self):
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
self.stop_words = set(stopwords.words('english'))
def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords=False):
"""
Mask one word before the first common n-gram, one between two n-grams,
and one after the last common n-gram (random selection).
Args:
original_sentence (str): Original sentence
common_ngrams (dict): Common n-grams and their indices
Returns:
str: Masked sentence
"""
if remove_stopwords:
words = original_sentence.split()
words = [word for word in words if word not in self.stop_words]
else:
words = original_sentence.split()
mask_indices = []
# Handle before the first common n-gram
if common_ngrams:
first_ngram_start = list(common_ngrams.values())[0][0][0]
if first_ngram_start > 0:
mask_indices.append(random.randint(0, first_ngram_start - 1))
# Handle between common n-grams
ngram_positions = list(common_ngrams.values())
for i in range(len(ngram_positions) - 1):
end_prev = ngram_positions[i][-1][1]
start_next = ngram_positions[i + 1][0][0]
if start_next > end_prev + 1:
mask_indices.append(random.randint(end_prev + 1, start_next - 1))
# Handle after the last common n-gram
last_ngram_end = ngram_positions[-1][-1][1]
if last_ngram_end < len(words) - 1:
mask_indices.append(random.randint(last_ngram_end + 1, len(words) - 1))
# Mask the chosen indices
for idx in mask_indices:
if idx not in [index for ngram_indices in common_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
words[idx] = self.tokenizer.mask_token
return " ".join(words)
def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords=False):
"""
Mask one word before the first common n-gram, one between two n-grams,
and one after the last common n-gram (highest entropy selection).
Args:
original_sentence (str): Original sentence
common_ngrams (dict): Common n-grams and their indices
Returns:
str: Masked sentence
"""
if remove_stopwords:
words = original_sentence.split()
words = [word for word in words if word not in self.stop_words]
else:
words = original_sentence.split()
entropy_scores = {}
for idx, word in enumerate(words):
if idx in [index for ngram_indices in common_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
continue # Skip words in common n-grams
masked_sentence = words[:idx] + [self.tokenizer.mask_token] + words[idx + 1:]
masked_sentence = " ".join(masked_sentence)
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
with torch.no_grad():
outputs = self.model(input_ids)
logits = outputs.logits
filtered_logits = logits[0, mask_token_index, :]
probs = torch.softmax(filtered_logits, dim=-1)
entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item() # Add epsilon to prevent log(0)
entropy_scores[idx] = entropy
mask_indices = []
# Handle before the first common n-gram
if common_ngrams:
first_ngram_start = list(common_ngrams.values())[0][0][0]
candidates = [i for i in range(first_ngram_start) if i in entropy_scores]
if candidates:
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
# Handle between common n-grams
ngram_positions = list(common_ngrams.values())
for i in range(len(ngram_positions) - 1):
end_prev = ngram_positions[i][-1][1]
start_next = ngram_positions[i + 1][0][0]
candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores]
if candidates:
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
# Handle after the last common n-gram
last_ngram_end = ngram_positions[-1][-1][1]
candidates = [i for i in range(last_ngram_end + 1, len(words)) if i in entropy_scores]
if candidates:
mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))
# Mask the chosen indices
for idx in mask_indices:
words[idx] = self.tokenizer.mask_token
return " ".join(words)
def calculate_mask_logits(self, masked_sentence):
"""
Calculate logits for masked tokens in the sentence using BERT.
Args:
masked_sentence (str): Sentence with [MASK] tokens
Returns:
dict: Masked token indices and their logits
"""
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
with torch.no_grad():
outputs = self.model(input_ids)
logits = outputs.logits
mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
return mask_logits
def process_sentences(self, original_sentences, result_dict, remove_stopwords=False, method="random"):
"""
Process a list of sentences and calculate logits for masked tokens using the specified method.
Args:
original_sentences (list): List of original sentences
result_dict (dict): Common n-grams and their indices for each sentence
method (str): Masking method ("random" or "entropy")
Returns:
dict: Masked sentences and their logits for each sentence
"""
results = {}
for sentence, ngrams in result_dict.items():
if method == "random":
masked_sentence = self.mask_sentence_random(sentence, ngrams)
elif method == "entropy":
masked_sentence = self.mask_sentence_entropy(sentence, ngrams)
else:
raise ValueError("Invalid method. Choose 'random' or 'entropy'.")
logits = self.calculate_mask_logits(masked_sentence)
results[sentence] = {
"masked_sentence": masked_sentence,
"mask_logits": logits
}
return results
# Example usage
if __name__ == "__main__":
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# THIS IS WORKING WHEN THE COORDINATES ARE WITHOUT REMOVING STOPWORDS
sentences = [
"The quick brown fox jumps over the lazy dog.",
"A quick brown dog outpaces a lazy fox.",
"Quick brown animals leap over lazy obstacles."
]
result_dict = {
"The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]},
"A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]},
"Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]}
}
# result_dict = {
# "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
# "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
# "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
# }
processor = MaskingProcessor()
results_random = processor.process_sentences(sentences, result_dict, remove_stopwords=True, method="random")
results_entropy = processor.process_sentences(sentences, result_dict, remove_stopwords=True, method="entropy")
for sentence, output in results_random.items():
print(f"Original Sentence (Random): {sentence}")
print(f"Masked Sentence (Random): {output['masked_sentence']}")
# print(f"Mask Logits (Random): {output['mask_logits']}")
for sentence, output in results_entropy.items():
print(f"Original Sentence (Entropy): {sentence}")
print(f"Masked Sentence (Entropy): {output['masked_sentence']}")
# print(f"Mask Logits (Entropy): {output['mask_logits']}")
'''
result_dict = {
"The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
"A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
"Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
}
'''