Hong commited on
Commit
f5a1b52
ยท
1 Parent(s): 6023314

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from platform import processor
2
+ import streamlit as st
3
+ from load_data import candidate_labels
4
+ import numpy as np
5
+ from load_data import *
6
+ import pickle
7
+ import torch
8
+ from BART_utils import get_taggs, compare_trans
9
+
10
+ st.title("Flitto Domain Tagger Demo V0.4.1")
11
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
+ if device == "cpu":
13
+ processor = "๐Ÿ–ฅ๏ธ"
14
+ else:
15
+ processor = "๐Ÿ’ฝ"
16
+
17
+ st.subheader("Running on {}".format(device + processor))
18
+
19
+
20
+ user_input = st.text_area(
21
+ "๐Ÿ‘‡๋„๋ฉ”์ธ ํƒœ๊ทธ๋ฅผ ์ƒ์„ฑํ•  ๋ฌธ์žฅ์„ ์ž…๋ ฅํ•˜์„ธ์š”", "This app uses Facebook Zero-shot NLI model. It uses a pre-trained NLI models as a ready-made zero-shot sequence classifiers"
22
+ )
23
+
24
+ thred = st.slider(
25
+ "๐Ÿ‘‡ํƒœ๊ทธ ์ƒ์„ฑ thredhold ์„ค์ •. ๊ฒฐ๊ณผ๊ฐ€ ๋‚˜์˜ค์ง€ ์•Š์„๊ฒฝ์šฐ, threshold๋ฅผ 0์— ๊ฐ€๊น๊ฒŒ ๋‚ฎ์ถ”์„ธ์š”!",
26
+ 0.0,
27
+ 1.0,
28
+ 0.5,
29
+ step=0.01,
30
+ )
31
+ if thred:
32
+ st.write(thred, " ์ด์ƒ์˜ confidence level์ธ ํƒœ๊ทธ๋งŒ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.")
33
+
34
+ maximum = st.number_input("๐Ÿ‘‡์ตœ๋Œ€ ํƒœ๊ทธ ๊ฐฏ์ˆ˜ ์„ค์ •", 0, 10, 5, step=1)
35
+ st.write("์ตœ๋Œ€ {} ๊ฐœ์˜ ํƒœ๊ทธ ์ƒ์„ฑ".format(maximum))
36
+
37
+ check_source = st.checkbox("๐Ÿท๏ธ์šฉ์ฒ˜ / ์ถœ์ฒ˜ ํƒœ๊ทธ ์ƒ์„ฑ")
38
+ submit = st.button("๐Ÿ‘ˆํด๋ฆญํ•ด์„œ ํƒœ๊ทธ ์ƒ์„ฑ")
39
+ if submit:
40
+
41
+ with st.spinner("โŒ›ํƒœ๊ทธ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ์ค‘์ž…๋‹ˆ๋‹ค..."):
42
+ result = get_taggs(user_input, candidate_labels, thred)
43
+ result = result[:maximum]
44
+ st.subheader("๐Ÿ”ํ˜น์‹œ ์ด๋Ÿฐ ์ฃผ์ œ์˜ ๋ฌธ์žฅ์ธ๊ฐ€์š”? : ")
45
+ if len(result) == 0:
46
+ st.write("๐Ÿ˜ข์ €๋Ÿฐ..๊ฒฐ๊ณผ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. Threshold๋ฅผ ๋‚ฎ์ถฐ๋ณด์„ธ์š”!")
47
+ for i in result:
48
+ st.write("โžก๏ธ " + i[0], "{}%".format(int(i[1] * 100)))
49
+
50
+ if check_source:
51
+ with st.spinner("โŒ›์‚ฌ์šฉ ๋ชฉ์  ํƒœ๊ทธ ์ƒ์„ฑ์ค‘..."):
52
+ source_result = get_taggs(user_input, source, thred=0)
53
+ st.subheader("๐Ÿ”ํ˜น์‹œ ์ด ์‚ฌ์šฉ๋ชฉ์ ์˜ ๋ฌธ์žฅ์ธ๊ฐ€์š”? : ")
54
+ for i in source_result[:3]:
55
+ st.write("๐Ÿท๏ธ " + i[0], "{}%".format(int(i[1] * 100)))