File size: 1,131 Bytes
7472a45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Function to compute relevance score
def get_relevance_score(query, paragraph):
inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
scores = model(**inputs).logits.squeeze().item()
return round(scores, 4)
# Gradio interface
interface = gr.Interface(
fn=get_relevance_score,
inputs=[
gr.Textbox(label="Query", placeholder="Enter your search query..."),
gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...")
],
outputs=gr.Number(label="Relevance Score"),
title="Cross-Encoder Relevance Scoring",
description="Enter a query and a document paragraph to get a relevance score using the MS MARCO MiniLM L-12 v2 model."
)
if __name__ == "__main__":
interface.launch() |