File size: 13,619 Bytes
060ac52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 |
from transformers import BertTokenizer, BertForMaskedLM
import torch
import random
from masking_methods import MaskingProcessor
from transformers import pipeline
class SamplingProcessorWithModel:
def __init__(self, model_name='bert-base-uncased'):
self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.model = BertForMaskedLM.from_pretrained(model_name)
self.model.eval() # Set the model to evaluation mode
def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
"""
Fills each mask in the masked sentence using the specified sampling technique.
Args:
masked_sentence (str): Sentence with [MASK] tokens.
sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
temperature (float): Temperature parameter for sampling methods.
Returns:
str: Sentence with the masks filled.
"""
input_ids = self.tokenizer.encode(masked_sentence, return_tensors="pt")
while self.tokenizer.mask_token_id in input_ids[0]:
# Find indices of all [MASK] tokens
mask_indices = torch.where(input_ids == self.tokenizer.mask_token_id)[1]
# Process the first [MASK] token in the sequence
mask_index = mask_indices[0].item()
# Get logits from the model
with torch.no_grad():
outputs = self.model(input_ids)
logits = outputs.logits
# Extract logits for the [MASK] token
mask_logits = logits[0, mask_index]
if sampling_technique == "inverse_transform":
probs = torch.softmax(mask_logits / temperature, dim=-1)
cumulative_probs = torch.cumsum(probs, dim=-1)
random_prob = random.random()
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
elif sampling_technique == "exponential_minimum":
probs = torch.softmax(mask_logits / temperature, dim=-1)
exp_probs = torch.exp(-torch.log(probs))
random_probs = torch.rand_like(exp_probs)
sampled_index = torch.argmax(random_probs * exp_probs).item()
elif sampling_technique == "temperature":
mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8)
probs = torch.softmax(mask_logits / temperature, dim=-1)
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
raise ValueError("The computed probabilities contain NaN or inf values.")
probs = torch.max(probs, torch.tensor(1e-8, device=mask_logits.device))
probs = probs / torch.sum(probs)
probs = probs.flatten()
if probs.size(0) > 1:
sampled_index = torch.multinomial(probs, 1).item()
else:
sampled_index = torch.argmax(probs).item()
elif sampling_technique == 'greedy':
sampled_index = torch.argmax(mask_logits).item()
else:
raise ValueError(f"Unknown sampling technique: {sampling_technique}")
# Replace the first [MASK] with the selected token
input_ids[0, mask_index] = sampled_index
return self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
"""
Fills each mask in the masked sentence using the specified sampling technique.
Args:
masked_sentence (str): Sentence with [MASK] tokens.
sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
temperature (float): Temperature parameter for sampling methods.
Returns:
str: Sentence with the masks filled.
"""
while '[MASK]' in masked_sentence:
# Get predictions for the first [MASK]
predictions = self.unmasker(masked_sentence)
# Ensure predictions is a list of dictionaries
if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions):
raise ValueError("Unexpected structure in predictions from the pipeline.")
# Extract logits (scores) from the predictions
logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32)
if sampling_technique == "inverse_transform":
probs = torch.softmax(logits / temperature, dim=-1)
cumulative_probs = torch.cumsum(probs, dim=-1)
random_prob = random.random()
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
elif sampling_technique == "exponential_minimum":
probs = torch.softmax(logits / temperature, dim=-1)
exp_probs = torch.exp(-torch.log(probs))
random_probs = torch.rand_like(exp_probs)
sampled_index = torch.argmax(random_probs * exp_probs).item()
elif sampling_technique == "temperature":
logits = torch.clamp(logits, min=-1e8, max=1e8)
probs = torch.softmax(logits / temperature, dim=-1)
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
raise ValueError("The computed probabilities contain NaN or inf values.")
probs = torch.max(probs, torch.tensor(1e-8, device=logits.device))
probs = probs / torch.sum(probs)
probs = probs.flatten()
if probs.size(0) > 1:
sampled_index = torch.multinomial(probs, 1).item()
else:
sampled_index = torch.argmax(probs).item()
elif sampling_technique == 'greedy':
sampled_index = torch.argmax(logits).item()
else:
raise ValueError(f"Unknown sampling technique: {sampling_technique}")
# Replace the first [MASK] with the selected word
sampled_token = predictions[sampled_index]['token_str']
masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1)
return masked_sentence
# Example usage
if __name__ == "__main__":
from transformers import BertTokenizer
# Define sentences and result_dict
sentences = [
"The quick brown fox jumps over the lazy dog.",
"A quick brown dog outpaces a lazy fox.",
"Quick brown dog leaps over lazy the fox."
]
result_dict = {
"The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]},
"A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]},
"Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}
}
masking_processor = MaskingProcessor()
masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
# Use SamplingProcessor
sampling_processor = SamplingProcessorWithModel()
# Iterate through masking results to apply sampling
for sentence, result in masking_results.items():
print(f"Original Sentence (Random): {sentence}")
print(f"Masked Sentence (Random): {result['masked_sentence']}")
masked_sentence = result["masked_sentence"]
# Apply different sampling techniques
for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
print(f"Sampling Technique: {technique}")
filled_sentence = sampling_processor.fill_masked_sentence(
masked_sentence=masked_sentence,
sampling_technique=technique,
temperature=1.0 # Adjust temperature as needed
)
print(f"Filled Sentence: {filled_sentence}\n")
print('--------------------------------')
# from transformers import pipeline
# import torch
# import random
# from masking_methods import MaskingProcessor
# class SamplingProcessorWithPipeline:
# def __init__(self, model_name='bert-base-uncased'):
# self.unmasker = pipeline('fill-mask', model=model_name)
# self.tokenizer = self.unmasker.tokenizer
# def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
# """
# Fills each mask in the masked sentence using the specified sampling technique.
# Args:
# masked_sentence (str): Sentence with [MASK] tokens.
# sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
# temperature (float): Temperature parameter for sampling methods.
# Returns:
# str: Sentence with the masks filled.
# """
# while '[MASK]' in masked_sentence:
# # Get predictions for the first [MASK]
# predictions = self.unmasker(masked_sentence)
# print(f' predictions : {predictions}')
# print(f' type of predictions : {type(predictions)}')
# # Ensure predictions is a list of dictionaries for the first [MASK]
# if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions):
# raise ValueError("Unexpected structure in predictions from the pipeline.")
# # Extract logits (scores) from the predictions
# logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32)
# if sampling_technique == "inverse_transform":
# probs = torch.softmax(logits / temperature, dim=-1)
# cumulative_probs = torch.cumsum(probs, dim=-1)
# random_prob = random.random()
# sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()
# elif sampling_technique == "exponential_minimum":
# probs = torch.softmax(logits / temperature, dim=-1)
# exp_probs = torch.exp(-torch.log(probs))
# random_probs = torch.rand_like(exp_probs)
# sampled_index = torch.argmax(random_probs * exp_probs).item()
# elif sampling_technique == "temperature":
# logits = torch.clamp(logits, min=-1e8, max=1e8)
# probs = torch.softmax(logits / temperature, dim=-1)
# if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
# raise ValueError("The computed probabilities contain NaN or inf values.")
# probs = torch.max(probs, torch.tensor(1e-8, device=logits.device))
# probs = probs / torch.sum(probs)
# probs = probs.flatten()
# if probs.size(0) > 1:
# sampled_index = torch.multinomial(probs, 1).item()
# else:
# sampled_index = torch.argmax(probs).item()
# elif sampling_technique == 'greedy':
# sampled_index = torch.argmax(logits).item()
# else:
# raise ValueError(f"Unknown sampling technique: {sampling_technique}")
# # Replace the first [MASK] with the selected word
# sampled_token = predictions[sampled_index]['token_str']
# masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1)
# return masked_sentence
# # Example usage
# if __name__ == "__main__":
# from transformers import BertTokenizer
# # Define sentences and result_dict
# sentences = [
# "The quick brown fox jumps over the lazy dog.",
# "A quick brown dog outpaces a lazy fox.",
# "Quick brown animals leap over lazy obstacles."
# ]
# result_dict = {
# "The quick brown fox jumps over the lazy dog.": {"quick brown": [(1, 2)], "lazy": [(7, 7)]},
# "A quick brown dog outpaces a lazy fox.": {"quick brown": [(1, 2)], "lazy": [(6, 6)]},
# "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(5, 5)]}
# }
# masking_processor = MaskingProcessor()
# masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
# # Use SamplingProcessor
# sampling_processor = SamplingProcessorWithPipeline()
# # Iterate through masking results to apply sampling
# for sentence, result in masking_results.items():
# print(f"Original Sentence (Random): {sentence}")
# print(f"Masked Sentence (Random): {result['masked_sentence']}")
# masked_sentence = result["masked_sentence"]
# # Apply different sampling techniques
# for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
# print(f"Sampling Technique: {technique}")
# filled_sentence = sampling_processor.fill_masked_sentence(
# masked_sentence=masked_sentence,
# sampling_technique=technique,
# temperature=1.0 # Adjust temperature as needed
# )
# print(f"Filled Sentence: {filled_sentence}\n")
# print('--------------------------------')
|