""" This file contains the code to generate paraphrases of sentences. """ import os import sys import logging from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from tqdm import tqdm # for progress bars sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from utils.config import load_config # config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml') # config = load_config(config_path)['PECCAVI_TEXT']['Paraphrase'] # Configure logging to show only warnings or above on the terminal. logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) class Paraphraser: """ Paraphraser class to generate paraphrases of sentences. """ def __init__(self, config): self.config = config import torch self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tqdm.write(f"[Paraphraser] Initializing on device: {self.device}") self.tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']) self.model = AutoModelForSeq2SeqLM.from_pretrained(config['model']).to(self.device) self.num_beams = config['num_beams'] self.num_beam_groups = config['num_beam_groups'] self.num_return_sequences = config['num_return_sequences'] self.repetition_penalty = config['repetition_penalty'] self.diversity_penalty = config['diversity_penalty'] self.no_repeat_ngram_size = config['no_repeat_ngram_size'] self.temperature = config['temperature'] self.max_length = config['max_length'] def paraphrase(self, sentence: str, num_return_sequences: int=None, num_beams: int=None, num_beam_groups: int=None): tqdm.write(f"[Paraphraser] Starting paraphrase for sentence: {sentence}") if num_return_sequences is None: num_return_sequences = self.num_return_sequences if num_beams is None: num_beams = self.num_beams if num_beam_groups is None: num_beam_groups = self.num_beam_groups inputs = self.tokenizer.encode("paraphrase: " + sentence, return_tensors="pt", max_length=self.max_length, truncation=True).to(self.device) outputs = self.model.generate( inputs, max_length=self.max_length, num_beams=num_beams, num_beam_groups=num_beam_groups, num_return_sequences=num_return_sequences, repetition_penalty=self.repetition_penalty, diversity_penalty=self.diversity_penalty, no_repeat_ngram_size=self.no_repeat_ngram_size, temperature=self.temperature ) paraphrases = [self.tokenizer.decode(output, skip_special_tokens=True) for output in tqdm(outputs, desc="Decoding Paraphrases")] tqdm.write(f"[Paraphraser] Paraphrase completed. {len(paraphrases)} outputs generated.") return paraphrases if __name__ == "__main__": config_path = '/home/jigyasu/PECCAVI-Text/utils/config.yaml' config = load_config(config_path) paraphraser = Paraphraser(config['PECCAVI_TEXT']['Paraphrase']) sentence = "The quick brown fox jumps over the lazy dog." paraphrases = paraphraser.paraphrase(sentence) for paraphrase in paraphrases: print(paraphrase)