File size: 3,488 Bytes
060ac52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
This file contains the code to generate paraphrases of sentences.
"""
import os
import sys
import logging
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm import tqdm  # for progress bars
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from utils.config import load_config
# config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml')
# config = load_config(config_path)['PECCAVI_TEXT']['Paraphrase']

# Configure logging to show only warnings or above on the terminal.
logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

class Paraphraser:
    """
    Paraphraser class to generate paraphrases of sentences.
    """
    def __init__(self, config):
        self.config = config
        import torch
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        tqdm.write(f"[Paraphraser] Initializing on device: {self.device}")
        self.tokenizer = AutoTokenizer.from_pretrained(config['tokenizer'])
        self.model = AutoModelForSeq2SeqLM.from_pretrained(config['model']).to(self.device)
        self.num_beams = config['num_beams']
        self.num_beam_groups = config['num_beam_groups']
        self.num_return_sequences = config['num_return_sequences']
        self.repetition_penalty = config['repetition_penalty']
        self.diversity_penalty = config['diversity_penalty']
        self.no_repeat_ngram_size = config['no_repeat_ngram_size']
        self.temperature = config['temperature']
        self.max_length = config['max_length']

    def paraphrase(self, sentence: str, num_return_sequences: int=None, num_beams: int=None, num_beam_groups: int=None):
        tqdm.write(f"[Paraphraser] Starting paraphrase for sentence: {sentence}")
        if num_return_sequences is None:
            num_return_sequences = self.num_return_sequences
        if num_beams is None:
            num_beams = self.num_beams
        if num_beam_groups is None:
            num_beam_groups = self.num_beam_groups

        inputs = self.tokenizer.encode("paraphrase: " + sentence,
                                       return_tensors="pt",
                                       max_length=self.max_length,
                                       truncation=True).to(self.device)
        outputs = self.model.generate(
            inputs,
            max_length=self.max_length,
            num_beams=num_beams,
            num_beam_groups=num_beam_groups,
            num_return_sequences=num_return_sequences,
            repetition_penalty=self.repetition_penalty,
            diversity_penalty=self.diversity_penalty,
            no_repeat_ngram_size=self.no_repeat_ngram_size,
            temperature=self.temperature
        )
        paraphrases = [self.tokenizer.decode(output, skip_special_tokens=True)
                       for output in tqdm(outputs, desc="Decoding Paraphrases")]
        tqdm.write(f"[Paraphraser] Paraphrase completed. {len(paraphrases)} outputs generated.")
        return paraphrases
    
if __name__ == "__main__":
    config_path = '/home/jigyasu/PECCAVI-Text/utils/config.yaml'
    config = load_config(config_path)
    paraphraser = Paraphraser(config['PECCAVI_TEXT']['Paraphrase'])
    sentence = "The quick brown fox jumps over the lazy dog."
    paraphrases = paraphraser.paraphrase(sentence)
    for paraphrase in paraphrases:
        print(paraphrase)