compare_docs / predict.py
gnaw05's picture
init
d9acf37
raw
history blame contribute delete
843 Bytes
from transformers import pipeline
# Tải model sẵn để khỏi load nhiều lần
qa_pipeline = pipeline(
"question-answering",
model="marshmellow77/roberta-base-cuad",
tokenizer="marshmellow77/roberta-base-cuad"
)
def run_prediction(questions, context, model_name=None, n_best_size=5):
"""
- questions: list các câu hỏi (ví dụ ['What is the payment term?'])
- context: đoạn văn bản (hợp đồng) để tìm câu trả lời
- model_name: không cần, để giữ nguyên cho tương thích
- n_best_size: không cần, giữ nguyên để gọi
"""
answers = {}
for idx, question in enumerate(questions):
result = qa_pipeline({
'context': context,
'question': question
})
answers[str(idx)] = result['answer']
return answers