mqha commited on
Commit
8d3dced
·
1 Parent(s): ae4f900

加入gpt2模型

Browse files
Files changed (1) hide show
  1. app.py +31 -12
app.py CHANGED
@@ -1,18 +1,37 @@
1
  import streamlit as st
 
2
 
3
- # 标题
4
- st.title("欢迎使用Streamlit应用程序")
5
 
6
- # 文本输入框
7
- name = st.text_input("请输入您的姓名")
8
 
9
- # 按钮
10
- button_clicked = st.button("点击这里")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # 根据按钮点击状态展示消息
13
- if button_clicked:
14
- st.write(f"你好,{name}!欢迎使用Streamlit应用程序。")
15
 
16
- # 图片
17
- st.image("https://www.streamlit.io/images/brand/streamlit-logo-primary-colormark-darktext.png",
18
- caption="Streamlit Logo", use_column_width=True)
 
1
  import streamlit as st
2
+ from transformers import pipeline, set_seed
3
 
4
+ # 设置全局随机种子,确保每次生成的结果相同
5
+ set_seed(42)
6
 
 
 
7
 
8
+ options = ['中文','英文']
9
+ choice = st.radio('不同语言使用不同模型:', options)
10
+
11
+ input_text = st.text_input("请输入您要生成的文本", value="")
12
+ maxlen = st.text_input("请输入生成文本的最大长度,越长越慢,不要超过1000", value="30")
13
+ button_generate = st.button("生成")
14
+ output_text = st.empty()
15
+
16
+ def generate_text(input_text):
17
+ # 加载预训练模型
18
+ if choice == '中文':
19
+ model = 'gpt2-chinese-cluecorpussmall' # 会自动下载
20
+ generator = pipeline("text-generation", model)
21
+
22
+ # 生成文本
23
+ output = generator(input_text, max_length=int(maxlen), num_return_sequences=1)
24
+
25
+ # 提取生成的文本
26
+ generated_text = output[0]["generated_text"].strip()
27
+
28
+ return generated_text
29
+
30
+ if button_generate:
31
+ # 生成文本
32
+ generated_text = generate_text(input_text)
33
+
34
+ # 显示生成的文本
35
+ output_text.success(generated_text)
36
 
 
 
 
37