Spaces:
Sleeping
Sleeping
File size: 1,863 Bytes
bba0e26 283aad9 bba0e26 283aad9 bba0e26 283aad9 bba0e26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
from torch.nn import Softmax
from torch import sort
@st.cache_resource # кэширование
def load_model():
return AutoModelForSequenceClassification.from_pretrained('zhdantim/mydeberta-v3-small').eval()
@st.cache_resource # кэширование
def load_id2classes():
return pd.read_csv('classes.tsv', sep='\t', index_col=0).to_dict()['classes']
@st.cache_resource
def load_tokenizer():
return AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
model = load_model()
id2classes = load_id2classes()
tokenizer = load_tokenizer()
def get_top_classes(text):
tokenized_text = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
logits = model(**tokenized_text).logits.detach()
probs = Softmax()(logits)
probs_sorted, indices = sort(probs, descending=True)
k = 1
while sum(probs_sorted[0, :k]) < 0.95:
k += 1
return [id2classes[idx.item()] for idx in indices[0, :k]], probs_sorted[0, :k]
st.title("Простой классификатор статей")
title = st.text_input(label="Введите название статьи (обязательно)", value="Type Here ...")
abstract = st.text_input(label="Введите abstract", value="Type Here ...")
if st.button('Submit'):
if title.title() != 'Type Here ...':
if abstract.title() != 'Type Here ...':
text = title.title() + '\n' + abstract.title()
else:
text = title.title()
top_classes, probs = get_top_classes(text)
for p, cls in zip(probs, top_classes):
st.success(f'Статья относится к {cls} с вероятностью {p}')
else:
st.error('Введите название статьи')
|