|
import torch |
|
import random |
|
import logging |
|
from utils.masking_methods import MaskingProcessor |
|
from tqdm import tqdm |
|
|
|
|
|
logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s") |
|
logger = logging.getLogger(__name__) |
|
|
|
class SamplingProcessor: |
|
def __init__(self, tokenizer): |
|
""" |
|
Initialize the SamplingProcessor. |
|
|
|
Args: |
|
tokenizer: BERT tokenizer instance |
|
""" |
|
self.tokenizer = tokenizer |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
tqdm.write(f"[SamplingProcessor] Initialized on device: {self.device}") |
|
|
|
def sample_tokens(self, mask_logits_dict, masked_sentence, sampling_technique="temperature", temperature=1.0): |
|
""" |
|
Sample tokens for each mask in the sentence using the specified sampling technique. |
|
|
|
Args: |
|
mask_logits_dict (dict): Dictionary of mask positions and their logits/tokens |
|
masked_sentence (str): Sentence with [MASK] tokens |
|
sampling_technique (str): Sampling method to use |
|
temperature (float): Temperature parameter for sampling |
|
|
|
Returns: |
|
str: Sentence with sampled tokens replacing masks |
|
""" |
|
tqdm.write(f"[SamplingProcessor] Sampling tokens for: {masked_sentence}") |
|
print(f"[SamplingProcessor] Sampling tokens for: {masked_sentence}") |
|
words = masked_sentence.split() |
|
print(f"words: {words}") |
|
|
|
mask_positions = sorted(mask_logits_dict.keys()) |
|
print(f"mask_positions: {mask_positions}") |
|
|
|
for mask_pos in mask_positions: |
|
mask_data = mask_logits_dict[mask_pos] |
|
|
|
mask_logits = torch.tensor(mask_data['logits']).to(self.device) |
|
candidate_tokens = mask_data['tokens'] |
|
|
|
try: |
|
if sampling_technique == "inverse_transform": |
|
probs = torch.softmax(mask_logits / temperature, dim=-1) |
|
cumulative_probs = torch.cumsum(probs, dim=-1) |
|
random_prob = random.random() |
|
sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item() |
|
|
|
elif sampling_technique == "exponential_minimum": |
|
probs = torch.softmax(mask_logits / temperature, dim=-1) |
|
exp_probs = torch.exp(-torch.log(probs)) |
|
random_probs = torch.rand_like(exp_probs) |
|
sampled_index = torch.argmax(random_probs * exp_probs).item() |
|
|
|
elif sampling_technique == "temperature": |
|
mask_logits = torch.clamp(mask_logits, min=-1e8, max=1e8) |
|
probs = torch.softmax(mask_logits / temperature, dim=-1) |
|
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)): |
|
raise ValueError("The computed probabilities contain NaN or inf values.") |
|
probs = torch.max(probs, torch.tensor(1e-8).to(self.device)) |
|
probs = probs / torch.sum(probs) |
|
probs = probs.flatten() |
|
if probs.size(0) > 1: |
|
sampled_index = torch.multinomial(probs, 1).item() |
|
else: |
|
sampled_index = torch.argmax(probs).item() |
|
|
|
elif sampling_technique == 'greedy': |
|
sampled_index = torch.argmax(mask_logits).item() |
|
|
|
else: |
|
raise ValueError(f"Unknown sampling technique: {sampling_technique}") |
|
|
|
|
|
sampled_token = candidate_tokens[sampled_index] |
|
|
|
sampled_token = sampled_token.replace('##', '') |
|
words[mask_pos] = sampled_token |
|
logger.info(f"Sampled token '{sampled_token}' for mask position {mask_pos}.") |
|
|
|
except Exception as e: |
|
logger.error(f"Error sampling for position {mask_pos}: {str(e)}") |
|
continue |
|
|
|
sampled_sentence = " ".join(words) |
|
tqdm.write(f"[SamplingProcessor] Sampled sentence: {sampled_sentence}") |
|
return sampled_sentence |
|
|
|
def process_masked_sentences(self, results_dict, sampling_technique="temperature", temperature=1.0): |
|
""" |
|
Process all masked sentences in the results dictionary. |
|
|
|
Args: |
|
results_dict (dict): Dictionary containing masked sentences and their logits |
|
sampling_technique (str): Sampling method to use |
|
temperature (float): Temperature parameter for sampling |
|
|
|
Returns: |
|
dict: Dictionary containing original, masked, and sampled sentences |
|
""" |
|
tqdm.write("[SamplingProcessor] Starting sampling for masked sentences.") |
|
processed_results = {} |
|
|
|
for original_sentence, data in tqdm(results_dict.items(), desc="Sampling Masked Sentences"): |
|
masked_sentence = data["masked_sentence"] |
|
mask_logits = data["mask_logits"] |
|
|
|
sampled_sentence = self.sample_tokens(mask_logits, |
|
masked_sentence, |
|
sampling_technique, |
|
temperature) |
|
processed_results[original_sentence] = { |
|
"masked_sentence": masked_sentence, |
|
"sampled_sentence": sampled_sentence |
|
} |
|
logger.info(f"Processed sampling for sentence: {original_sentence}") |
|
tqdm.write("[SamplingProcessor] Completed sampling for all sentences.") |
|
return processed_results |
|
|
|
|
|
if __name__ == "__main__": |
|
sentences = [ |
|
"The quick brown fox jumps over the lazy dog everyday.", |
|
"A speedy brown fox jumps over a lazy dog.", |
|
"A swift brown fox leaps over the lethargic dog." |
|
] |
|
result_dict = { |
|
'The quick brown fox jumps over the lazy dog everyday.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, |
|
'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, |
|
'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]} |
|
} |
|
|
|
|
|
masking_processor = MaskingProcessor() |
|
masking_results = masking_processor.process_sentences(sentences, result_dict) |
|
|
|
|
|
sampling_processor = SamplingProcessor(masking_processor.tokenizer) |
|
|
|
|
|
sampling_techniques = ["temperature", "greedy", "inverse_transform", "exponential_minimum"] |
|
|
|
for technique in sampling_techniques: |
|
logger.info(f"Sampling using technique: {technique}") |
|
sampled_results = sampling_processor.process_masked_sentences( |
|
masking_results, |
|
sampling_technique=technique, |
|
temperature=1.0 |
|
) |
|
|
|
''' |
|
{ |
|
"original_sentence_1": |
|
{ |
|
"masked_sentence": "sentence with [MASK] tokens", |
|
"sampling_method1": "sentence with sampled tokens", |
|
}, |
|
"original_sentence_2": |
|
{ |
|
"masked_sentence": "sentence with [MASK] tokens", |
|
"sampling_method": "sentence with sampled tokens" |
|
}, |
|
# ... and so on for each input sentence |
|
}, |
|
|
|
''' |
|
|
|
for original_sentence, result in sampled_results.items(): |
|
logger.info(f"Original: {original_sentence}") |
|
logger.info(f"Masked: {result['masked_sentence']}") |
|
logger.info(f"Sampled: {result['sampled_sentence']}") |
|
logger.info("---") |
|
|
|
|