Spaces:
Runtime error
Runtime error
start
Browse files- app.py +62 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from generate import generate
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
# Batch modification
|
7 |
+
with st.form(key='my_form'):
|
8 |
+
print("Loading model...")
|
9 |
+
|
10 |
+
|
11 |
+
print("Model is ready to serve...")
|
12 |
+
desc = "Vietnamese News Generative Model - finetuned GPT2"
|
13 |
+
|
14 |
+
st.title('Vietnamese News Generative Model!')
|
15 |
+
st.write(desc)
|
16 |
+
st.write("##")
|
17 |
+
option = st.selectbox(
|
18 |
+
'Choose a category',
|
19 |
+
('thời sự ', 'thế giới', 'tài chính kinh doanh',
|
20 |
+
'đời sống', 'văn hoá', 'giải trí', 'giới trẻ', 'giáo dục',
|
21 |
+
'công nghệ', 'sức khoẻ'))
|
22 |
+
|
23 |
+
st.write("##")
|
24 |
+
category = str(option)
|
25 |
+
headline = st.text_input('Headline (or part of the headline)')
|
26 |
+
num_return_sequences = st.slider('Number of return sequences', min_value = 1, max_value = 5, value = 2)
|
27 |
+
max_len = st.slider('Max Length', min_value = 80, max_value = 500, value = 300)
|
28 |
+
with st.expander("Setting parameters"):
|
29 |
+
min_len = st.slider('Min Length', min_value = 0, max_value = 50, value = 50)
|
30 |
+
top_k = st.slider('Top k', min_value = 30, max_value = 200, value = 50)
|
31 |
+
top_p = st.slider('Top p', min_value = 0.0, max_value = 1.0, value = 0.8)
|
32 |
+
num_beams = st.slider('Num Beams', min_value = 1, max_value = 6, value = 2)
|
33 |
+
|
34 |
+
|
35 |
+
submit_button = st.form_submit_button(label='Generate', )
|
36 |
+
|
37 |
+
if submit_button:
|
38 |
+
print("Generating output")
|
39 |
+
with st.spinner('Wait for it...'):
|
40 |
+
outputs = generate(category = str(category), headline = str(headline), min_len = min_len, max_len = max_len, num_return_sequences = num_return_sequences)
|
41 |
+
|
42 |
+
for i, output in enumerate(outputs):
|
43 |
+
# Cut start of text
|
44 |
+
temp = output.split("<|startoftext|>")[1].strip()
|
45 |
+
|
46 |
+
temp = temp.split("<|headline|>")
|
47 |
+
category = temp[0]
|
48 |
+
|
49 |
+
temp = temp[1].split("<|content|>")
|
50 |
+
headline = temp[0].strip()
|
51 |
+
content = temp[1].strip()
|
52 |
+
|
53 |
+
st.header(f"Output: {i}")
|
54 |
+
st.subheader("Category")
|
55 |
+
st.write(category)
|
56 |
+
st.subheader(f"Headline")
|
57 |
+
st.write(headline)
|
58 |
+
st.subheader(f"Content")
|
59 |
+
st.write(content)
|
60 |
+
st.write("##")
|
61 |
+
|
62 |
+
st.balloons()
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
transformers==4.16.2
|
2 |
+
torch==1.10.0
|
3 |
+
streamlit==1.5.1
|