X-encoder / app.py
wilwork's picture
Create app.py
7472a45 verified
raw
history blame
1.13 kB
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Function to compute relevance score
def get_relevance_score(query, paragraph):
inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
scores = model(**inputs).logits.squeeze().item()
return round(scores, 4)
# Gradio interface
interface = gr.Interface(
fn=get_relevance_score,
inputs=[
gr.Textbox(label="Query", placeholder="Enter your search query..."),
gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match...")
],
outputs=gr.Number(label="Relevance Score"),
title="Cross-Encoder Relevance Scoring",
description="Enter a query and a document paragraph to get a relevance score using the MS MARCO MiniLM L-12 v2 model."
)
if __name__ == "__main__":
interface.launch()