File size: 7,026 Bytes
060ac52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import torch
import random
from masking_methods import MaskingProcessor
class SamplingProcessor:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def fill_masked_sentence(self, original_sentence, mask_logits, sampling_technique, temperature=1.0):
"""
Fills each mask in the masked sentence using the specified sampling technique.
Args:
original_sentence (str): The original masked sentence.
mask_logits (dict): Logits for each [MASK] token.
sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
temperature (float): Temperature parameter for sampling methods.
Returns:
str: Sentence with the masks filled.
"""
sentence_tokens = self.tokenizer.tokenize(original_sentence)
mask_token_indices = [i for i, token in enumerate(sentence_tokens) if token == self.tokenizer.mask_token]
if len(mask_token_indices) != len(mask_logits):
raise ValueError("Mismatch between number of [MASK] tokens and logits provided.")
for mask_idx, filtered_logits in zip(mask_token_indices, mask_logits.values()):
# Convert logits to a tensor
filtered_logits = torch.tensor(filtered_logits)
# filtered_logits, _ = torch.sort(filtered_logits, descending=True)
# print(f' type of filtered_logits : {type(filtered_logits)}')
# filtered_logits = filtered_logits[:5]
if sampling_technique == "inverse_transform":
probs = torch.softmax(filtered_logits / temperature, dim=-1)
cumulative_probs = torch.cumsum(probs, dim=-1)
random_prob = random.random()
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
elif sampling_technique == "exponential_minimum":
probs = torch.softmax(filtered_logits / temperature, dim=-1)
exp_probs = torch.exp(-torch.log(probs))
random_probs = torch.rand_like(exp_probs)
sampled_index = torch.argmax(random_probs * exp_probs).item()
elif sampling_technique == "temperature":
filtered_logits = torch.clamp(filtered_logits, min=-1e8, max=1e8)
probs = torch.softmax(filtered_logits / temperature, dim=-1)
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
raise ValueError("The computed probabilities contain NaN or inf values.")
probs = torch.max(probs, torch.tensor(1e-8, device=filtered_logits.device))
probs = probs / torch.sum(probs)
probs = probs.flatten()
if probs.size(0) > 1:
sampled_index = torch.multinomial(probs, 1).item()
else:
sampled_index = torch.argmax(probs).item()
elif sampling_technique == 'greedy':
sampled_index = torch.argmax(filtered_logits).item()
else:
raise ValueError(f"Unknown sampling technique: {sampling_technique}")
sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
sentence_tokens[mask_idx] = sampled_token
return self.tokenizer.convert_tokens_to_string(sentence_tokens)
def process_samples(self, masked_sentences, mask_logits, sampling_technique, temperature=1.0):
"""
Process multiple masked sentences and fill their masks using the specified sampling technique.
Args:
masked_sentences (list): List of masked sentences.
mask_logits (dict): Logits for each [MASK] token in each sentence.
sampling_technique (str): Sampling technique to use.
temperature (float): Temperature parameter for sampling methods.
Returns:
list: List of sentences with masks filled.
"""
filled_sentences = []
for sentence, logits in zip(masked_sentences, mask_logits):
filled_sentence = self.fill_masked_sentence(sentence, logits, sampling_technique, temperature)
filled_sentences.append(filled_sentence)
return filled_sentences
# Example usage
if __name__ == "__main__":
from transformers import BertTokenizer
# tokenizer = BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
processor = SamplingProcessor(tokenizer)
sentences = [
"The quick brown fox jumps over the lazy dog.",
"A quick brown dog outpaces a lazy fox.",
"Quick brown dog leaps over lazy the fox."
]
result_dict = {
"The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]},
"A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]},
"Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}
}
masking_processor = MaskingProcessor()
masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
# masked_sentence = "The [MASK] brown fox jumps [MASK] the lazy dog."
# mask_logits = {
# 1: torch.randn(len(tokenizer)), # Example logits for first [MASK]
# 5: torch.randn(len(tokenizer)), # Example logits for second [MASK]
# }
# Iterate through masking results to apply sampling
for sentence, result in masking_results.items():
print(f"Original Sentence (Random): {sentence}")
print(f"Masked Sentence (Random): {result['masked_sentence']}")
# print(f"Mask Logits (Random): {output['mask_logits']}")
print(f' type(result["mask_logits"]) : {type(result["mask_logits"])}')
print(f' length of result["mask_logits"] : {len(result["mask_logits"])}')
print(f' result["mask_logits"].keys() : {result["mask_logits"].keys()}')
masked_sentence = result["masked_sentence"]
mask_logits = result["mask_logits"]
print(f"Original Masked Sentence: {masked_sentence}")
# Apply different sampling techniques
for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
print(f"Sampling Technique: {technique}")
# Fill the masks using the sampling processor
filled_sentence = processor.fill_masked_sentence(
original_sentence=masked_sentence,
mask_logits=mask_logits,
sampling_technique=technique,
temperature=1.0 # Adjust temperature as needed
)
print(f"Filled Sentence: {filled_sentence}\n")
print('--------------------------------') |