File size: 16,347 Bytes
060ac52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
import random
import torch
from transformers import BertTokenizer, BertForMaskedLM
from nltk.corpus import stopwords
import nltk

# Ensure stopwords are downloaded
try:
    nltk.data.find('corpora/stopwords')
except LookupError:
    nltk.download('stopwords')

class MaskingProcessor:
    def __init__(self, ):
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.model = BertForMaskedLM.from_pretrained("bert-base-uncased")
        self.stop_words = set(stopwords.words('english'))

    def adjust_ngram_indices(self, words, common_ngrams, remove_stopwords):
        """
        Adjust indices of common n-grams after removing stop words.

        Args:
            words (list): List of words in the original sentence.
            common_ngrams (dict): Common n-grams and their indices.

        Returns:
            dict: Adjusted common n-grams and their indices.
        """
        if not remove_stopwords:
            return common_ngrams

        non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
        adjusted_ngrams = {}

        for ngram, positions in common_ngrams.items():
            adjusted_positions = []
            for start, end in positions:
                try:
                    new_start = non_stop_word_indices.index(start)
                    new_end = non_stop_word_indices.index(end)
                    adjusted_positions.append((new_start, new_end))
                except ValueError:
                    continue  # Skip if indices cannot be mapped
            adjusted_ngrams[ngram] = adjusted_positions

        return adjusted_ngrams

    # def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords):
    #     """
    #     Mask one word before the first common n-gram, one between two n-grams,
    #     and one after the last common n-gram (random selection).

    #     Args:
    #         original_sentence (str): Original sentence
    #         common_ngrams (dict): Common n-grams and their indices

    #     Returns:
    #         str: Masked sentence with original stop words retained
    #     """
    #     words = original_sentence.split()
    #     if remove_stopwords:
    #         non_stop_words = [word for word in words if word.lower() not in self.stop_words]
    #         non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
    #     else:
    #         non_stop_words = words
    #         non_stop_word_indices = list(range(len(words)))
    #     # non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
    #     adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)

    #     mask_indices = []
    #     # Handle before the first common n-gram
    #     if adjusted_ngrams:
    #         first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
    #         if first_ngram_start > 0:
    #             mask_indices.append(random.randint(0, first_ngram_start - 1))

    #     # Handle between common n-grams
    #     ngram_positions = list(adjusted_ngrams.values())
    #     for i in range(len(ngram_positions) - 1):
    #         end_prev = ngram_positions[i][-1][1]
    #         start_next = ngram_positions[i + 1][0][0]
    #         if start_next > end_prev + 1:
    #             mask_indices.append(random.randint(end_prev + 1, start_next - 1))

    #     # Handle after the last common n-gram
    #     last_ngram_end = ngram_positions[-1][-1][1]
    #     if last_ngram_end < len(non_stop_words) - 1:
    #         mask_indices.append(random.randint(last_ngram_end + 1, len(non_stop_words) - 1))

    #     # Mask the chosen indices
    #     original_masked_sentence = words[:]
    #     # for idx in mask_indices:
    #     #     if idx not in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
    #     #         non_stop_words[idx] = self.tokenizer.mask_token
    #     #         original_masked_sentence[idx] = self.tokenizer.mask_token
    #     for idx in mask_indices:
    #         if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
    #             continue  # Skip if index belongs to common n-grams
    #         if remove_stopwords:
    #             original_idx = non_stop_word_indices[idx]  # Map back to original indices
    #             original_masked_sentence[original_idx] = self.tokenizer.mask_token
    #         else:
    #             original_masked_sentence[idx] = self.tokenizer.mask_token


    #     return " ".join(original_masked_sentence)
    def mask_sentence_random(self, original_sentence, common_ngrams, remove_stopwords):
        """
        Mask one word before the first common n-gram, one between two n-grams,
        and one after the last common n-gram (random selection).

        Args:
            original_sentence (str): Original sentence
            common_ngrams (dict): Common n-grams and their indices
            remove_stopwords (bool): Whether to remove stop words

        Returns:
            str: Masked sentence with original stop words retained
        """
        words = original_sentence.split()
        if remove_stopwords:
            non_stop_words = [word for word in words if word.lower() not in self.stop_words]
            non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
        else:
            non_stop_words = words
            non_stop_word_indices = list(range(len(words)))

        adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)

        # Collect all indices corresponding to common n-grams
        common_ngram_indices = {
            idx for ngram_positions in adjusted_ngrams.values()
            for start, end in ngram_positions
            for idx in range(start, end + 1)
        }

        mask_indices = []
        # Handle before the first common n-gram
        if adjusted_ngrams:
            first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
            if first_ngram_start > 0:
                potential_indices = [i for i in range(first_ngram_start) if i not in common_ngram_indices]
                if potential_indices:
                    mask_indices.append(random.choice(potential_indices))

        # Handle between common n-grams
        ngram_positions = list(adjusted_ngrams.values())
        for i in range(len(ngram_positions) - 1):
            end_prev = ngram_positions[i][-1][1]
            start_next = ngram_positions[i + 1][0][0]
            potential_indices = [i for i in range(end_prev + 1, start_next) if i not in common_ngram_indices]
            if potential_indices:
                mask_indices.append(random.choice(potential_indices))

        # Handle after the last common n-gram
        last_ngram_end = ngram_positions[-1][-1][1]
        if last_ngram_end < len(non_stop_words) - 1:
            potential_indices = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i not in common_ngram_indices]
            if potential_indices:
                mask_indices.append(random.choice(potential_indices))

        # Mask the chosen indices
        original_masked_sentence = words[:]
        for idx in mask_indices:
            if remove_stopwords:
                original_idx = non_stop_word_indices[idx]  # Map back to original indices
                original_masked_sentence[original_idx] = self.tokenizer.mask_token
            else:
                original_masked_sentence[idx] = self.tokenizer.mask_token

        return " ".join(original_masked_sentence)

    def mask_sentence_entropy(self, original_sentence, common_ngrams, remove_stopwords):
        """
        Mask one word before the first common n-gram, one between two n-grams,
        and one after the last common n-gram (highest entropy selection).

        Args:
            original_sentence (str): Original sentence
            common_ngrams (dict): Common n-grams and their indices

        Returns:
            str: Masked sentence with original stop words retained
        """
        words = original_sentence.split()
        # non_stop_words = [word for word in words if word.lower() not in self.stop_words] if remove_stopwords else words
        if remove_stopwords:
            non_stop_words = [word for word in words if word.lower() not in self.stop_words]
            non_stop_word_indices = [i for i, word in enumerate(words) if word.lower() not in self.stop_words]
        else:
            non_stop_words = words
            non_stop_word_indices = list(range(len(words)))
        adjusted_ngrams = self.adjust_ngram_indices(words, common_ngrams, remove_stopwords)
        entropy_scores = {}

        for idx, word in enumerate(non_stop_words):
            if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
                continue  # Skip words in common n-grams

            masked_sentence = non_stop_words[:idx] + [self.tokenizer.mask_token] + non_stop_words[idx + 1:]
            masked_sentence = " ".join(masked_sentence)
            input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
            mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]

            with torch.no_grad():
                outputs = self.model(input_ids)
                logits = outputs.logits

            filtered_logits = logits[0, mask_token_index, :]
            probs = torch.softmax(filtered_logits, dim=-1)
            entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item()  # Add epsilon to prevent log(0)
            entropy_scores[idx] = entropy

        mask_indices = []

        # Handle before the first common n-gram
        if adjusted_ngrams:
            first_ngram_start = list(adjusted_ngrams.values())[0][0][0]
            candidates = [i for i in range(first_ngram_start) if i in entropy_scores]
            if candidates:
                mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))

        # Handle between common n-grams
        ngram_positions = list(adjusted_ngrams.values())
        for i in range(len(ngram_positions) - 1):
            end_prev = ngram_positions[i][-1][1]
            start_next = ngram_positions[i + 1][0][0]
            candidates = [i for i in range(end_prev + 1, start_next) if i in entropy_scores]
            if candidates:
                mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))

        # Handle after the last common n-gram
        last_ngram_end = ngram_positions[-1][-1][1]
        candidates = [i for i in range(last_ngram_end + 1, len(non_stop_words)) if i in entropy_scores]
        if candidates:
            mask_indices.append(max(candidates, key=lambda x: entropy_scores[x]))

        # Mask the chosen indices
        original_masked_sentence = words[:]
        # for idx in mask_indices:
        #     non_stop_words[idx] = self.tokenizer.mask_token
        #     original_masked_sentence[idx] = self.tokenizer.mask_token

        for idx in mask_indices:
            if idx in [index for ngram_indices in adjusted_ngrams.values() for start, end in ngram_indices for index in range(start, end + 1)]:
                continue  # Skip if index belongs to common n-grams
            if remove_stopwords:
                original_idx = non_stop_word_indices[idx]  # Map back to original indices
                original_masked_sentence[original_idx] = self.tokenizer.mask_token
            else:
                original_masked_sentence[idx] = self.tokenizer.mask_token


        return " ".join(original_masked_sentence)

    def calculate_mask_logits(self, masked_sentence):
        """
        Calculate logits for masked tokens in the sentence using BERT.

        Args:
            masked_sentence (str): Sentence with [MASK] tokens

        Returns:
            dict: Masked token indices and their logits
        """
        input_ids = self.tokenizer(masked_sentence, return_tensors="pt")["input_ids"]
        mask_token_index = torch.where(input_ids == self.tokenizer.mask_token_id)[1]

        with torch.no_grad():
            outputs = self.model(input_ids)
            logits = outputs.logits

        mask_logits = {idx.item(): logits[0, idx].tolist() for idx in mask_token_index}
        return mask_logits

    def process_sentences(self, original_sentences, result_dict, method="random", remove_stopwords=False):
        """
        Process a list of sentences and calculate logits for masked tokens using the specified method.

        Args:
            original_sentences (list): List of original sentences
            result_dict (dict): Common n-grams and their indices for each sentence
            method (str): Masking method ("random" or "entropy")

        Returns:
            dict: Masked sentences and their logits for each sentence
        """
        results = {}

        for sentence, ngrams in result_dict.items():
            if method == "random":
                masked_sentence = self.mask_sentence_random(sentence, ngrams, remove_stopwords)
            elif method == "entropy":
                masked_sentence = self.mask_sentence_entropy(sentence, ngrams, remove_stopwords)
            else:
                raise ValueError("Invalid method. Choose 'random' or 'entropy'.")

            logits = self.calculate_mask_logits(masked_sentence)
            results[sentence] = {
                "masked_sentence": masked_sentence,
                "mask_logits": logits
            }

        return results

# Example usage
if __name__ == "__main__":
    # !!! Working both the cases regardless if the stopword is removed or not 
    sentences = [
        "The quick brown fox jumps over the lazy dog.",
        "A speedy brown fox jumps over a lazy dog.",
        "A swift brown fox leaps over the lethargic dog."
    ]
    result_dict ={
        'The quick brown fox jumps over the lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, 
        'A speedy brown fox jumps over a lazy dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}, 
        'A swift brown fox leaps over the lethargic dog.': {'brown fox': [(2, 3)], 'dog': [(8, 8)]}
    }


    processor = MaskingProcessor()
    results_random = processor.process_sentences(sentences, result_dict, method="random", remove_stopwords=True)
    # results_entropy = processor.process_sentences(sentences, result_dict, method="entropy", remove_stopwords=False)

    for sentence, output in results_random.items():
        print(f"Original Sentence (Random): {sentence}")
        print(f"Masked Sentence (Random): {output['masked_sentence']}")
        # # print(f"Mask Logits (Random): {output['mask_logits']}")
        # print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
        # print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
        # print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')
        print('--------------------------------')
        # for mask_idx, logits in output["mask_logits"].items():
        #     print(f"Logits for [MASK] at position {mask_idx}:")
        #     print(f' logits : {logits[:5]}')  # List of logits for all vocabulary tokens




    # result_dict = {
    #     "The quick brown fox jumps over the lazy dog.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
    #     "A quick brown dog outpaces a lazy fox.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]},
    #     "Quick brown animals leap over lazy obstacles.": {"quick brown": [(0, 1)], "lazy": [(4, 4)]}
    # }


    # print('--------------------------------')
    # for sentence, output in results_entropy.items():
    #     print(f"Original Sentence (Entropy): {sentence}")
    #     print(f"Masked Sentence (Entropy): {output['masked_sentence']}")
    #     # print(f"Mask Logits (Entropy): {output['mask_logits']}")
    #     print(f' type(output["mask_logits"]) : {type(output["mask_logits"])}')
    #     print(f' length of output["mask_logits"] : {len(output["mask_logits"])}')
    #     print(f' output["mask_logits"].keys() : {output["mask_logits"].keys()}')