jgyasu's picture
Add entire pipeline
060ac52
"""
This file contains the code to generate paraphrases of sentences.
"""
import os
import sys
import logging
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm import tqdm # for progress bars
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from utils.config import load_config
# config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml')
# config = load_config(config_path)['PECCAVI_TEXT']['Paraphrase']
# Configure logging to show only warnings or above on the terminal.
logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class Paraphraser:
"""
Paraphraser class to generate paraphrases of sentences.
"""
def __init__(self, config):
self.config = config
import torch
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tqdm.write(f"[Paraphraser] Initializing on device: {self.device}")
self.tokenizer = AutoTokenizer.from_pretrained(config['tokenizer'])
self.model = AutoModelForSeq2SeqLM.from_pretrained(config['model']).to(self.device)
self.num_beams = config['num_beams']
self.num_beam_groups = config['num_beam_groups']
self.num_return_sequences = config['num_return_sequences']
self.repetition_penalty = config['repetition_penalty']
self.diversity_penalty = config['diversity_penalty']
self.no_repeat_ngram_size = config['no_repeat_ngram_size']
self.temperature = config['temperature']
self.max_length = config['max_length']
def paraphrase(self, sentence: str, num_return_sequences: int=None, num_beams: int=None, num_beam_groups: int=None):
tqdm.write(f"[Paraphraser] Starting paraphrase for sentence: {sentence}")
if num_return_sequences is None:
num_return_sequences = self.num_return_sequences
if num_beams is None:
num_beams = self.num_beams
if num_beam_groups is None:
num_beam_groups = self.num_beam_groups
inputs = self.tokenizer.encode("paraphrase: " + sentence,
return_tensors="pt",
max_length=self.max_length,
truncation=True).to(self.device)
outputs = self.model.generate(
inputs,
max_length=self.max_length,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
num_return_sequences=num_return_sequences,
repetition_penalty=self.repetition_penalty,
diversity_penalty=self.diversity_penalty,
no_repeat_ngram_size=self.no_repeat_ngram_size,
temperature=self.temperature
)
paraphrases = [self.tokenizer.decode(output, skip_special_tokens=True)
for output in tqdm(outputs, desc="Decoding Paraphrases")]
tqdm.write(f"[Paraphraser] Paraphrase completed. {len(paraphrases)} outputs generated.")
return paraphrases
if __name__ == "__main__":
config_path = '/home/jigyasu/PECCAVI-Text/utils/config.yaml'
config = load_config(config_path)
paraphraser = Paraphraser(config['PECCAVI_TEXT']['Paraphrase'])
sentence = "The quick brown fox jumps over the lazy dog."
paraphrases = paraphraser.paraphrase(sentence)
for paraphrase in paraphrases:
print(paraphrase)