Spaces:
Runtime error
Runtime error
Hong
commited on
Commit
ยท
f5a1b52
1
Parent(s):
6023314
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from platform import processor
|
2 |
+
import streamlit as st
|
3 |
+
from load_data import candidate_labels
|
4 |
+
import numpy as np
|
5 |
+
from load_data import *
|
6 |
+
import pickle
|
7 |
+
import torch
|
8 |
+
from BART_utils import get_taggs, compare_trans
|
9 |
+
|
10 |
+
st.title("Flitto Domain Tagger Demo V0.4.1")
|
11 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
12 |
+
if device == "cpu":
|
13 |
+
processor = "๐ฅ๏ธ"
|
14 |
+
else:
|
15 |
+
processor = "๐ฝ"
|
16 |
+
|
17 |
+
st.subheader("Running on {}".format(device + processor))
|
18 |
+
|
19 |
+
|
20 |
+
user_input = st.text_area(
|
21 |
+
"๐๋๋ฉ์ธ ํ๊ทธ๋ฅผ ์์ฑํ ๋ฌธ์ฅ์ ์
๋ ฅํ์ธ์", "This app uses Facebook Zero-shot NLI model. It uses a pre-trained NLI models as a ready-made zero-shot sequence classifiers"
|
22 |
+
)
|
23 |
+
|
24 |
+
thred = st.slider(
|
25 |
+
"๐ํ๊ทธ ์์ฑ thredhold ์ค์ . ๊ฒฐ๊ณผ๊ฐ ๋์ค์ง ์์๊ฒฝ์ฐ, threshold๋ฅผ 0์ ๊ฐ๊น๊ฒ ๋ฎ์ถ์ธ์!",
|
26 |
+
0.0,
|
27 |
+
1.0,
|
28 |
+
0.5,
|
29 |
+
step=0.01,
|
30 |
+
)
|
31 |
+
if thred:
|
32 |
+
st.write(thred, " ์ด์์ confidence level์ธ ํ๊ทธ๋ง ์์ฑํฉ๋๋ค.")
|
33 |
+
|
34 |
+
maximum = st.number_input("๐์ต๋ ํ๊ทธ ๊ฐฏ์ ์ค์ ", 0, 10, 5, step=1)
|
35 |
+
st.write("์ต๋ {} ๊ฐ์ ํ๊ทธ ์์ฑ".format(maximum))
|
36 |
+
|
37 |
+
check_source = st.checkbox("๐ท๏ธ์ฉ์ฒ / ์ถ์ฒ ํ๊ทธ ์์ฑ")
|
38 |
+
submit = st.button("๐ํด๋ฆญํด์ ํ๊ทธ ์์ฑ")
|
39 |
+
if submit:
|
40 |
+
|
41 |
+
with st.spinner("โํ๊ทธ๋ฅผ ์์ฑํ๋ ์ค์
๋๋ค..."):
|
42 |
+
result = get_taggs(user_input, candidate_labels, thred)
|
43 |
+
result = result[:maximum]
|
44 |
+
st.subheader("๐ํน์ ์ด๋ฐ ์ฃผ์ ์ ๋ฌธ์ฅ์ธ๊ฐ์? : ")
|
45 |
+
if len(result) == 0:
|
46 |
+
st.write("๐ข์ ๋ฐ..๊ฒฐ๊ณผ๊ฐ ์์ต๋๋ค. Threshold๋ฅผ ๋ฎ์ถฐ๋ณด์ธ์!")
|
47 |
+
for i in result:
|
48 |
+
st.write("โก๏ธ " + i[0], "{}%".format(int(i[1] * 100)))
|
49 |
+
|
50 |
+
if check_source:
|
51 |
+
with st.spinner("โ์ฌ์ฉ ๋ชฉ์ ํ๊ทธ ์์ฑ์ค..."):
|
52 |
+
source_result = get_taggs(user_input, source, thred=0)
|
53 |
+
st.subheader("๐ํน์ ์ด ์ฌ์ฉ๋ชฉ์ ์ ๋ฌธ์ฅ์ธ๊ฐ์? : ")
|
54 |
+
for i in source_result[:3]:
|
55 |
+
st.write("๐ท๏ธ " + i[0], "{}%".format(int(i[1] * 100)))
|