File size: 3,488 Bytes
060ac52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
"""
This file contains the code to generate paraphrases of sentences.
"""
import os
import sys
import logging
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm import tqdm # for progress bars
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from utils.config import load_config
# config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml')
# config = load_config(config_path)['PECCAVI_TEXT']['Paraphrase']
# Configure logging to show only warnings or above on the terminal.
logging.basicConfig(level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class Paraphraser:
"""
Paraphraser class to generate paraphrases of sentences.
"""
def __init__(self, config):
self.config = config
import torch
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tqdm.write(f"[Paraphraser] Initializing on device: {self.device}")
self.tokenizer = AutoTokenizer.from_pretrained(config['tokenizer'])
self.model = AutoModelForSeq2SeqLM.from_pretrained(config['model']).to(self.device)
self.num_beams = config['num_beams']
self.num_beam_groups = config['num_beam_groups']
self.num_return_sequences = config['num_return_sequences']
self.repetition_penalty = config['repetition_penalty']
self.diversity_penalty = config['diversity_penalty']
self.no_repeat_ngram_size = config['no_repeat_ngram_size']
self.temperature = config['temperature']
self.max_length = config['max_length']
def paraphrase(self, sentence: str, num_return_sequences: int=None, num_beams: int=None, num_beam_groups: int=None):
tqdm.write(f"[Paraphraser] Starting paraphrase for sentence: {sentence}")
if num_return_sequences is None:
num_return_sequences = self.num_return_sequences
if num_beams is None:
num_beams = self.num_beams
if num_beam_groups is None:
num_beam_groups = self.num_beam_groups
inputs = self.tokenizer.encode("paraphrase: " + sentence,
return_tensors="pt",
max_length=self.max_length,
truncation=True).to(self.device)
outputs = self.model.generate(
inputs,
max_length=self.max_length,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
num_return_sequences=num_return_sequences,
repetition_penalty=self.repetition_penalty,
diversity_penalty=self.diversity_penalty,
no_repeat_ngram_size=self.no_repeat_ngram_size,
temperature=self.temperature
)
paraphrases = [self.tokenizer.decode(output, skip_special_tokens=True)
for output in tqdm(outputs, desc="Decoding Paraphrases")]
tqdm.write(f"[Paraphraser] Paraphrase completed. {len(paraphrases)} outputs generated.")
return paraphrases
if __name__ == "__main__":
config_path = '/home/jigyasu/PECCAVI-Text/utils/config.yaml'
config = load_config(config_path)
paraphraser = Paraphraser(config['PECCAVI_TEXT']['Paraphrase'])
sentence = "The quick brown fox jumps over the lazy dog."
paraphrases = paraphraser.paraphrase(sentence)
for paraphrase in paraphrases:
print(paraphrase) |