File size: 7,026 Bytes
060ac52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
import random
from masking_methods import MaskingProcessor

class SamplingProcessor:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def fill_masked_sentence(self, original_sentence, mask_logits, sampling_technique, temperature=1.0):
        """
        Fills each mask in the masked sentence using the specified sampling technique.

        Args:
            original_sentence (str): The original masked sentence.
            mask_logits (dict): Logits for each [MASK] token.
            sampling_technique (str): Sampling technique to use (e.g., "inverse_transform", "exponential_minimum", "temperature", "greedy").
            temperature (float): Temperature parameter for sampling methods.

        Returns:
            str: Sentence with the masks filled.
        """
        sentence_tokens = self.tokenizer.tokenize(original_sentence)
        mask_token_indices = [i for i, token in enumerate(sentence_tokens) if token == self.tokenizer.mask_token]

        if len(mask_token_indices) != len(mask_logits):
            raise ValueError("Mismatch between number of [MASK] tokens and logits provided.")

        for mask_idx, filtered_logits in zip(mask_token_indices, mask_logits.values()):
            # Convert logits to a tensor
            filtered_logits = torch.tensor(filtered_logits)
            # filtered_logits, _ = torch.sort(filtered_logits, descending=True)
            # print(f' type of filtered_logits : {type(filtered_logits)}')
            # filtered_logits = filtered_logits[:5]

            if sampling_technique == "inverse_transform":
                probs = torch.softmax(filtered_logits / temperature, dim=-1)
                cumulative_probs = torch.cumsum(probs, dim=-1)
                random_prob = random.random()
                sampled_index = torch.where(cumulative_probs >= random_prob)[0][0].item()

            elif sampling_technique == "exponential_minimum":
                probs = torch.softmax(filtered_logits / temperature, dim=-1)
                exp_probs = torch.exp(-torch.log(probs))
                random_probs = torch.rand_like(exp_probs)
                sampled_index = torch.argmax(random_probs * exp_probs).item()

            elif sampling_technique == "temperature":
                filtered_logits = torch.clamp(filtered_logits, min=-1e8, max=1e8)
                probs = torch.softmax(filtered_logits / temperature, dim=-1)
                if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)):
                    raise ValueError("The computed probabilities contain NaN or inf values.")
                probs = torch.max(probs, torch.tensor(1e-8, device=filtered_logits.device))
                probs = probs / torch.sum(probs)
                probs = probs.flatten()
                if probs.size(0) > 1:
                    sampled_index = torch.multinomial(probs, 1).item()
                else:
                    sampled_index = torch.argmax(probs).item()

            elif sampling_technique == 'greedy':
                sampled_index = torch.argmax(filtered_logits).item()

            else:
                raise ValueError(f"Unknown sampling technique: {sampling_technique}")

            sampled_token = self.tokenizer.convert_ids_to_tokens([sampled_index])[0]
            sentence_tokens[mask_idx] = sampled_token

        return self.tokenizer.convert_tokens_to_string(sentence_tokens)



    def process_samples(self, masked_sentences, mask_logits, sampling_technique, temperature=1.0):
        """
        Process multiple masked sentences and fill their masks using the specified sampling technique.

        Args:
            masked_sentences (list): List of masked sentences.
            mask_logits (dict): Logits for each [MASK] token in each sentence.
            sampling_technique (str): Sampling technique to use.
            temperature (float): Temperature parameter for sampling methods.

        Returns:
            list: List of sentences with masks filled.
        """
        filled_sentences = []
        for sentence, logits in zip(masked_sentences, mask_logits):
            filled_sentence = self.fill_masked_sentence(sentence, logits, sampling_technique, temperature)
            filled_sentences.append(filled_sentence)
        return filled_sentences

# Example usage
if __name__ == "__main__":
    from transformers import BertTokenizer

    # tokenizer = BertTokenizer.from_pretrained("bert-large-cased-whole-word-masking")
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    processor = SamplingProcessor(tokenizer)

    sentences = [
        "The quick brown fox jumps over the lazy dog.",
        "A quick brown dog outpaces a lazy fox.",
        "Quick brown dog leaps over lazy the fox."
    ]
    result_dict = {
        "The quick brown fox jumps over the lazy dog.": {'quick brown': [(0, 1)], 'fox': [(2, 2)], 'lazy': [(4, 4)], 'dog': [(5, 5)]}, 
        "A quick brown dog outpaces a lazy fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}, 
        "Quick brown dog leaps over lazy the fox.": {'quick brown': [(0, 1)], 'fox': [(5, 5)], 'lazy': [(4, 4)], 'dog': [(2, 2)]}
    }


    masking_processor = MaskingProcessor()
    masking_results = masking_processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=False)
    # masked_sentence = "The [MASK] brown fox jumps [MASK] the lazy dog."
    # mask_logits = {
    #     1: torch.randn(len(tokenizer)),  # Example logits for first [MASK]
    #     5: torch.randn(len(tokenizer)),  # Example logits for second [MASK]
    # }
    
    # Iterate through masking results to apply sampling
    for sentence, result in masking_results.items():
        print(f"Original Sentence (Random): {sentence}")
        print(f"Masked Sentence (Random): {result['masked_sentence']}")
        # print(f"Mask Logits (Random): {output['mask_logits']}")
        print(f' type(result["mask_logits"]) : {type(result["mask_logits"])}')
        print(f' length of result["mask_logits"] : {len(result["mask_logits"])}')
        print(f' result["mask_logits"].keys() : {result["mask_logits"].keys()}')
        masked_sentence = result["masked_sentence"]
        mask_logits = result["mask_logits"]
        
        print(f"Original Masked Sentence: {masked_sentence}")
        
        # Apply different sampling techniques
        for technique in ["inverse_transform", "exponential_minimum", "temperature", "greedy"]:
            print(f"Sampling Technique: {technique}")
                
            # Fill the masks using the sampling processor
            filled_sentence = processor.fill_masked_sentence(
                original_sentence=masked_sentence,
                mask_logits=mask_logits,
                sampling_technique=technique,
                temperature=1.0  # Adjust temperature as needed
            )
                
            print(f"Filled Sentence: {filled_sentence}\n")
        print('--------------------------------')