Spaces:
Running
Running
File size: 2,260 Bytes
c23c7fd 4b87ee5 c0a3229 c23c7fd c3b53de c2d61c7 c23c7fd 71a36a2 c23c7fd ccf784a c23c7fd c7d91ee c23c7fd ac94852 c23c7fd ac94852 66ca485 ac94852 c23c7fd 66ca485 ac94852 c23c7fd f1fb568 c23c7fd c2d61c7 434ea14 c23c7fd c2d61c7 5dbe1e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
from safetensors.torch import load_file as safe_load
target_to_ind = {'cs': 0, 'econ': 1, 'eess': 2, 'math': 3, 'phys': 4, 'q-bio': 5, 'q-fin': 6, 'stat': 7}
target_to_label = {'cs': 'Computer Science', 'econ': 'Economics', 'eess': 'Electrical Engineering and Systems Science', 'math': 'Mathematics', 'phys': 'Physics',
'q-bio': 'Quantitative Biology', 'q-fin': 'Quantitative Finance', 'stat': 'Statistics'}
ind_to_target = {ind: target for target, ind in target_to_ind.items()}
@st.cache_resource
def load_model_and_tokenizer():
model_name = 'distilbert/distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(target_to_ind))
state_dict = safe_load("model.safetensors")
model.load_state_dict(state_dict)
return model, tokenizer
model, tokenizer = load_model_and_tokenizer()
def get_predict(title: str, abstract: str) -> (str, float, dict):
text = [title + tokenizer.sep_token + abstract[:128]]
tokens_info = tokenizer(
text,
padding=True,
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
out = model(**tokens_info)
probs = torch.nn.functional.softmax(out.logits, dim=-1).tolist()[0]
return list(sorted([(p, ind_to_target[i]) for i, p in enumerate(probs)]))[::-1]
title = st.text_area("Title ", "", height=100)
abstract = st.text_area("Abstract ", "", height=150)
mode = st.radio("Mode: ", ("Best prediction", "Top 95%"))
if st.button("Get prediction", key="manual"):
if len(title) == 0:
st.error("Please, provide paper's title")
else:
with st.spinner("Be patient, I'm doing my best"):
predict = get_predict(title, abstract)
tags = []
threshold = 0 if status == "Best prediction" else 0.95
sum_p = 0
for p, tag in predict:
sum_p += p
tags.append(target_to_label[tag])
if sum_p >= threshold:
break
tags = '\n'.join(tags)
st.succes(tags)
|