yizhangliu commited on
Commit
655b569
·
1 Parent(s): af8001b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -31
app.py CHANGED
@@ -2,6 +2,8 @@ from transformers import pipeline
2
  import gradio as gr
3
  import random
4
  import paddlehub as hub
 
 
5
  from loguru import logger
6
 
7
  language_translation_model = hub.Module(directory=f'./baidu_translate')
@@ -23,10 +25,26 @@ def getTextTrans(text, source='zh', target='en'):
23
 
24
  extend_prompt_pipe = pipeline('text-generation', model='yizhangliu/prompt-extend', max_length=77, pad_token_id=0)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  space_ids = {
27
- "spaces/stabilityai/stable-diffusion":"SD 2.1",
28
- "spaces/runwayml/stable-diffusion-v1-5":"SD 1.5",
29
- "spaces/stabilityai/stable-diffusion-1":"SD 1.0",
 
30
  }
31
 
32
  tab_actions = []
@@ -34,6 +52,7 @@ tab_titles = []
34
 
35
  thanks_info = "Thanks: "
36
  thanks_info += "[<a style='display:inline-block' href='https://huggingface.co/spaces/daspartho/prompt-extend' _blank><font style='color:blue;weight:bold;'>prompt-extend</font></a>]"
 
37
 
38
  for space_id in space_ids.keys():
39
  print(space_id, space_ids[space_id])
@@ -82,22 +101,20 @@ start_work = """async() => {
82
 
83
  if (typeof window['gradioEl'] === 'undefined') {
84
  window['gradioEl'] = gradioEl;
85
-
86
  tabitems = window['gradioEl'].querySelectorAll('.tabitem');
 
87
 
88
- for (var i = 0; i < tabitems.length; i++) {
89
- if ([0, 1, 2].includes(i)) {
90
  tabitems[i].childNodes[0].children[0].style.display='none';
91
  for (var j = 0; j < tabitems[i].childNodes[0].children[1].children.length; j++) {
92
  if (j != 1) {
93
  tabitems[i].childNodes[0].children[1].children[j].style.display='none';
94
  }
95
  }
96
- } else {
97
- tabitems[i].childNodes[0].children[0].style.display='none';
98
- tabitems[i].childNodes[0].children[1].style.display='none';
99
- tabitems[i].childNodes[0].children[2].children[0].style.display='none';
100
- tabitems[i].childNodes[0].children[3].style.display='none';
101
  }
102
  }
103
 
@@ -106,8 +123,12 @@ start_work = """async() => {
106
  tab_demo.setAttribute('style', 'height: 100%;');
107
  const page1 = window['gradioEl'].querySelectorAll('#page_1')[0];
108
  const page2 = window['gradioEl'].querySelectorAll('#page_2')[0];
109
- window['gradioEl'].querySelector('#input_col1_row2').children[0].setAttribute('style', 'min-width:0px;width:50%;');
110
- window['gradioEl'].querySelector('#input_col1_row2').children[1].setAttribute('style', 'min-width:0px;width:50%;');
 
 
 
 
111
  page1.style.display = "none";
112
  page2.style.display = "block";
113
  window['prevPrompt'] = '';
@@ -121,20 +142,25 @@ start_work = """async() => {
121
  window['doCheckPrompt'] = 1;
122
  window['prevPrompt'] = text_value;
123
  tabitems = window['gradioEl'].querySelectorAll('.tabitem');
124
- for (var i = 0; i < tabitems.length; i++) {
125
- if ([0, 1, 2].includes(i)) {
 
 
126
  inputText = tabitems[i].children[0].children[1].children[0].querySelectorAll('.gr-text-input')[0];
127
- } else {
128
- inputText = tabitems[i].childNodes[0].children[2].children[0].children[0].querySelectorAll('.gr-text-input')[0];
 
 
 
 
 
129
  }
130
- setNativeValue(inputText, text_value);
131
- inputText.dispatchEvent(new Event('input', { bubbles: true }));
132
  }
133
 
134
  setTimeout(function() {
135
  btns = window['gradioEl'].querySelectorAll('button');
136
  for (var i = 0; i < btns.length; i++) {
137
- if (['Generate image','Run'].includes(btns[i].innerText)) {
138
  btns[i].click();
139
  }
140
  }
@@ -150,25 +176,40 @@ start_work = """async() => {
150
  return false;
151
  }"""
152
 
153
- def prompt_extend(prompt):
154
  prompt_en = getTextTrans(prompt, source='zh', target='en')
155
- extend_prompt_en = extend_prompt_pipe(prompt_en+',', num_return_sequences=1)[0]["generated_text"]
 
 
 
 
156
  if (prompt != prompt_en):
157
- logger.info(f"extend_prompt__1__")
158
  extend_prompt_out = getTextTrans(extend_prompt_en, source='en', target='zh')
159
  else:
160
- logger.info(f"extend_prompt__2__")
161
  extend_prompt_out = extend_prompt_en
162
 
163
  return extend_prompt_out
164
 
 
 
 
 
 
 
 
 
165
  def prompt_draw(prompt):
166
  prompt_en = getTextTrans(prompt, source='zh', target='en')
167
  if (prompt != prompt_en):
168
  logger.info(f"draw_prompt______1__")
 
169
  else:
170
  logger.info(f"draw_prompt______2__")
171
- return prompt_en
 
 
172
 
173
  with gr.Blocks(title='Text-to-Image') as demo:
174
  with gr.Group(elem_id="page_1", visible=True) as page_1:
@@ -183,21 +224,25 @@ with gr.Blocks(title='Text-to-Image') as demo:
183
  with gr.Row(elem_id="input_col1_row1"):
184
  prompt_input0 = gr.Textbox(lines=2, label="Original prompt", visible=True)
185
  with gr.Row(elem_id="input_col1_row2"):
186
- with gr.Column(elem_id="input_col1_row2_col1"):
187
  draw_btn_0 = gr.Button(value = "Generate(original)", elem_id="draw-btn-0")
 
 
188
  with gr.Column(elem_id="input_col1_row2_col2"):
189
- extend_btn = gr.Button(value = "Extend prompt",elem_id="extend-btn")
190
  with gr.Column(id="input_col2"):
191
  prompt_input1 = gr.Textbox(lines=2, label="Extend prompt", visible=True)
192
  draw_btn_1 = gr.Button(value = "Generate(extend)", elem_id="draw-btn-1")
193
  prompt_work = gr.Textbox(lines=1, label="prompt_work", elem_id="prompt_work", visible=False)
 
194
  with gr.Row(elem_id='tab_demo', visible=True).style(height=200):
195
  tab_demo = gr.TabbedInterface(tab_actions, tab_titles)
196
  with gr.Row():
197
- gr.HTML(f"<p>{thanks_info}</p>")
198
 
199
- extend_btn.click(fn=prompt_extend, inputs=[prompt_input0], outputs=[prompt_input1])
200
- draw_btn_0.click(fn=prompt_draw, inputs=[prompt_input0], outputs=[prompt_work])
201
- draw_btn_1.click(fn=prompt_draw, inputs=[prompt_input1], outputs=[prompt_work])
 
202
 
203
  demo.launch()
 
2
  import gradio as gr
3
  import random
4
  import paddlehub as hub
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from loguru import logger
8
 
9
  language_translation_model = hub.Module(directory=f'./baidu_translate')
 
25
 
26
  extend_prompt_pipe = pipeline('text-generation', model='yizhangliu/prompt-extend', max_length=77, pad_token_id=0)
27
 
28
+ def load_prompter():
29
+ prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
30
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
31
+ tokenizer.pad_token = tokenizer.eos_token
32
+ tokenizer.padding_side = "left"
33
+ return prompter_model, tokenizer
34
+ prompter_model, prompter_tokenizer = load_prompter()
35
+ def extend_prompt_microsoft(in_text):
36
+ input_ids = prompter_tokenizer(in_text.strip()+" Rephrase:", return_tensors="pt").input_ids
37
+ eos_id = prompter_tokenizer.eos_token_id
38
+ outputs = prompter_model.generate(input_ids, do_sample=False, max_new_tokens=75, num_beams=8, num_return_sequences=8, eos_token_id=eos_id, pad_token_id=eos_id, length_penalty=-1.0)
39
+ output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
40
+ res = output_texts[0].replace(in_text+" Rephrase:", "").strip()
41
+ return res
42
+
43
  space_ids = {
44
+ "spaces/stabilityai/stable-diffusion": "SD 2.1",
45
+ "spaces/runwayml/stable-diffusion-v1-5": "SD 1.5",
46
+ "spaces/stabilityai/stable-diffusion-1": "SD 1.0",
47
+ "spaces/IDEA-CCNL/Taiyi-Stable-Diffusion-Chinese": "Taiyi(太乙)",
48
  }
49
 
50
  tab_actions = []
 
52
 
53
  thanks_info = "Thanks: "
54
  thanks_info += "[<a style='display:inline-block' href='https://huggingface.co/spaces/daspartho/prompt-extend' _blank><font style='color:blue;weight:bold;'>prompt-extend</font></a>]"
55
+ thanks_info += "[<a style='display:inline-block' href='https://huggingface.co/spaces/microsoft/Promptist' _blank><font style='color:blue;weight:bold;'>Promptist</font></a>]"
56
 
57
  for space_id in space_ids.keys():
58
  print(space_id, space_ids[space_id])
 
101
 
102
  if (typeof window['gradioEl'] === 'undefined') {
103
  window['gradioEl'] = gradioEl;
 
104
  tabitems = window['gradioEl'].querySelectorAll('.tabitem');
105
+ tabitems_title = window['gradioEl'].querySelectorAll('#tab_demo')[0].children[0].children[0].children;
106
 
107
+ for (var i = 0; i < tabitems.length; i++) {
108
+ if (tabitems_title[i].innerText.indexOf('SD') >= 0) {
109
  tabitems[i].childNodes[0].children[0].style.display='none';
110
  for (var j = 0; j < tabitems[i].childNodes[0].children[1].children.length; j++) {
111
  if (j != 1) {
112
  tabitems[i].childNodes[0].children[1].children[j].style.display='none';
113
  }
114
  }
115
+ } else if (tabitems_title[i].innerText.indexOf('Taiyi') >= 0) {
116
+ tabitems[3].children[0].children[0].children[1].style.display='none';
117
+ tabitems[i].children[0].children[0].children[0].children[0].children[1].style.display='none';
 
 
118
  }
119
  }
120
 
 
123
  tab_demo.setAttribute('style', 'height: 100%;');
124
  const page1 = window['gradioEl'].querySelectorAll('#page_1')[0];
125
  const page2 = window['gradioEl'].querySelectorAll('#page_2')[0];
126
+
127
+ btns_1 = window['gradioEl'].querySelector('#input_col1_row2').children;
128
+ btns_1_split = 100 / btns_1.length;
129
+ for (var i = 0; i < btns_1.length; i++) {
130
+ btns_1[i].setAttribute('style', 'min-width:0px;width:' + btns_1_split + '%;');
131
+ }
132
  page1.style.display = "none";
133
  page2.style.display = "block";
134
  window['prevPrompt'] = '';
 
142
  window['doCheckPrompt'] = 1;
143
  window['prevPrompt'] = text_value;
144
  tabitems = window['gradioEl'].querySelectorAll('.tabitem');
145
+ for (var i = 0; i < tabitems.length; i++) {
146
+ inputText = null;
147
+ if (tabitems_title[i].innerText.indexOf('SD') >= 0) {
148
+ text_value = window['gradioEl'].querySelectorAll('#prompt_work')[0].querySelectorAll('textarea')[0].value;
149
  inputText = tabitems[i].children[0].children[1].children[0].querySelectorAll('.gr-text-input')[0];
150
+ } else if (tabitems_title[i].innerText.indexOf('Taiyi') >= 0) {
151
+ text_value = window['gradioEl'].querySelectorAll('#prompt_work_zh')[0].querySelectorAll('textarea')[0].value;
152
+ inputText = tabitems[i].children[0].children[0].children[1].querySelectorAll('.gr-text-input')[0];
153
+ }
154
+ if (inputText) {
155
+ setNativeValue(inputText, text_value);
156
+ inputText.dispatchEvent(new Event('input', { bubbles: true }));
157
  }
 
 
158
  }
159
 
160
  setTimeout(function() {
161
  btns = window['gradioEl'].querySelectorAll('button');
162
  for (var i = 0; i < btns.length; i++) {
163
+ if (['Generate image','Run', '生成图像(Generate)'].includes(btns[i].innerText)) {
164
  btns[i].click();
165
  }
166
  }
 
176
  return false;
177
  }"""
178
 
179
+ def prompt_extend(prompt, PM):
180
  prompt_en = getTextTrans(prompt, source='zh', target='en')
181
+ if PM == 1:
182
+ extend_prompt_en = extend_prompt_pipe(prompt_en+',', num_return_sequences=1)[0]["generated_text"]
183
+ else:
184
+ extend_prompt_en = extend_prompt_microsoft(prompt_en)
185
+
186
  if (prompt != prompt_en):
187
+ logger.info(f"extend_prompt__1_[{PM}]_")
188
  extend_prompt_out = getTextTrans(extend_prompt_en, source='en', target='zh')
189
  else:
190
+ logger.info(f"extend_prompt__2_[{PM}]_")
191
  extend_prompt_out = extend_prompt_en
192
 
193
  return extend_prompt_out
194
 
195
+ def prompt_extend_1(prompt):
196
+ extend_prompt_out = prompt_extend(prompt, 1)
197
+ return extend_prompt_out
198
+
199
+ def prompt_extend_2(prompt):
200
+ extend_prompt_out = prompt_extend(prompt, 2)
201
+ return extend_prompt_out
202
+
203
  def prompt_draw(prompt):
204
  prompt_en = getTextTrans(prompt, source='zh', target='en')
205
  if (prompt != prompt_en):
206
  logger.info(f"draw_prompt______1__")
207
+ prompt_zh = prompt
208
  else:
209
  logger.info(f"draw_prompt______2__")
210
+ prompt_zh = getTextTrans(prompt, source='en', target='zh')
211
+
212
+ return prompt_en, prompt_zh
213
 
214
  with gr.Blocks(title='Text-to-Image') as demo:
215
  with gr.Group(elem_id="page_1", visible=True) as page_1:
 
224
  with gr.Row(elem_id="input_col1_row1"):
225
  prompt_input0 = gr.Textbox(lines=2, label="Original prompt", visible=True)
226
  with gr.Row(elem_id="input_col1_row2"):
227
+ with gr.Column(elem_id="input_col1_row2_col0"):
228
  draw_btn_0 = gr.Button(value = "Generate(original)", elem_id="draw-btn-0")
229
+ with gr.Column(elem_id="input_col1_row2_col1"):
230
+ extend_btn_1 = gr.Button(value = "Extend_1",elem_id="extend-btn-1")
231
  with gr.Column(elem_id="input_col1_row2_col2"):
232
+ extend_btn_2 = gr.Button(value = "Extend_2",elem_id="extend-btn-2")
233
  with gr.Column(id="input_col2"):
234
  prompt_input1 = gr.Textbox(lines=2, label="Extend prompt", visible=True)
235
  draw_btn_1 = gr.Button(value = "Generate(extend)", elem_id="draw-btn-1")
236
  prompt_work = gr.Textbox(lines=1, label="prompt_work", elem_id="prompt_work", visible=False)
237
+ prompt_work_zh = gr.Textbox(lines=1, label="prompt_work_zh", elem_id="prompt_work_zh", visible=False)
238
  with gr.Row(elem_id='tab_demo', visible=True).style(height=200):
239
  tab_demo = gr.TabbedInterface(tab_actions, tab_titles)
240
  with gr.Row():
241
+ gr.HTML(f"<p>{thanks_info}</p>")
242
 
243
+ extend_btn_1.click(fn=prompt_extend_1, inputs=[prompt_input0], outputs=[prompt_input1])
244
+ extend_btn_2.click(fn=prompt_extend_2, inputs=[prompt_input0], outputs=[prompt_input1])
245
+ draw_btn_0.click(fn=prompt_draw, inputs=[prompt_input0], outputs=[prompt_work, prompt_work_zh])
246
+ draw_btn_1.click(fn=prompt_draw, inputs=[prompt_input1], outputs=[prompt_work, prompt_work_zh])
247
 
248
  demo.launch()