Author commited on
Commit
9a179e2
·
1 Parent(s): 90432c0

first commit

Browse files
Files changed (4) hide show
  1. app.py +26 -0
  2. model.pt +3 -0
  3. model.py +65 -0
  4. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from model import (DistilBertTokenizerFast, TARGET_IND2LABEL, DistilBERTClassifier,
4
+ load_model, load_tokenizer,
5
+ get_preds, get_verdict)
6
+
7
+
8
+ st.markdown("## Hello, my dear friend!")
9
+ st.markdown("### This service can classify article topic by it's title and abstract")
10
+ st.markdown("##### You can specify only title or only abstract, just leave the field empty, but it can work worse in that case.")
11
+
12
+ title = st.text_area("Title:")
13
+ abstract = st.text_area("Abstract:")
14
+
15
+ device = 'cpu'
16
+ tokenizer = load_tokenizer()
17
+ model = load_model(device)
18
+
19
+ text = title + abstract
20
+ if (not text):
21
+ verdict = "Both fields are empty"
22
+ else:
23
+ verdict = get_verdict(get_preds(text, model, tokenizer, device))
24
+
25
+ st.markdown("#### Verdict:")
26
+ st.markdown(verdict)
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0883bdb976c57c2ca825375fd530c836f5d231ff323fb7f23f6fc14189db57a
3
+ size 263204889
model.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import streamlit as st
5
+ from transformers import DistilBertModel, DistilBertTokenizerFast
6
+
7
+
8
+ TARGET_IND2LABEL = {
9
+ 0: 'Computer Science',
10
+ 1: 'Economics',
11
+ 2: 'Electrical Engineering and Systems Science',
12
+ 3: 'Mathematics',
13
+ 4: 'Physics',
14
+ 5: 'Quantitative Biology',
15
+ 6: 'Quantitative Finance',
16
+ 7: 'Statistics',
17
+ }
18
+
19
+ class DistilBERTClassifier(nn.Module):
20
+ def __init__(self, num_classes=8):
21
+ super().__init__()
22
+ self.encoder = DistilBertModel.from_pretrained("distilbert-base-cased")
23
+ self.pre_classifier = nn.Linear(768, 768)
24
+ self.gelu = nn.GELU()
25
+ self.dropout = nn.Dropout(0.1)
26
+ self.classifier = nn.Linear(768, num_classes)
27
+
28
+ def forward(self, input_ids, attention_mask, labels):
29
+ output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
30
+ hidden_state = output[0]
31
+ pooler = hidden_state[:, 0]
32
+ pooler = self.dropout(self.gelu(self.pre_classifier(pooler)))
33
+ preds = self.classifier(pooler)
34
+ return preds
35
+
36
+ @st.cache_resource
37
+ def load_tokenizer():
38
+ return DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')
39
+
40
+ @st.cache_resource
41
+ def load_model(device):
42
+ model = torch.load('model.pt', map_location=torch.device('cpu')).to(device)
43
+ model.eval()
44
+ return model
45
+
46
+ def get_verdict(preds):
47
+ inds = np.argsort(preds)[::-1]
48
+ sum_prob = 0.0
49
+ verdict = []
50
+ for ind in inds:
51
+ prob = preds[ind]
52
+ sum_prob += prob
53
+ verdict.append(f"{TARGET_IND2LABEL[ind]}: {prob}")
54
+ if (sum_prob >= 0.95):
55
+ break
56
+ return "\n\n".join(verdict)
57
+
58
+ def get_preds(text, model, tokenizer, device):
59
+ tokens = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
60
+ tokens['input_ids'] = tokens['input_ids'].to(device)
61
+ tokens['attention_mask'] = tokens['attention_mask'].to(device)
62
+ tokens['labels'] = None # made for training convinience
63
+ with torch.no_grad():
64
+ preds = torch.softmax(model(**tokens)[0], 0).cpu().numpy()
65
+ return preds
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ numpy
2
+ torch
3
+ transformers