File size: 5,228 Bytes
060ac52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from transformers import pipeline
import torch
import random
from masking_methods import MaskingProcessor


class SamplingProcessorWithPipeline:
    def __init__(self, model_name='bert-base-uncased'):
        self.unmasker = pipeline('fill-mask', model=model_name)
        self.tokenizer = self.unmasker.tokenizer

    def fill_masked_sentence(self, masked_sentence, sampling_technique, temperature=1.0):
        """
        Fills each mask in the masked sentence using the specified sampling technique.

        Args:
            masked_sentence (str): Sentence with [MASK] tokens.
            sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
            temperature (float): Temperature parameter for sampling methods.

        Returns:
            str: Sentence with the masks filled.
        """
        while '[MASK]' in masked_sentence:
            # Get predictions for the first [MASK]
            predictions = self.unmasker(masked_sentence)
            print(f' predictions : {predictions}')
            print(f' type of predictions : {type(predictions)}')

            # Ensure predictions is a list of dictionaries
            if not isinstance(predictions, list) or not all(isinstance(pred, dict) for pred in predictions):
                raise ValueError("Unexpected structure in predictions from the pipeline.")

            # Extract logits (scores) from the predictions
            logits = torch.tensor([pred['score'] for pred in predictions], dtype=torch.float32)

            if sampling_technique == "inverse_transform":
                probs = torch.softmax(logits / temperature, dim=-1)
                cumulative_probs = torch.cumsum(probs, dim=-1)
                random_prob = random.random()
                sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()

            elif sampling_technique == "exponential_minimum":
                probs = torch.softmax(logits / temperature, dim=-1)
                exp_probs = torch.exp(-torch.log(probs))
                random_probs = torch.rand_like(exp_probs)
                sampled_index = torch.argmax(random_probs * exp_probs).item()

            elif sampling_technique == "temperature":
                logits = torch.clamp(logits, min=-1e8, max=1e8)
                probs = torch.softmax(logits / temperature, dim=-1)
                if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
                    raise ValueError("The computed probabilities contain NaN or inf values.")
                probs = torch.max(probs, torch.tensor(1e-8, device=logits.device))
                probs = probs / torch.sum(probs)
                probs = probs.flatten()
                if probs.size(0) > 1:
                    sampled_index = torch.multinomial(probs, 1).item()
                else:
                    sampled_index = torch.argmax(probs).item()

            elif sampling_technique == 'greedy':
                sampled_index = torch.argmax(logits).item()

            else:
                raise ValueError(f"Unknown sampling technique: {sampling_technique}")

            # Replace the first [MASK] with the selected word
            sampled_token = predictions[sampled_index]['token_str']
            masked_sentence = masked_sentence.replace('[MASK]', sampled_token, 1)

        return masked_sentence


# Example usage
if __name__ == "__main__":
    from transformers import BertTokenizer

    # Define sentences and result_dict
    sentences = [
        "The quick brown fox jumps over the lazy dog.",
        "A quick brown dog outpaces a lazy fox.",
        "Quick brown dog leaps over lazy the fox."
    ]
    result_dict = {
        "The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]}, 
        "A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}, 
        "Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}
    }

    masking_processor = MaskingProcessor()
    masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)

    # Use SamplingProcessor
    sampling_processor = SamplingProcessorWithPipeline()

    # Iterate through masking results to apply sampling
    for sentence, result in masking_results.items():
        print(f"Original Sentence (Random): {sentence}")
        print(f"Masked Sentence (Random): {result['masked_sentence']}")
        masked_sentence = result["masked_sentence"]

        # Apply different sampling techniques
        for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
            print(f"Sampling Technique: {technique}")
            filled_sentence = sampling_processor.fill_masked_sentence(
                masked_sentence=masked_sentence,
                sampling_technique=technique,
                temperature=1.0  # Adjust temperature as needed
            )
            print(f"Filled Sentence: {filled_sentence}\n")
        print('--------------------------------')