File size: 4,192 Bytes
060ac52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import numpy as np
from transformers import pipeline
from typing import List
from utils.config import load_config
class EntailmentAnalyzer:
# def __init__(self, config_path: str):
def __init__(self, config):
"""
Initialize the EntailmentAnalyzer with the config file path.
Args:
config_path: The path to the configuration file.
"""
# self.config = load_config(config_path)['PECCAVI_TEXT']['Entailment']
self.config = config
self.entailment_pipeline = pipeline(task=self.config['task'], model=self.config['model'])
def check_entailment(self, premise: str, hypothesis: str) -> float:
"""
Check entailment between the premise and hypothesis.
Args:
premise: The premise sentence.
hypothesis: The hypothesis sentence.
Returns:
float: The entailment score.
"""
results = self.entailment_pipeline(f"{premise} [SEP] {hypothesis}", top_k=None)
entailment_score = next(item['score'] for item in results if item['label'] == 'entailment')
return entailment_score
def analyze_entailment(self, original_sentence: str, paraphrased_sentences: List[str], threshold: float) -> tuple:
"""
Analyze entailment scores for paraphrased sentences. If no selected sentences are found,
lower the threshold and rerun the analysis.
Args:
original_sentence: The original sentence.
paraphrased_sentences: List of paraphrased sentences.
threshold: Minimum score to select a sentence.
Returns:
tuple: A dictionary of all scores, selected sentences, and discarded sentences.
"""
all_sentences = {}
selected_sentences = {}
discarded_sentences = {}
# Loop to reduce threshold if no sentences are selected
while not selected_sentences:
for paraphrased_sentence in paraphrased_sentences:
entailment_score = self.check_entailment(original_sentence, paraphrased_sentence)
all_sentences[paraphrased_sentence] = entailment_score
if entailment_score >= threshold:
selected_sentences[paraphrased_sentence] = entailment_score
else:
discarded_sentences[paraphrased_sentence] = entailment_score
# If no sentences are selected, lower the threshold
if not selected_sentences:
print(f"No selected sentences found. Lowering the threshold by 0.1 (from {threshold} to {threshold - 0.1}).")
threshold -= 0.1
if threshold <= 0:
print("Threshold has reached 0. No sentences meet the criteria.")
break
return all_sentences, selected_sentences, discarded_sentences
if __name__ == "__main__":
config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml')
config_path = '/home/ashhar21137/text_wm/scratch/utils/config/config.yaml'
config = load_config(config_path)
entailment_analyzer = EntailmentAnalyzer(config['PECCAVI_TEXT']['Entailment'])
all_sentences, selected_sentences, discarded_sentences = entailment_analyzer.analyze_entailment(
"The weather is nice today",
[
"The climate is pleasant today",
"It's a good day weather-wise",
"Today, the weather is terrible",
"What a beautiful day it is",
"The sky is clear and the weather is perfect",
"It's pouring rain outside today",
"The weather isn't bad today",
"A lovely day for outdoor activities"
],
0.7
)
print("----------------------- All Sentences -----------------------")
print(all_sentences)
print("----------------------- Discarded Sentences -----------------------")
print(discarded_sentences)
print("----------------------- Selected Sentences -----------------------")
print(selected_sentences)
|