Inosen-Infinity commited on
Commit
6677015
·
verified ·
1 Parent(s): 12ae8a3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import pandas as pd
5
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
6
+ from datasets import load_dataset
7
+
8
+
9
+ device = 'cpu'
10
+
11
+ @st.cache_resource
12
+ def get_model_and_tokenizer():
13
+ model_name = "FacebookAI/roberta-base"
14
+ num_labels = 157
15
+
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
18
+
19
+ chkp = torch.load("arxiv_roberta_9unfrozen_5scaleloss.pt", map_location=device)
20
+ model.load_state_dict(chkp['model'])
21
+
22
+ return model, tokenizer
23
+
24
+ @st.cache_data
25
+ def get_categories():
26
+ categories = load_dataset("TimSchopf/arxiv_categories", "arxiv_category_descriptions")
27
+
28
+ cat2id = dict((cat, id) for id, cat in enumerate(categories['arxiv_category_descriptions']['tag']))
29
+ id2cat = categories['arxiv_category_descriptions']['tag']
30
+ names = categories['arxiv_category_descriptions']['name']
31
+
32
+ return cat2id, id2cat, names
33
+
34
+ model, tokenizer = get_model_and_tokenizer()
35
+ cat2id, id2cat, cat_names = get_categories()
36
+
37
+ @torch.no_grad
38
+ def predict_and_decode(model, title='', abstract=''):
39
+ model.eval()
40
+
41
+ inputs = tokenizer(title, abstract, return_tensors='pt', truncation=True, max_length=512).to(device)
42
+ logits = model(**inputs)['logits'][0].cpu()
43
+
44
+ df = pd.DataFrame([
45
+ (id2cat[cat_id], cat_names[cat_id], prob.item())
46
+ for cat_id, prob in enumerate(F.sigmoid(logits))
47
+ ], columns=("tag", "name", "probability"))
48
+ df.sort_values(by="probability", ascending=False, inplace=True)
49
+
50
+ return df.reset_index(drop=True)
51
+
52
+ st.header("Paper Category Classifier")
53
+ st.text("Input title and/or abstract of a scientific paper, and get classification according to arxiv.org categories")
54
+
55
+ title_default = "Attention Is All You Need"
56
+ abstract_default = (
57
+ "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks "
58
+ "in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through "
59
+ "an attention mechanism. We propose a new simple network architecture, the Transformer..."
60
+ )
61
+
62
+ line_height = 34
63
+ n_lines = 10
64
+ title = st.text_input("Paper title", value=title_default, help="Type in paper's title")
65
+ abstract = st.text_area("Paper abstract", value=abstract_default, height=line_height*n_lines, help="Type in paper's abstract")
66
+
67
+ result = predict_and_decode(model, title=title, abstract=abstract)
68
+
69
+ cnt = st.container(border=True)
70
+ with cnt:
71
+ st.markdown("#### Top category")
72
+ st.markdown(f"**{result.tag[0]}** -- {result.name[0]}")
73
+ st.markdown(f"Probability: {result.probability[0]*100:.2f}%")
74
+
75
+ threshold = 0.55
76
+ st.text("Other top categories:")
77
+ max_len = min(max(1, sum(result.iloc[1:].probability > threshold)), 5)
78
+
79
+ def format_p(example):
80
+ example.probability = f"{example.probability * 100 :.2f}%"
81
+ return example
82
+ st.table(result.iloc[1:1 + max_len].apply(format_p, axis=1))