ai-text-watermarking-model / utils /old /masking_methods_final_copy.py
jgyasu's picture
Add entire pipeline
060ac52
import random
import torch
from transformers import BertTokenizer, BertForMaskedLM
from nltk.corpus import stopwords
import nltk
from transformers import RobertaTokenizer, RobertaForMaskedLM
# Ensure stopwords are downloaded
try:
nltk.data.find('corpora/stopwords')
except LookupError:
nltk.download('stopwords')
class MaskingProcessor:
# def __init__(self, tokenizer, model):
def __init__(self):
# self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# self.tokenizer = tokenizer
# self.model = model
self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
self.model = BertForMaskedLM.from_pretrained("bert-large-cased-whole-word-masking")
# self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
# self.model = RobertaForMaskedLM.from_pretrained("roberta-base")
self.stop_words = set(stopwords.words('english'))
def remove_stopwords(self, words):
"""
Remove stopwords from the given list of words.
Args:
words (list): List of words.
Returns:
list: List of non-stop words.
"""
return [word for word in words if word.lower() not in self.stop_words]
def adjust_ngram_indices(self, original_words, common_ngrams):
"""
Adjust indices of common n-grams after removing stopwords.
Args:
original_words (list): Original list of words.
common_ngrams (dict): Common n-grams and their indices.
Returns:
dict: Adjusted common n-grams with updated indices.
"""
non_stop_words = self.remove_stopwords(original_words)
original_to_non_stop = []
non_stop_idx = 0
for original_idx, word in enumerate(original_words):
if word.lower() not in self.stop_words:
original_to_non_stop.append((original_idx, non_stop_idx))
non_stop_idx += 1
adjusted_ngrams = {}
for ngram, positions in common_ngrams.items():
adjusted_positions = []
for start, end in positions:
try:
new_start = next(non_stop for orig, non_stop in original_to_non_stop if orig == start)
new_end = next(non_stop for orig, non_stop in original_to_non_stop if orig == end)
adjusted_positions.append((new_start, new_end))
except StopIteration:
continue # Skip if indices cannot be mapped
adjusted_ngrams[ngram] = adjusted_positions
return adjusted_ngrams
def mask_sentence_random(self, sentence, common_ngrams):
"""
Mask words in the sentence based on the specified rules after removing stopwords.
"""
# Split sentence into words
original_words = sentence.split()
# Handle punctuation at the end
has_punctuation = False
punctuation = None
if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
has_punctuation = True
punctuation = original_words[-1][-1]
original_words = original_words[:-1]
print(f' ---- original_words : {original_words} ----- ')
# Process words without punctuation
non_stop_words = self.remove_stopwords(original_words)
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
# Rest of the existing function code...
mask_indices = []
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
if ngram_positions:
first_ngram_start = ngram_positions[0][0]
if first_ngram_start > 0:
mask_index_before_ngram = random.randint(0, first_ngram_start-1)
mask_indices.append(mask_index_before_ngram)
# Mask words between common n-grams
for i in range(len(ngram_positions) - 1):
end_prev = ngram_positions[i][1]
start_next = ngram_positions[i + 1][0]
if start_next > end_prev + 1:
mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
mask_indices.append(mask_index_between_ngrams)
# Mask a word after the last common n-gram
last_ngram_end = ngram_positions[-1][1]
if last_ngram_end < len(non_stop_words) - 1:
mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
mask_indices.append(mask_index_after_ngram)
# Create mapping from non-stop words to original indices
non_stop_to_original = {}
non_stop_idx = 0
for orig_idx, word in enumerate(original_words):
if word.lower() not in self.stop_words:
non_stop_to_original[non_stop_idx] = orig_idx
non_stop_idx += 1
# Map mask indices and apply masks
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
masked_words = original_words.copy()
for idx in original_mask_indices:
masked_words[idx] = self.tokenizer.mask_token
# masked_words[idx] = '<mask>' # for roberta
# Add back punctuation if it existed
if has_punctuation:
masked_words.append(punctuation)
print(f' ***** masked_words at end : {masked_words} ***** ')
print(f' ***** original_mask_indices : {original_mask_indices} ***** ')
print(f' ***** TESTING : {" ".join(masked_words)} ***** ')
return " ".join(masked_words), original_mask_indices
def mask_sentence_pseudorandom(self, sentence, common_ngrams):
"""
Mask words in the sentence based on the specified rules after removing stopwords.
"""
# Split sentence into words
random.seed(3)
original_words = sentence.split()
# Handle punctuation at the end
has_punctuation = False
punctuation = None
if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
has_punctuation = True
punctuation = original_words[-1][-1]
original_words = original_words[:-1]
print(f' ---- original_words : {original_words} ----- ')
# Process words without punctuation
non_stop_words = self.remove_stopwords(original_words)
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
# Rest of the existing function code...
mask_indices = []
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
if ngram_positions:
first_ngram_start = ngram_positions[0][0]
if first_ngram_start > 0:
mask_index_before_ngram = random.randint(0, first_ngram_start-1)
mask_indices.append(mask_index_before_ngram)
# Mask words between common n-grams
for i in range(len(ngram_positions) - 1):
end_prev = ngram_positions[i][1]
start_next = ngram_positions[i + 1][0]
if start_next > end_prev + 1:
mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
mask_indices.append(mask_index_between_ngrams)
# Mask a word after the last common n-gram
last_ngram_end = ngram_positions[-1][1]
if last_ngram_end < len(non_stop_words) - 1:
mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
mask_indices.append(mask_index_after_ngram)
# Create mapping from non-stop words to original indices
non_stop_to_original = {}
non_stop_idx = 0
for orig_idx, word in enumerate(original_words):
if word.lower() not in self.stop_words:
non_stop_to_original[non_stop_idx] = orig_idx
non_stop_idx += 1
# Map mask indices and apply masks
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
masked_words = original_words.copy()
for idx in original_mask_indices:
masked_words[idx] = self.tokenizer.mask_token
# masked_words[idx] = '<mask>' # for roberta
# Add back punctuation if it existed
if has_punctuation:
masked_words.append(punctuation)
print(f' ***** masked_words at end : {masked_words} ***** ')
print(f' ***** original_mask_indices : {original_mask_indices} ***** ')
print(f' ***** TESTING : {" ".join(masked_words)} ***** ')
return " ".join(masked_words), original_mask_indices
def calculate_word_entropy(self, sentence, word_position):
"""
Calculate entropy for a specific word position in the sentence.
Args:
sentence (str): The input sentence
word_position (int): Position of the word to calculate entropy for
Returns:
float: Entropy value for the word
"""
words = sentence.split()
masked_words = words.copy()
masked_words[word_position] = self.tokenizer.mask_token
masked_sentence = " ".join(masked_words)
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
with torch.no_grad():
outputs = self.model(input_ids)
logits = outputs.logits
# Get probabilities for the masked position
probs = torch.nn.functional.softmax(logits[0, mask_token_index], dim=-1)
# Calculate entropy: -sum(p * log(p))
entropy = -torch.sum(probs * torch.log(probs + 1e-9))
return entropy.item()
def mask_sentence_entropy(self, sentence, common_ngrams):
"""
Mask words in the sentence based on entropy, following n-gram positioning rules.
Args:
sentence (str): Original sentence
common_ngrams (dict): Common n-grams and their indices
Returns:
str: Masked sentence
"""
# Split sentence into words
original_words = sentence.split()
# Handle punctuation at the end
has_punctuation = False
punctuation = None
if original_words and any(original_words[-1].endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
has_punctuation = True
punctuation = original_words[-1][-1]
original_words = original_words[:-1]
# Process words without punctuation
non_stop_words = self.remove_stopwords(original_words)
adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
# Create mapping from non-stop words to original indices
non_stop_to_original = {}
original_to_non_stop = {}
non_stop_idx = 0
for orig_idx, word in enumerate(original_words):
if word.lower() not in self.stop_words:
non_stop_to_original[non_stop_idx] = orig_idx
original_to_non_stop[orig_idx] = non_stop_idx
non_stop_idx += 1
ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
mask_indices = []
if ngram_positions:
# Handle words before first n-gram
first_ngram_start = ngram_positions[0][0]
if first_ngram_start > 0:
candidate_positions = range(0, first_ngram_start)
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
for pos in candidate_positions]
mask_indices.append(max(entropies, key=lambda x: x[1])[0])
# Handle words between n-grams
for i in range(len(ngram_positions) - 1):
end_prev = ngram_positions[i][1]
start_next = ngram_positions[i + 1][0]
if start_next > end_prev + 1:
candidate_positions = range(end_prev + 1, start_next)
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
for pos in candidate_positions]
mask_indices.append(max(entropies, key=lambda x: x[1])[0])
# Handle words after last n-gram
last_ngram_end = ngram_positions[-1][1]
if last_ngram_end < len(non_stop_words) - 1:
candidate_positions = range(last_ngram_end + 1, len(non_stop_words))
entropies = [(pos, self.calculate_word_entropy(sentence, non_stop_to_original[pos]))
for pos in candidate_positions]
mask_indices.append(max(entropies, key=lambda x: x[1])[0])
# Map mask indices to original sentence positions and apply masks
original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
masked_words = original_words.copy()
for idx in original_mask_indices:
masked_words[idx] = self.tokenizer.mask_token
# Add back punctuation if it existed
if has_punctuation:
masked_words.append(punctuation)
return " ".join(masked_words), original_mask_indices
def calculate_mask_logits(self, original_sentence, original_mask_indices):
"""
Calculate logits for masked tokens in the sentence using BERT.
Args:
original_sentence (str): Original sentence without masks
original_mask_indices (list): List of indices to mask
Returns:
dict: Masked token indices and their logits
"""
print('==========================================================================================================')
words = original_sentence.split()
print(f' ##### calculate_mask_logits >> words : {words} ##### ')
mask_logits = {}
for idx in original_mask_indices:
# Create a copy of words and mask the current position
print(f' ---- idx : {idx} ----- ')
masked_words = words.copy()
masked_words[idx] = '[MASK]'
# masked_words[idx] = '<mask>' # for roberta
masked_sentence = " ".join(masked_words)
print(f' ---- masked_sentence : {masked_sentence} ----- ')
# Calculate logits for the current mask
input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
with torch.no_grad():
outputs = self.model(input_ids)
logits = outputs.logits
# Extract logits for the masked position
mask_logits_tensor = logits[0, mask_token_index, :]
# Get top logits and corresponding tokens
top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 100, dim=-1) # Get more candidates
# Convert token IDs to words and filter out subword tokens
top_tokens = []
top_logits = []
seen_words = set() # To keep track of unique words
for token_id, logit in zip(top_mask_indices[0], top_mask_logits[0]):
token = self.tokenizer.convert_ids_to_tokens(token_id.item())
# Skip if it's a subword token (starts with ##)
if token.startswith('##'):
continue
# Convert token to proper word
word = self.tokenizer.convert_tokens_to_string([token]).strip()
# Only add if it's a new word and not empty
if word and word not in seen_words:
seen_words.add(word)
top_tokens.append(word)
top_logits.append(logit.item())
# Break if we have 50 unique complete words
if len(top_tokens) == 50:
break
# print(f' ---- top_tokens : {top_tokens} ----- ')
# Store results
mask_logits[idx] = {
"tokens": top_tokens,
"logits": top_logits
}
return mask_logits
# def calculate_mask_logits(self, original_sentence, original_mask_indices):
# """
# Calculate logits for masked tokens in the sentence using BERT.
# Args:
# original_sentence (str): Original sentence without masks
# original_mask_indices (list): List of indices to mask
# Returns:
# dict: Masked token indices and their logits
# """
# words = original_sentence.split()
# print(f' ##### calculate_mask_logits >> words : {words} ##### ')
# mask_logits = {}
# for idx in original_mask_indices:
# # Create a copy of words and mask the current position
# print(f' ---- idx : {idx} ----- ')
# masked_words = words.copy()
# print(f' ---- words : {masked_words} ----- ')
# # masked_words[idx] = self.tokenizer.mask_token
# masked_words[idx] = '[MASK]'
# print(f' ---- masked_words : {masked_words} ----- ')
# masked_sentence = " ".join(masked_words)
# print(f' ---- masked_sentence : {masked_sentence} ----- ')
# # Calculate logits for the current mask
# input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
# mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
# with torch.no_grad():
# outputs = self.model(input_ids)
# logits = outputs.logits
# # Extract logits for the masked position
# mask_logits_tensor = logits[0, mask_token_index, :]
# # Get top 50 logits and corresponding tokens
# top_mask_logits, top_mask_indices = torch.topk(mask_logits_tensor, 50, dim=-1)
# # Convert token IDs to words
# top_tokens = [self.tokenizer.convert_ids_to_tokens(token_id.item()) for token_id in top_mask_indices[0]]
# print(f' ---- top_tokens : {top_tokens} ----- ')
# # Store results
# mask_logits[idx] = {
# "tokens": top_tokens,
# "logits": top_mask_logits.tolist()
# }
# return mask_logits
def process_sentences(self, sentences, result_dict, method="random"):
"""
Process sentences and calculate logits for masked tokens.
"""
results = {}
for sentence, ngrams in result_dict.items():
# Split punctuation from the last word before processing
words = sentence.split()
last_word = words[-1]
if any(last_word.endswith(p) for p in ['.', ',', '!', '?', ';', ':']):
# Split the last word and punctuation
words[-1] = last_word[:-1]
punctuation = last_word[-1]
# Rejoin with space before punctuation to treat it as separate token
processed_sentence = " ".join(words) + " " + punctuation
else:
processed_sentence = sentence
if method == "random":
masked_sentence, original_mask_indices = self.mask_sentence_random(processed_sentence, ngrams)
elif method == "pseudorandom":
masked_sentence, original_mask_indices = self.mask_sentence_pseudorandom(processed_sentence, ngrams)
else: # entropy
masked_sentence, original_mask_indices = self.mask_sentence_entropy(processed_sentence, ngrams)
logits = self.calculate_mask_logits(processed_sentence, original_mask_indices)
results[sentence] = {
"masked_sentence": masked_sentence,
"mask_logits": logits
}
return results
if __name__ == "__main__":
# !!! Working both the cases regardless if the stopword is removed or not
sentences = [
"The quick brown fox jumps over small cat the lazy dog everyday again and again .",
# "A speedy brown fox jumps over a lazy dog.",
# "A swift brown fox leaps over the lethargic dog."
]
result_dict ={
'The quick brown fox jumps over small cat the lazy dog everyday again and again .': {'brown fox': [(2, 3)],'cat': [(7, 7)], 'dog': [(10, 10)]},
# 'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]},
# 'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
}
processor = MaskingProcessor()
# results_random = processor.process_sentences(sentences, result_dict)
results_entropy = processor.process_sentences(sentences, result_dict, method="random")
'''
results structure :
results = {
"The quick brown fox jumps over the lazy dog everyday.":
{ # Original sentence as key
"masked_sentence": str, # The sentence with [MASK] tokens
"mask_logits":
{ # Dictionary of mask positions and their predictions
1:
{ # Position of mask in sentence
"tokens" (words) : list, # List of top 50 predicted tokens
"logits" (probabilities) : list # Corresponding logits for those tokens
},
7:
{
"tokens" (words) : list,
"logits" (probabilities) : list
},
10:
{
"tokens (words)": list,
"logits (probabilities)": list
}
}
}
}
'''
# results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)
for sentence, output in results_entropy.items():
print(f"Original Sentence (Random): {sentence}")
print(f"Masked Sentence (Random): {output['masked_sentence']}")
# print(f"Mask Logits (Random): {output['mask_logits']}")
# print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
# print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
# print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
# print('--------------------------------')
# for mask_idx, logits in output["mask_logits"].items():
# print(f"Logits for [MASK] at position {mask_idx}:")
# print(f' logits : {logits[:5]}') # List of logits for all vocabulary tokens
# print(f' len(logits) : {len(logits)}')
# ------------------------------------------------------------------------------------------------
# def mask_sentence_random(self, sentence, common_ngrams):
# """
# Mask words in the sentence based on the specified rules after removing stopwords.
# """
# original_words = sentence.split()
# # print(f' ---- original_words : {original_words} ----- ')
# non_stop_words = self.remove_stopwords(original_words)
# # print(f' ---- non_stop_words : {non_stop_words} ----- ')
# adjusted_ngrams = self.adjust_ngram_indices(original_words, common_ngrams)
# # print(f' ---- common_ngrams : {common_ngrams} ----- ')
# # print(f' ---- adjusted_ngrams : {adjusted_ngrams} ----- ')
# mask_indices = []
# # Extract n-gram positions in non-stop words
# ngram_positions = [pos for positions in adjusted_ngrams.values() for pos in positions]
# # Mask a word before the first common n-gram
# if ngram_positions:
# # print(f' ---- ngram_positions : {ngram_positions} ----- ')
# first_ngram_start = ngram_positions[0][0]
# # print(f' ---- first_ngram_start : {first_ngram_start} ----- ')
# if first_ngram_start > 0:
# mask_index_before_ngram = random.randint(0, first_ngram_start-1)
# # print(f' ---- mask_index_before_ngram : {mask_index_before_ngram} ----- ')
# mask_indices.append(mask_index_before_ngram)
# # Mask words between common n-grams
# for i in range(len(ngram_positions) - 1):
# end_prev = ngram_positions[i][1]
# # print(f' ---- end_prev : {end_prev} ----- ')
# start_next = ngram_positions[i + 1][0]
# # print(f' ---- start_next : {start_next} ----- ')
# if start_next > end_prev + 1:
# mask_index_between_ngrams = random.randint(end_prev + 1, start_next - 1)
# # print(f' ---- mask_index_between_ngrams : {mask_index_between_ngrams} ----- ')
# mask_indices.append(mask_index_between_ngrams)
# # Mask a word after the last common n-gram
# last_ngram_end = ngram_positions[-1][1]
# if last_ngram_end < len(non_stop_words) - 1:
# # print(f' ---- last_ngram_end : {last_ngram_end} ----- ')
# mask_index_after_ngram = random.randint(last_ngram_end + 1, len(non_stop_words) - 1)
# # print(f' ---- mask_index_after_ngram : {mask_index_after_ngram} ----- ')
# mask_indices.append(mask_index_after_ngram)
# # Create mapping from non-stop words to original indices
# non_stop_to_original = {}
# non_stop_idx = 0
# for orig_idx, word in enumerate(original_words):
# if word.lower() not in self.stop_words:
# non_stop_to_original[non_stop_idx] = orig_idx
# non_stop_idx += 1
# # Map mask indices from non-stop word positions to original positions
# # print(f' ---- non_stop_to_original : {non_stop_to_original} ----- ')
# original_mask_indices = [non_stop_to_original[idx] for idx in mask_indices]
# # print(f' ---- original_mask_indices : {original_mask_indices} ----- ')
# # Apply masks to the original sentence
# masked_words = original_words.copy()
# for idx in original_mask_indices:
# masked_words[idx] = self.tokenizer.mask_token
# return " ".join(masked_words), original_mask_indices