|
import sys |
|
import os |
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
|
import numpy as np |
|
from transformers import pipeline |
|
from typing import List |
|
from utils.config import load_config |
|
|
|
|
|
class EntailmentAnalyzer: |
|
|
|
def __init__(self, config): |
|
""" |
|
Initialize the EntailmentAnalyzer with the config file path. |
|
|
|
Args: |
|
config_path: The path to the configuration file. |
|
""" |
|
|
|
self.config = config |
|
self.entailment_pipeline = pipeline(task=self.config['task'], model=self.config['model']) |
|
|
|
def check_entailment(self, premise: str, hypothesis: str) -> float: |
|
""" |
|
Check entailment between the premise and hypothesis. |
|
|
|
Args: |
|
premise: The premise sentence. |
|
hypothesis: The hypothesis sentence. |
|
|
|
Returns: |
|
float: The entailment score. |
|
""" |
|
results = self.entailment_pipeline(f"{premise} [SEP] {hypothesis}", top_k=None) |
|
entailment_score = next(item['score'] for item in results if item['label'] == 'entailment') |
|
return entailment_score |
|
|
|
def analyze_entailment(self, original_sentence: str, paraphrased_sentences: List[str], threshold: float) -> tuple: |
|
""" |
|
Analyze entailment scores for paraphrased sentences. If no selected sentences are found, |
|
lower the threshold and rerun the analysis. |
|
|
|
Args: |
|
original_sentence: The original sentence. |
|
paraphrased_sentences: List of paraphrased sentences. |
|
threshold: Minimum score to select a sentence. |
|
|
|
Returns: |
|
tuple: A dictionary of all scores, selected sentences, and discarded sentences. |
|
""" |
|
all_sentences = {} |
|
selected_sentences = {} |
|
discarded_sentences = {} |
|
|
|
|
|
while not selected_sentences: |
|
for paraphrased_sentence in paraphrased_sentences: |
|
entailment_score = self.check_entailment(original_sentence, paraphrased_sentence) |
|
|
|
all_sentences[paraphrased_sentence] = entailment_score |
|
if entailment_score >= threshold: |
|
selected_sentences[paraphrased_sentence] = entailment_score |
|
else: |
|
discarded_sentences[paraphrased_sentence] = entailment_score |
|
|
|
|
|
if not selected_sentences: |
|
print(f"No selected sentences found. Lowering the threshold by 0.1 (from {threshold} to {threshold - 0.1}).") |
|
threshold -= 0.1 |
|
if threshold <= 0: |
|
print("Threshold has reached 0. No sentences meet the criteria.") |
|
break |
|
|
|
return all_sentences, selected_sentences, discarded_sentences |
|
|
|
|
|
if __name__ == "__main__": |
|
config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml') |
|
|
|
config_path = '/home/ashhar21137/text_wm/scratch/utils/config/config.yaml' |
|
|
|
config = load_config(config_path) |
|
|
|
entailment_analyzer = EntailmentAnalyzer(config['PECCAVI_TEXT']['Entailment']) |
|
|
|
all_sentences, selected_sentences, discarded_sentences = entailment_analyzer.analyze_entailment( |
|
"The weather is nice today", |
|
[ |
|
"The climate is pleasant today", |
|
"It's a good day weather-wise", |
|
"Today, the weather is terrible", |
|
"What a beautiful day it is", |
|
"The sky is clear and the weather is perfect", |
|
"It's pouring rain outside today", |
|
"The weather isn't bad today", |
|
"A lovely day for outdoor activities" |
|
], |
|
0.7 |
|
) |
|
|
|
print("----------------------- All Sentences -----------------------") |
|
print(all_sentences) |
|
print("----------------------- Discarded Sentences -----------------------") |
|
print(discarded_sentences) |
|
print("----------------------- Selected Sentences -----------------------") |
|
print(selected_sentences) |
|
|