tuanle commited on
Commit
5e92cc6
1 Parent(s): d1f0a49
Files changed (2) hide show
  1. app.py +62 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from generate import generate
3
+
4
+
5
+ if __name__ == "__main__":
6
+ # Batch modification
7
+ with st.form(key='my_form'):
8
+ print("Loading model...")
9
+
10
+
11
+ print("Model is ready to serve...")
12
+ desc = "Vietnamese News Generative Model - finetuned GPT2"
13
+
14
+ st.title('Vietnamese News Generative Model!')
15
+ st.write(desc)
16
+ st.write("##")
17
+ option = st.selectbox(
18
+ 'Choose a category',
19
+ ('thời sự ', 'thế giới', 'tài chính kinh doanh',
20
+ 'đời sống', 'văn hoá', 'giải trí', 'giới trẻ', 'giáo dục',
21
+ 'công nghệ', 'sức khoẻ'))
22
+
23
+ st.write("##")
24
+ category = str(option)
25
+ headline = st.text_input('Headline (or part of the headline)')
26
+ num_return_sequences = st.slider('Number of return sequences', min_value = 1, max_value = 5, value = 2)
27
+ max_len = st.slider('Max Length', min_value = 80, max_value = 500, value = 300)
28
+ with st.expander("Setting parameters"):
29
+ min_len = st.slider('Min Length', min_value = 0, max_value = 50, value = 50)
30
+ top_k = st.slider('Top k', min_value = 30, max_value = 200, value = 50)
31
+ top_p = st.slider('Top p', min_value = 0.0, max_value = 1.0, value = 0.8)
32
+ num_beams = st.slider('Num Beams', min_value = 1, max_value = 6, value = 2)
33
+
34
+
35
+ submit_button = st.form_submit_button(label='Generate', )
36
+
37
+ if submit_button:
38
+ print("Generating output")
39
+ with st.spinner('Wait for it...'):
40
+ outputs = generate(category = str(category), headline = str(headline), min_len = min_len, max_len = max_len, num_return_sequences = num_return_sequences)
41
+
42
+ for i, output in enumerate(outputs):
43
+ # Cut start of text
44
+ temp = output.split("<|startoftext|>")[1].strip()
45
+
46
+ temp = temp.split("<|headline|>")
47
+ category = temp[0]
48
+
49
+ temp = temp[1].split("<|content|>")
50
+ headline = temp[0].strip()
51
+ content = temp[1].strip()
52
+
53
+ st.header(f"Output: {i}")
54
+ st.subheader("Category")
55
+ st.write(category)
56
+ st.subheader(f"Headline")
57
+ st.write(headline)
58
+ st.subheader(f"Content")
59
+ st.write(content)
60
+ st.write("##")
61
+
62
+ st.balloons()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers==4.16.2
2
+ torch==1.10.0
3
+ streamlit==1.5.1