Karpernik commited on
Commit
45164d7
·
verified ·
1 Parent(s): cf15218

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +98 -0
  2. model.pt +3 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import os
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ # from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
6
+
7
+
8
+ st.markdown('## Классификатор статей')
9
+ st.write('Данный сервис предназначен для выбора темы статьи, \n' \
10
+ 'основываясь на ее названии и краткой выжимкой текста статьи. \n' \
11
+ 'Сервис работает благодаря fine-tune версии модели distil bert. \n' \
12
+ 'Данные для обучения были взяты [отсюда](https://www.kaggle.com/datasets/neelshah18/arxivdataset). \n' \
13
+ 'Поддерживается ввод только английского языка.')
14
+ st.markdown('#### Введите название статьи и ее краткое содержание:')
15
+
16
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
17
+
18
+ def create_model_and_optimizer(model, lr=1e-4, beta1=0.9, beta2=0.999, device=device):
19
+ model = model.to(device)
20
+
21
+ params = []
22
+ for param in model.parameters():
23
+ if param.requires_grad:
24
+ params.append(param)
25
+
26
+ optimizer = torch.optim.Adam(params, lr, [beta1, beta2])
27
+ return model, optimizer
28
+
29
+ def title_summary_transform(items, tokenizer):
30
+ return tokenizer(
31
+ items['title'] + '[SEP]' + items['summary'],
32
+ padding="max_length",
33
+ truncation=True
34
+ )
35
+
36
+ def predict_category(case, model, tokenizer):
37
+ input_ = {key: torch.tensor(val).unsqueeze(0).to(device) for key, val in title_summary_transform(case, tokenizer).items()}
38
+
39
+ pred = []
40
+ pred_prob = []
41
+
42
+ with torch.no_grad():
43
+ logits = model(**input_).logits[0]
44
+ probs = torch.nn.functional.softmax(logits, dim=-1)
45
+ probs, indices = torch.sort(probs, descending=True)
46
+
47
+ sum_prob = 0
48
+ for i, prob_ in enumerate(probs):
49
+ pred.append(indices[i].item())
50
+ pred_prob.append(prob_.item())
51
+ sum_prob += prob_
52
+
53
+ if sum_prob > 0.95:
54
+ break
55
+ return pred, pred_prob
56
+
57
+ @st.cache_resource # кэширование
58
+ def load_model():
59
+ chkp_folder = '.'
60
+ model_name = 'model'
61
+ cat_count = 358
62
+
63
+ checkpoint = torch.load(os.path.join(chkp_folder, f"{model_name}.pt"), weights_only=False)
64
+
65
+ # Создаём те же классы, что и внутри чекпоинта
66
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
67
+
68
+ model_ = AutoModelForSequenceClassification.from_pretrained('distilbert/distilbert-base-cased', num_labels=cat_count).to(device)
69
+
70
+ for param in model_.distilbert.parameters():
71
+ param.requires_grad = False
72
+
73
+ for i in range(4, 6):
74
+ for param in model_.distilbert.transformer.layer[i].parameters():
75
+ param.requires_grad = True
76
+
77
+ model, optimizer = create_model_and_optimizer(model_)
78
+
79
+ # Загружаем состояния из чекпоинта
80
+ model.load_state_dict(checkpoint['model_state_dict'])
81
+ ind_to_cat = checkpoint['ind_to_cat']
82
+ tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-cased')
83
+ return model, tokenizer, ind_to_cat
84
+
85
+ model, tokenizer, ind_to_cat = load_model()
86
+
87
+
88
+ case_ = {}
89
+
90
+ case_['title'] = st.text_area("Название статьи:", value="")
91
+ case_['summary'] = st.text_area("Краткое содержание:", value="")
92
+
93
+
94
+ if case_['title'] or case_['summary']:
95
+ categories, probabilities = predict_category(case_, model, tokenizer)
96
+ st.write('Возможные категории:')
97
+ for i, cat in enumerate(categories):
98
+ st.write(f'{ind_to_cat[cat]}')
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:551d83cc776a221a789ac6e5be212d5a9dc708e6a41f3ec164b8099fae575000
3
+ size 384644922