|
import gradio as gr |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
import torch |
|
|
|
|
|
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
|
|
|
def get_relevance_score(query, paragraph): |
|
inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True) |
|
with torch.no_grad(): |
|
scores = model(**inputs).logits.squeeze().item() |
|
return round(scores, 4) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=get_relevance_score, |
|
inputs=[ |
|
gr.Textbox(label="Query", placeholder="Enter your search query..."), |
|
gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...") |
|
], |
|
outputs=gr.Number(label="Relevance Score"), |
|
title="Cross-Encoder Relevance Scoring", |
|
description="Enter a query and a document paragraph to get a relevance score using the MS MARCO MiniLM L-12 v2 model." |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |