|
""" |
|
This file contains the code to generate paraphrases of sentences. |
|
""" |
|
import os |
|
import sys |
|
import logging |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
from tqdm import tqdm |
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
|
from utils.config import load_config |
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s") |
|
logger = logging.getLogger(__name__) |
|
|
|
class Paraphraser: |
|
""" |
|
Paraphraser class to generate paraphrases of sentences. |
|
""" |
|
def __init__(self, config): |
|
self.config = config |
|
import torch |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
tqdm.write(f"[Paraphraser] Initializing on device: {self.device}") |
|
self.tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(config['model']).to(self.device) |
|
self.num_beams = config['num_beams'] |
|
self.num_beam_groups = config['num_beam_groups'] |
|
self.num_return_sequences = config['num_return_sequences'] |
|
self.repetition_penalty = config['repetition_penalty'] |
|
self.diversity_penalty = config['diversity_penalty'] |
|
self.no_repeat_ngram_size = config['no_repeat_ngram_size'] |
|
self.temperature = config['temperature'] |
|
self.max_length = config['max_length'] |
|
|
|
def paraphrase(self, sentence: str, num_return_sequences: int=None, num_beams: int=None, num_beam_groups: int=None): |
|
tqdm.write(f"[Paraphraser] Starting paraphrase for sentence: {sentence}") |
|
if num_return_sequences is None: |
|
num_return_sequences = self.num_return_sequences |
|
if num_beams is None: |
|
num_beams = self.num_beams |
|
if num_beam_groups is None: |
|
num_beam_groups = self.num_beam_groups |
|
|
|
inputs = self.tokenizer.encode("paraphrase: " + sentence, |
|
return_tensors="pt", |
|
max_length=self.max_length, |
|
truncation=True).to(self.device) |
|
outputs = self.model.generate( |
|
inputs, |
|
max_length=self.max_length, |
|
num_beams=num_beams, |
|
num_beam_groups=num_beam_groups, |
|
num_return_sequences=num_return_sequences, |
|
repetition_penalty=self.repetition_penalty, |
|
diversity_penalty=self.diversity_penalty, |
|
no_repeat_ngram_size=self.no_repeat_ngram_size, |
|
temperature=self.temperature |
|
) |
|
paraphrases = [self.tokenizer.decode(output, skip_special_tokens=True) |
|
for output in tqdm(outputs, desc="Decoding Paraphrases")] |
|
tqdm.write(f"[Paraphraser] Paraphrase completed. {len(paraphrases)} outputs generated.") |
|
return paraphrases |
|
|
|
if __name__ == "__main__": |
|
config_path = '/home/jigyasu/PECCAVI-Text/utils/config.yaml' |
|
config = load_config(config_path) |
|
paraphraser = Paraphraser(config['PECCAVI_TEXT']['Paraphrase']) |
|
sentence = "The quick brown fox jumps over the lazy dog." |
|
paraphrases = paraphraser.paraphrase(sentence) |
|
for paraphrase in paraphrases: |
|
print(paraphrase) |