Spaces:
Running
Running
Upload 2 files
Browse files
app.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
5 |
+
# from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
|
6 |
+
|
7 |
+
|
8 |
+
st.markdown('## Классификатор статей')
|
9 |
+
st.write('Данный сервис предназначен для выбора темы статьи, \n' \
|
10 |
+
'основываясь на ее названии и краткой выжимкой текста статьи. \n' \
|
11 |
+
'Сервис работает благодаря fine-tune версии модели distil bert. \n' \
|
12 |
+
'Данные для обучения были взяты [отсюда](https://www.kaggle.com/datasets/neelshah18/arxivdataset). \n' \
|
13 |
+
'Поддерживается ввод только английского языка.')
|
14 |
+
st.markdown('#### Введите название статьи и ее краткое содержание:')
|
15 |
+
|
16 |
+
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
17 |
+
|
18 |
+
def create_model_and_optimizer(model, lr=1e-4, beta1=0.9, beta2=0.999, device=device):
|
19 |
+
model = model.to(device)
|
20 |
+
|
21 |
+
params = []
|
22 |
+
for param in model.parameters():
|
23 |
+
if param.requires_grad:
|
24 |
+
params.append(param)
|
25 |
+
|
26 |
+
optimizer = torch.optim.Adam(params, lr, [beta1, beta2])
|
27 |
+
return model, optimizer
|
28 |
+
|
29 |
+
def title_summary_transform(items, tokenizer):
|
30 |
+
return tokenizer(
|
31 |
+
items['title'] + '[SEP]' + items['summary'],
|
32 |
+
padding="max_length",
|
33 |
+
truncation=True
|
34 |
+
)
|
35 |
+
|
36 |
+
def predict_category(case, model, tokenizer):
|
37 |
+
input_ = {key: torch.tensor(val).unsqueeze(0).to(device) for key, val in title_summary_transform(case, tokenizer).items()}
|
38 |
+
|
39 |
+
pred = []
|
40 |
+
pred_prob = []
|
41 |
+
|
42 |
+
with torch.no_grad():
|
43 |
+
logits = model(**input_).logits[0]
|
44 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
45 |
+
probs, indices = torch.sort(probs, descending=True)
|
46 |
+
|
47 |
+
sum_prob = 0
|
48 |
+
for i, prob_ in enumerate(probs):
|
49 |
+
pred.append(indices[i].item())
|
50 |
+
pred_prob.append(prob_.item())
|
51 |
+
sum_prob += prob_
|
52 |
+
|
53 |
+
if sum_prob > 0.95:
|
54 |
+
break
|
55 |
+
return pred, pred_prob
|
56 |
+
|
57 |
+
@st.cache_resource # кэширование
|
58 |
+
def load_model():
|
59 |
+
chkp_folder = '.'
|
60 |
+
model_name = 'model'
|
61 |
+
cat_count = 358
|
62 |
+
|
63 |
+
checkpoint = torch.load(os.path.join(chkp_folder, f"{model_name}.pt"), weights_only=False)
|
64 |
+
|
65 |
+
# Создаём те же классы, что и внутри чекпоинта
|
66 |
+
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
67 |
+
|
68 |
+
model_ = AutoModelForSequenceClassification.from_pretrained('distilbert/distilbert-base-cased', num_labels=cat_count).to(device)
|
69 |
+
|
70 |
+
for param in model_.distilbert.parameters():
|
71 |
+
param.requires_grad = False
|
72 |
+
|
73 |
+
for i in range(4, 6):
|
74 |
+
for param in model_.distilbert.transformer.layer[i].parameters():
|
75 |
+
param.requires_grad = True
|
76 |
+
|
77 |
+
model, optimizer = create_model_and_optimizer(model_)
|
78 |
+
|
79 |
+
# Загружаем состояния из чекпоинта
|
80 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
81 |
+
ind_to_cat = checkpoint['ind_to_cat']
|
82 |
+
tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-cased')
|
83 |
+
return model, tokenizer, ind_to_cat
|
84 |
+
|
85 |
+
model, tokenizer, ind_to_cat = load_model()
|
86 |
+
|
87 |
+
|
88 |
+
case_ = {}
|
89 |
+
|
90 |
+
case_['title'] = st.text_area("Название статьи:", value="")
|
91 |
+
case_['summary'] = st.text_area("Краткое содержание:", value="")
|
92 |
+
|
93 |
+
|
94 |
+
if case_['title'] or case_['summary']:
|
95 |
+
categories, probabilities = predict_category(case_, model, tokenizer)
|
96 |
+
st.write('Возможные категории:')
|
97 |
+
for i, cat in enumerate(categories):
|
98 |
+
st.write(f'{ind_to_cat[cat]}')
|
model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:551d83cc776a221a789ac6e5be212d5a9dc708e6a41f3ec164b8099fae575000
|
3 |
+
size 384644922
|