ivanlau commited on
Commit
f0726f1
·
1 Parent(s): a3858c0

added app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ import neattext.functions as nfx
3
+ import re
4
+ import torch
5
+ import streamlit as st
6
+
7
+ # labels
8
+ labels = [
9
+ 'bug',
10
+ 'enhancement',
11
+ 'question'
12
+ ]
13
+
14
+ # Model path
15
+ # LOCAL
16
+ # MODEL_DIR = "./model/distil-bert-uncased-finetuned-github-issues/"
17
+
18
+ # REMOTE
19
+ MODEL_DIR = "ivanlau/distil-bert-uncased-finetuned-github-issues"
20
+
21
+
22
+ @st.cache(allow_output_mutation=True, show_spinner=False)
23
+ def load_model():
24
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
25
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
26
+ return model, tokenizer
27
+
28
+ # Helpers
29
+ reg_obj = re.compile(r'[^\u0000-\u007F]+', re.UNICODE)
30
+ def is_english_text(text):
31
+ return (False if reg_obj.match(text) else True)
32
+
33
+ # remove the stopwords, emojis from the text and convert it into lower case
34
+ def neatify_text(text):
35
+ text = str(text).lower()
36
+ text = nfx.remove_stopwords(text)
37
+ text = nfx.remove_emojis(text)
38
+ return text
39
+
40
+
41
+
42
+ def main():
43
+ # st UI setting
44
+ st.set_page_config(
45
+ page_title="IntelliLabel",
46
+ page_icon="🏷",
47
+ layout="centered",
48
+ initial_sidebar_state="auto",
49
+ )
50
+ st.title("IntelliLabel")
51
+ st.write("IntelliLabel is a github issue classification app. It classifies issue into 3 categories (Bug, Enhancement, Question).")
52
+
53
+ # load model
54
+ with st.spinner("Downloading model (takes ~1 min)"):
55
+ model, tokenizer = load_model()
56
+
57
+
58
+
59
+ default_text = "Unable to run Speech2Text example in documentation"
60
+
61
+ text = st.text_area('Enter text here:', value=default_text)
62
+ submit = st.button('Predict 🏷')
63
+
64
+
65
+ if submit:
66
+ text = text.strip(" \n\t")
67
+ if is_english_text(text):
68
+ text = neatify_text(text)
69
+ tokenized_sentence = tokenizer(text, return_tensors='pt')
70
+ output = model(**tokenized_sentence)
71
+ predictions = torch.nn.functional.softmax(output.logits, dim=-1)
72
+ _, preds = torch.max(predictions, dim=-1)
73
+ predicted = labels[preds.item()]
74
+
75
+ predictions = predictions.tolist()[0]
76
+ c1, c2, c3 = st.columns(3)
77
+ c1.metric(label="Bug", value=round(predictions[0],3))
78
+ c2.metric(label="Enhancement", value=round(predictions[1],3))
79
+ c3.metric(label="Question", value=round(predictions[2],3))
80
+
81
+ st.info("Prediction")
82
+ st.write(predicted.capitalize())
83
+
84
+ else:
85
+ st.error(str("Please input english text."))
86
+
87
+
88
+ if __name__ == '__main__':
89
+ main()