Spaces:
Sleeping
Sleeping
timofeyzhdanovich
commited on
Commit
·
bba0e26
1
Parent(s):
ce274fb
init commit
Browse files- app.py +56 -0
- classes.tsv +39 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
3 |
+
import pandas as pd
|
4 |
+
from torch.nn import Softmax
|
5 |
+
from torch import sort
|
6 |
+
|
7 |
+
|
8 |
+
@st.cache_resource # кэширование
|
9 |
+
def load_model():
|
10 |
+
return AutoModelForSequenceClassification.from_pretrained('zhdantim/mydeberta-v3-small').eval()
|
11 |
+
|
12 |
+
|
13 |
+
@st.cache_resource # кэширование
|
14 |
+
def load_id2classes():
|
15 |
+
return pd.read_csv('classes.tsv', sep='\t', index_col=0).to_dict()['classes']
|
16 |
+
|
17 |
+
|
18 |
+
@st.cache_resource
|
19 |
+
def load_tokenizer():
|
20 |
+
return AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
|
21 |
+
|
22 |
+
|
23 |
+
model = load_model()
|
24 |
+
id2classes = load_id2classes()
|
25 |
+
tokenizer = load_tokenizer()
|
26 |
+
|
27 |
+
|
28 |
+
def get_top_classes(text):
|
29 |
+
tokenized_text = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
|
30 |
+
logits = model(**tokenized_text).logits.detach()
|
31 |
+
probs = Softmax()(logits)
|
32 |
+
probs_sorted, indices = sort(probs, descending=True)
|
33 |
+
|
34 |
+
k = 1
|
35 |
+
while sum(probs_sorted[0, :k]) < 0.95:
|
36 |
+
k += 1
|
37 |
+
|
38 |
+
return [id2classes[idx.item()] for idx in indices[0, :k]], probs_sorted[0, :k]
|
39 |
+
|
40 |
+
st.title("Простой классификатор статей")
|
41 |
+
|
42 |
+
title = st.text_input(label="Введите название статьи (обязательно)", value="Type Here ...")
|
43 |
+
abstract = st.text_input(label="Введите abstract", value="Type Here ...")
|
44 |
+
if st.button('Submit'):
|
45 |
+
if title.title() != 'Type Here ...':
|
46 |
+
if abstract.title() != 'Type Here ...':
|
47 |
+
text = title.title() + '\n' + abstract.title()
|
48 |
+
else:
|
49 |
+
text = title.title()
|
50 |
+
|
51 |
+
top_classes, probs = get_top_classes(text)
|
52 |
+
|
53 |
+
for p, cls in zip(probs, top_classes):
|
54 |
+
st.success(f'Статья относится к {cls} с вероятностью {p}')
|
55 |
+
else:
|
56 |
+
st.error('Введите название статьи')
|
classes.tsv
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
classes
|
2 |
+
0 cmp-lg
|
3 |
+
1 math
|
4 |
+
2 comp-gas
|
5 |
+
3 acc-phys
|
6 |
+
4 patt-sol
|
7 |
+
5 funct-an
|
8 |
+
6 hep-th
|
9 |
+
7 ao-sci
|
10 |
+
8 mtrl-th
|
11 |
+
9 adap-org
|
12 |
+
10 bayes-an
|
13 |
+
11 alg-geom
|
14 |
+
12 nlin
|
15 |
+
13 chao-dyn
|
16 |
+
14 chem-ph
|
17 |
+
15 cond-mat
|
18 |
+
16 math-ph
|
19 |
+
17 eess
|
20 |
+
18 hep-ph
|
21 |
+
19 plasm-ph
|
22 |
+
20 dg-ga
|
23 |
+
21 stat
|
24 |
+
22 econ
|
25 |
+
23 nucl-th
|
26 |
+
24 q-alg
|
27 |
+
25 atom-ph
|
28 |
+
26 hep-ex
|
29 |
+
27 q-bio
|
30 |
+
28 cs
|
31 |
+
29 hep-lat
|
32 |
+
30 quant-ph
|
33 |
+
31 astro-ph
|
34 |
+
32 nucl-ex
|
35 |
+
33 q-fin
|
36 |
+
34 solv-int
|
37 |
+
35 physics
|
38 |
+
36 gr-qc
|
39 |
+
37 supr-con
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
pandas
|
2 |
+
transformers
|
3 |
+
torch
|