File size: 843 Bytes
d9acf37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from transformers import pipeline

# Tải model sẵn để khỏi load nhiều lần
qa_pipeline = pipeline(
    "question-answering",
    model="marshmellow77/roberta-base-cuad",
    tokenizer="marshmellow77/roberta-base-cuad"
)

def run_prediction(questions, context, model_name=None, n_best_size=5):
    """
    - questions: list các câu hỏi (ví dụ ['What is the payment term?'])
    - context: đoạn văn bản (hợp đồng) để tìm câu trả lời
    - model_name: không cần, để giữ nguyên cho tương thích
    - n_best_size: không cần, giữ nguyên để gọi
    """
    answers = {}
    for idx, question in enumerate(questions):
        result = qa_pipeline({
            'context': context,
            'question': question
        })
        answers[str(idx)] = result['answer']
    return answers