MAGAer13 commited on
Commit
8e4225b
1 Parent(s): 48e8dfb
Files changed (1) hide show
  1. app.py +125 -24
app.py CHANGED
@@ -80,7 +80,7 @@ def contains_chinese(string):
80
  def clear_history(request: gr.Request):
81
  state = default_conversation.copy()
82
 
83
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
84
 
85
  def http_bot(state, topk, max_new_tokens, random_seed, request: gr.Request):
86
  prompt = after_process_image(state.get_prompt())
@@ -88,10 +88,10 @@ def http_bot(state, topk, max_new_tokens, random_seed, request: gr.Request):
88
  state.messages[-1][-1] = "▌"
89
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
90
 
91
- if contains_chinese(prompt):
92
- state.messages[-1][-1] = "**CURRENTLY WE ONLY SUPPORT ENGLISH. PLEASE REFRESH THIS PAGE TO RESTART.**"
93
- yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
94
- return
95
 
96
  try:
97
  data = get_inputs(prompt, images, topk, max_new_tokens, random_seed)
@@ -111,11 +111,105 @@ def http_bot(state, topk, max_new_tokens, random_seed, request: gr.Request):
111
  state.messages[-1][-1] = state.messages[-1][-1][:-1]
112
  yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  title_markdown = ("""
115
  # mPLUG-Owl🦉 (GitHub Repo: https://github.com/X-PLUG/mPLUG-Owl)
 
116
  """)
117
 
118
  tos_markdown = ("""
 
 
119
  ### Terms of use
120
  By using this service, users are required to agree to the following terms:
121
  The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
@@ -135,11 +229,9 @@ version 1.0
135
  """
136
 
137
  def build_demo():
138
- #with gr.Blocks(title="mPLUG-Owl🦉", theme=gr.themes.Base(), css=css) as demo:
139
- with gr.Blocks(title="mPLUG-Owl🦉") as demo:
140
  state = gr.State()
141
- with gr.Box():
142
- gr.Markdown(SHARED_UI_WARNING)
143
 
144
  gr.Markdown(title_markdown)
145
 
@@ -148,6 +240,8 @@ def build_demo():
148
 
149
  imagebox = gr.Image(type="pil")
150
 
 
 
151
  with gr.Accordion("Parameters", open=True, visible=False) as parameter_row:
152
  topk = gr.Slider(minimum=1, maximum=5, value=5, step=1, interactive=True, label="Top K",)
153
  max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
@@ -155,7 +249,7 @@ def build_demo():
155
  gr.Markdown(tos_markdown)
156
 
157
  with gr.Column(scale=6):
158
- chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
159
  with gr.Row():
160
  with gr.Column(scale=8):
161
  textbox = gr.Textbox(show_label=False,
@@ -178,9 +272,14 @@ def build_demo():
178
  [f'examples/laundry.jpeg', 'Why this happens and how to fix it?'],
179
  [f'examples/ca.jpeg', "What do you think about the person's behavior?"],
180
  [f'examples/monalisa-fun.jpg', 'Do you know who drew this painting?​'],
181
- [f"examples/Yao_Ming.jpeg", "What is the name of the man on the right?"],
182
  ], inputs=[imagebox, textbox])
183
 
 
 
 
 
 
184
  gr.Markdown(learn_more_markdown)
185
  url_params = gr.JSON(visible=False)
186
 
@@ -191,18 +290,20 @@ def build_demo():
191
  [state], [textbox, upvote_btn, downvote_btn, flag_btn])
192
  flag_btn.click(flag_last_response,
193
  [state], [textbox, upvote_btn, downvote_btn, flag_btn])
194
- regenerate_btn.click(regenerate, state,
195
- [state, chatbot, textbox, imagebox] + btn_list).then(
196
- http_bot, [state, topk, max_output_tokens, temperature],
197
- [state, chatbot] + btn_list)
198
- clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list)
199
-
200
- textbox.submit(add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
201
- ).then(http_bot, [state, topk, max_output_tokens, temperature],
202
- [state, chatbot] + btn_list)
203
- submit_btn.click(add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
204
- ).then(http_bot, [state, topk, max_output_tokens, temperature],
205
- [state, chatbot] + btn_list)
 
 
206
 
207
  demo.load(load_demo, [url_params], [state,
208
  chatbot, textbox, submit_btn, button_row, parameter_row],
@@ -218,7 +319,7 @@ if __name__ == "__main__":
218
  parser.add_argument("--host", type=str, default="0.0.0.0")
219
  parser.add_argument("--debug", action="store_true", help="using debug mode")
220
  parser.add_argument("--port", type=int)
221
- parser.add_argument("--concurrency-count", type=int, default=100)
222
  args = parser.parse_args()
223
  demo = build_demo()
224
  demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False).launch(server_name=args.host, debug=args.debug, server_port=args.port, share=False)
 
80
  def clear_history(request: gr.Request):
81
  state = default_conversation.copy()
82
 
83
+ return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
84
 
85
  def http_bot(state, topk, max_new_tokens, random_seed, request: gr.Request):
86
  prompt = after_process_image(state.get_prompt())
 
88
  state.messages[-1][-1] = "▌"
89
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
90
 
91
+ # if contains_chinese(prompt):
92
+ # state.messages[-1][-1] = "**CURRENTLY WE ONLY SUPPORT ENGLISH. PLEASE REFRESH THIS PAGE TO RESTART.**"
93
+ # yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
94
+ # return
95
 
96
  try:
97
  data = get_inputs(prompt, images, topk, max_new_tokens, random_seed)
 
111
  state.messages[-1][-1] = state.messages[-1][-1][:-1]
112
  yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
113
 
114
+
115
+ def add_text_http_bot(state, text, image, video, topk, max_new_tokens, random_seed, request: gr.Request):
116
+ if len(text) <= 0 and (image is None or video is None):
117
+ state.skip_next = True
118
+ return (state, state.to_gradio_chatbot(), "", None, None) + (no_change_btn,) * 5
119
+
120
+ if image is not None:
121
+ multimodal_msg = None
122
+ if '<image>' not in text:
123
+ text = text + '\n<image>'
124
+
125
+ if multimodal_msg is not None:
126
+ return (state, state.to_gradio_chatbot(), multimodal_msg, None, None) + (
127
+ no_change_btn,) * 5
128
+ text = (text, image)
129
+
130
+ if video is not None:
131
+ num_frames = 4
132
+ if '<image>' not in text:
133
+ text = text + '\n<image>' * num_frames
134
+ text = (text, video)
135
+
136
+ state.append_message(state.roles[0], text)
137
+ state.append_message(state.roles[1], None)
138
+ state.skip_next = False
139
+
140
+ yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
141
+
142
+ prompt = after_process_image(state.get_prompt())
143
+ images = state.get_images()
144
+ state.messages[-1][-1] = "▌"
145
+ yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
146
+
147
+ # if contains_chinese(prompt):
148
+ # state.messages[-1][-1] = "**CURRENTLY WE ONLY SUPPORT ENGLISH. PLEASE REFRESH THIS PAGE TO RESTART OR CLEAR HISTORY.**"
149
+ # yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
150
+ # return
151
+
152
+ try:
153
+ data = get_inputs(prompt, images, topk, max_new_tokens, random_seed)
154
+ output = model.prediction(data, log_dir)
155
+ print(output)
156
+ # output = output.replace("```", "")
157
+
158
+ state.messages[-1][-1] = output + "▌"
159
+ yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
160
+ time.sleep(0.03)
161
+
162
+ except requests.exceptions.RequestException as e:
163
+ state.messages[-1][-1] = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
164
+ yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
165
+ return
166
+
167
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
168
+ yield (state, state.to_gradio_chatbot(), "", None, None) + (enable_btn,) * 5
169
+
170
+
171
+ def regenerate_http_bot(state, topk, max_new_tokens, random_seed, request: gr.Request):
172
+ state.messages[-1][-1] = None
173
+ state.skip_next = False
174
+ yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
175
+
176
+ prompt = after_process_image(state.get_prompt())
177
+ images = state.get_images()
178
+ state.messages[-1][-1] = " "
179
+ yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
180
+
181
+ # if contains_chinese(prompt):
182
+ # state.messages[-1][-1] = "**CURRENTLY WE ONLY SUPPORT ENGLISH. PLEASE REFRESH THIS PAGE TO RESTART OR CLEAR HISTORY.**"
183
+ # yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
184
+ # return
185
+
186
+ try:
187
+ data = get_inputs(prompt, images, topk, max_new_tokens, random_seed)
188
+ output = model.prediction(data, log_dir)
189
+ print(">>>> output:", output)
190
+ # output = output.replace("```", "")
191
+
192
+ state.messages[-1][-1] = output + " "
193
+ yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
194
+ time.sleep(0.03)
195
+
196
+ except requests.exceptions.RequestException as e:
197
+ print(e)
198
+ state.messages[-1][-1] = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
199
+ yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
200
+ return
201
+
202
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
203
+ yield (state, state.to_gradio_chatbot(), "", None, None) + (enable_btn,) * 5
204
+
205
  title_markdown = ("""
206
  # mPLUG-Owl🦉 (GitHub Repo: https://github.com/X-PLUG/mPLUG-Owl)
207
+ If you like our project, please give us a star on Github for latest update.
208
  """)
209
 
210
  tos_markdown = ("""
211
+ **Notice:** The output is generated by top-k sampling scheme and may involve some randomness. For multiple image and video, we cannot ensure it's performance since only image-text pairs are used during training.
212
+
213
  ### Terms of use
214
  By using this service, users are required to agree to the following terms:
215
  The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
 
229
  """
230
 
231
  def build_demo():
232
+ # with gr.Blocks(title="mPLUG-Owl🦉", theme=gr.themes.Base(), css=css) as demo:
233
+ with gr.Blocks(title="mPLUG-Owl🦉", css=css) as demo:
234
  state = gr.State()
 
 
235
 
236
  gr.Markdown(title_markdown)
237
 
 
240
 
241
  imagebox = gr.Image(type="pil")
242
 
243
+ videobox = gr.Video()
244
+
245
  with gr.Accordion("Parameters", open=True, visible=False) as parameter_row:
246
  topk = gr.Slider(minimum=1, maximum=5, value=5, step=1, interactive=True, label="Top K",)
247
  max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
 
249
  gr.Markdown(tos_markdown)
250
 
251
  with gr.Column(scale=6):
252
+ chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=800)
253
  with gr.Row():
254
  with gr.Column(scale=8):
255
  textbox = gr.Textbox(show_label=False,
 
272
  [f'examples/laundry.jpeg', 'Why this happens and how to fix it?'],
273
  [f'examples/ca.jpeg', "What do you think about the person's behavior?"],
274
  [f'examples/monalisa-fun.jpg', 'Do you know who drew this painting?​'],
275
+ # [f"examples/Yao_Ming.jpeg", "What is the name of the man on the right?"],
276
  ], inputs=[imagebox, textbox])
277
 
278
+ gr.Examples(examples=[
279
+ [f"examples/surf.mp4", "What is the man doing?"],
280
+ [f"examples/yoga.mp4", "What did the woman doing?"],
281
+ ], inputs=[videobox, textbox])
282
+
283
  gr.Markdown(learn_more_markdown)
284
  url_params = gr.JSON(visible=False)
285
 
 
290
  [state], [textbox, upvote_btn, downvote_btn, flag_btn])
291
  flag_btn.click(flag_last_response,
292
  [state], [textbox, upvote_btn, downvote_btn, flag_btn])
293
+ # regenerate_btn.click(regenerate, state,
294
+ # [state, chatbot, textbox, imagebox] + btn_list).then(
295
+ # http_bot, [state, topk, max_output_tokens, temperature],
296
+ # [state, chatbot] + btn_list)
297
+ regenerate_btn.click(regenerate_http_bot, [state, topk, max_output_tokens, temperature],
298
+ [state, chatbot, textbox, imagebox, videobox] + btn_list)
299
+ clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox, videobox] + btn_list)
300
+
301
+ # textbox.submit(add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
302
+ # ).then(http_bot, [state, topk, max_output_tokens, temperature],
303
+ # [state, chatbot] + btn_list)
304
+ textbox.submit(add_text_http_bot, [state, textbox, imagebox, videobox, topk, max_output_tokens, temperature], [state, chatbot, textbox, imagebox, videobox] + btn_list)
305
+
306
+ submit_btn.click(add_text_http_bot, [state, textbox, imagebox, videobox, topk, max_output_tokens, temperature], [state, chatbot, textbox, imagebox, videobox] + btn_list)
307
 
308
  demo.load(load_demo, [url_params], [state,
309
  chatbot, textbox, submit_btn, button_row, parameter_row],
 
319
  parser.add_argument("--host", type=str, default="0.0.0.0")
320
  parser.add_argument("--debug", action="store_true", help="using debug mode")
321
  parser.add_argument("--port", type=int)
322
+ parser.add_argument("--concurrency-count", type=int, default=4)
323
  args = parser.parse_args()
324
  demo = build_demo()
325
  demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False).launch(server_name=args.host, debug=args.debug, server_port=args.port, share=False)