Spaces:
Running
on
T4
Running
on
T4
update
Browse files
app.py
CHANGED
@@ -80,7 +80,7 @@ def contains_chinese(string):
|
|
80 |
def clear_history(request: gr.Request):
|
81 |
state = default_conversation.copy()
|
82 |
|
83 |
-
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
84 |
|
85 |
def http_bot(state, topk, max_new_tokens, random_seed, request: gr.Request):
|
86 |
prompt = after_process_image(state.get_prompt())
|
@@ -88,10 +88,10 @@ def http_bot(state, topk, max_new_tokens, random_seed, request: gr.Request):
|
|
88 |
state.messages[-1][-1] = "▌"
|
89 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
90 |
|
91 |
-
if contains_chinese(prompt):
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
|
96 |
try:
|
97 |
data = get_inputs(prompt, images, topk, max_new_tokens, random_seed)
|
@@ -111,11 +111,105 @@ def http_bot(state, topk, max_new_tokens, random_seed, request: gr.Request):
|
|
111 |
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
112 |
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
title_markdown = ("""
|
115 |
# mPLUG-Owl🦉 (GitHub Repo: https://github.com/X-PLUG/mPLUG-Owl)
|
|
|
116 |
""")
|
117 |
|
118 |
tos_markdown = ("""
|
|
|
|
|
119 |
### Terms of use
|
120 |
By using this service, users are required to agree to the following terms:
|
121 |
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
@@ -135,11 +229,9 @@ version 1.0
|
|
135 |
"""
|
136 |
|
137 |
def build_demo():
|
138 |
-
#with gr.Blocks(title="mPLUG-Owl🦉", theme=gr.themes.Base(), css=css) as demo:
|
139 |
-
with gr.Blocks(title="mPLUG-Owl🦉") as demo:
|
140 |
state = gr.State()
|
141 |
-
with gr.Box():
|
142 |
-
gr.Markdown(SHARED_UI_WARNING)
|
143 |
|
144 |
gr.Markdown(title_markdown)
|
145 |
|
@@ -148,6 +240,8 @@ def build_demo():
|
|
148 |
|
149 |
imagebox = gr.Image(type="pil")
|
150 |
|
|
|
|
|
151 |
with gr.Accordion("Parameters", open=True, visible=False) as parameter_row:
|
152 |
topk = gr.Slider(minimum=1, maximum=5, value=5, step=1, interactive=True, label="Top K",)
|
153 |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
@@ -155,7 +249,7 @@ def build_demo():
|
|
155 |
gr.Markdown(tos_markdown)
|
156 |
|
157 |
with gr.Column(scale=6):
|
158 |
-
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=
|
159 |
with gr.Row():
|
160 |
with gr.Column(scale=8):
|
161 |
textbox = gr.Textbox(show_label=False,
|
@@ -178,9 +272,14 @@ def build_demo():
|
|
178 |
[f'examples/laundry.jpeg', 'Why this happens and how to fix it?'],
|
179 |
[f'examples/ca.jpeg', "What do you think about the person's behavior?"],
|
180 |
[f'examples/monalisa-fun.jpg', 'Do you know who drew this painting?'],
|
181 |
-
[f"examples/Yao_Ming.jpeg", "What is the name of the man on the right?"],
|
182 |
], inputs=[imagebox, textbox])
|
183 |
|
|
|
|
|
|
|
|
|
|
|
184 |
gr.Markdown(learn_more_markdown)
|
185 |
url_params = gr.JSON(visible=False)
|
186 |
|
@@ -191,18 +290,20 @@ def build_demo():
|
|
191 |
[state], [textbox, upvote_btn, downvote_btn, flag_btn])
|
192 |
flag_btn.click(flag_last_response,
|
193 |
[state], [textbox, upvote_btn, downvote_btn, flag_btn])
|
194 |
-
regenerate_btn.click(regenerate, state,
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
|
|
|
|
206 |
|
207 |
demo.load(load_demo, [url_params], [state,
|
208 |
chatbot, textbox, submit_btn, button_row, parameter_row],
|
@@ -218,7 +319,7 @@ if __name__ == "__main__":
|
|
218 |
parser.add_argument("--host", type=str, default="0.0.0.0")
|
219 |
parser.add_argument("--debug", action="store_true", help="using debug mode")
|
220 |
parser.add_argument("--port", type=int)
|
221 |
-
parser.add_argument("--concurrency-count", type=int, default=
|
222 |
args = parser.parse_args()
|
223 |
demo = build_demo()
|
224 |
demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False).launch(server_name=args.host, debug=args.debug, server_port=args.port, share=False)
|
|
|
80 |
def clear_history(request: gr.Request):
|
81 |
state = default_conversation.copy()
|
82 |
|
83 |
+
return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
|
84 |
|
85 |
def http_bot(state, topk, max_new_tokens, random_seed, request: gr.Request):
|
86 |
prompt = after_process_image(state.get_prompt())
|
|
|
88 |
state.messages[-1][-1] = "▌"
|
89 |
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
90 |
|
91 |
+
# if contains_chinese(prompt):
|
92 |
+
# state.messages[-1][-1] = "**CURRENTLY WE ONLY SUPPORT ENGLISH. PLEASE REFRESH THIS PAGE TO RESTART.**"
|
93 |
+
# yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
94 |
+
# return
|
95 |
|
96 |
try:
|
97 |
data = get_inputs(prompt, images, topk, max_new_tokens, random_seed)
|
|
|
111 |
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
112 |
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
113 |
|
114 |
+
|
115 |
+
def add_text_http_bot(state, text, image, video, topk, max_new_tokens, random_seed, request: gr.Request):
|
116 |
+
if len(text) <= 0 and (image is None or video is None):
|
117 |
+
state.skip_next = True
|
118 |
+
return (state, state.to_gradio_chatbot(), "", None, None) + (no_change_btn,) * 5
|
119 |
+
|
120 |
+
if image is not None:
|
121 |
+
multimodal_msg = None
|
122 |
+
if '<image>' not in text:
|
123 |
+
text = text + '\n<image>'
|
124 |
+
|
125 |
+
if multimodal_msg is not None:
|
126 |
+
return (state, state.to_gradio_chatbot(), multimodal_msg, None, None) + (
|
127 |
+
no_change_btn,) * 5
|
128 |
+
text = (text, image)
|
129 |
+
|
130 |
+
if video is not None:
|
131 |
+
num_frames = 4
|
132 |
+
if '<image>' not in text:
|
133 |
+
text = text + '\n<image>' * num_frames
|
134 |
+
text = (text, video)
|
135 |
+
|
136 |
+
state.append_message(state.roles[0], text)
|
137 |
+
state.append_message(state.roles[1], None)
|
138 |
+
state.skip_next = False
|
139 |
+
|
140 |
+
yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
|
141 |
+
|
142 |
+
prompt = after_process_image(state.get_prompt())
|
143 |
+
images = state.get_images()
|
144 |
+
state.messages[-1][-1] = "▌"
|
145 |
+
yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
|
146 |
+
|
147 |
+
# if contains_chinese(prompt):
|
148 |
+
# state.messages[-1][-1] = "**CURRENTLY WE ONLY SUPPORT ENGLISH. PLEASE REFRESH THIS PAGE TO RESTART OR CLEAR HISTORY.**"
|
149 |
+
# yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
150 |
+
# return
|
151 |
+
|
152 |
+
try:
|
153 |
+
data = get_inputs(prompt, images, topk, max_new_tokens, random_seed)
|
154 |
+
output = model.prediction(data, log_dir)
|
155 |
+
print(output)
|
156 |
+
# output = output.replace("```", "")
|
157 |
+
|
158 |
+
state.messages[-1][-1] = output + "▌"
|
159 |
+
yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
|
160 |
+
time.sleep(0.03)
|
161 |
+
|
162 |
+
except requests.exceptions.RequestException as e:
|
163 |
+
state.messages[-1][-1] = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
164 |
+
yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
165 |
+
return
|
166 |
+
|
167 |
+
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
168 |
+
yield (state, state.to_gradio_chatbot(), "", None, None) + (enable_btn,) * 5
|
169 |
+
|
170 |
+
|
171 |
+
def regenerate_http_bot(state, topk, max_new_tokens, random_seed, request: gr.Request):
|
172 |
+
state.messages[-1][-1] = None
|
173 |
+
state.skip_next = False
|
174 |
+
yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
|
175 |
+
|
176 |
+
prompt = after_process_image(state.get_prompt())
|
177 |
+
images = state.get_images()
|
178 |
+
state.messages[-1][-1] = " "
|
179 |
+
yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
|
180 |
+
|
181 |
+
# if contains_chinese(prompt):
|
182 |
+
# state.messages[-1][-1] = "**CURRENTLY WE ONLY SUPPORT ENGLISH. PLEASE REFRESH THIS PAGE TO RESTART OR CLEAR HISTORY.**"
|
183 |
+
# yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
184 |
+
# return
|
185 |
+
|
186 |
+
try:
|
187 |
+
data = get_inputs(prompt, images, topk, max_new_tokens, random_seed)
|
188 |
+
output = model.prediction(data, log_dir)
|
189 |
+
print(">>>> output:", output)
|
190 |
+
# output = output.replace("```", "")
|
191 |
+
|
192 |
+
state.messages[-1][-1] = output + " "
|
193 |
+
yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
|
194 |
+
time.sleep(0.03)
|
195 |
+
|
196 |
+
except requests.exceptions.RequestException as e:
|
197 |
+
print(e)
|
198 |
+
state.messages[-1][-1] = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
199 |
+
yield (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
200 |
+
return
|
201 |
+
|
202 |
+
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
203 |
+
yield (state, state.to_gradio_chatbot(), "", None, None) + (enable_btn,) * 5
|
204 |
+
|
205 |
title_markdown = ("""
|
206 |
# mPLUG-Owl🦉 (GitHub Repo: https://github.com/X-PLUG/mPLUG-Owl)
|
207 |
+
If you like our project, please give us a star on Github for latest update.
|
208 |
""")
|
209 |
|
210 |
tos_markdown = ("""
|
211 |
+
**Notice:** The output is generated by top-k sampling scheme and may involve some randomness. For multiple image and video, we cannot ensure it's performance since only image-text pairs are used during training.
|
212 |
+
|
213 |
### Terms of use
|
214 |
By using this service, users are required to agree to the following terms:
|
215 |
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
|
|
229 |
"""
|
230 |
|
231 |
def build_demo():
|
232 |
+
# with gr.Blocks(title="mPLUG-Owl🦉", theme=gr.themes.Base(), css=css) as demo:
|
233 |
+
with gr.Blocks(title="mPLUG-Owl🦉", css=css) as demo:
|
234 |
state = gr.State()
|
|
|
|
|
235 |
|
236 |
gr.Markdown(title_markdown)
|
237 |
|
|
|
240 |
|
241 |
imagebox = gr.Image(type="pil")
|
242 |
|
243 |
+
videobox = gr.Video()
|
244 |
+
|
245 |
with gr.Accordion("Parameters", open=True, visible=False) as parameter_row:
|
246 |
topk = gr.Slider(minimum=1, maximum=5, value=5, step=1, interactive=True, label="Top K",)
|
247 |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
|
|
249 |
gr.Markdown(tos_markdown)
|
250 |
|
251 |
with gr.Column(scale=6):
|
252 |
+
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=800)
|
253 |
with gr.Row():
|
254 |
with gr.Column(scale=8):
|
255 |
textbox = gr.Textbox(show_label=False,
|
|
|
272 |
[f'examples/laundry.jpeg', 'Why this happens and how to fix it?'],
|
273 |
[f'examples/ca.jpeg', "What do you think about the person's behavior?"],
|
274 |
[f'examples/monalisa-fun.jpg', 'Do you know who drew this painting?'],
|
275 |
+
# [f"examples/Yao_Ming.jpeg", "What is the name of the man on the right?"],
|
276 |
], inputs=[imagebox, textbox])
|
277 |
|
278 |
+
gr.Examples(examples=[
|
279 |
+
[f"examples/surf.mp4", "What is the man doing?"],
|
280 |
+
[f"examples/yoga.mp4", "What did the woman doing?"],
|
281 |
+
], inputs=[videobox, textbox])
|
282 |
+
|
283 |
gr.Markdown(learn_more_markdown)
|
284 |
url_params = gr.JSON(visible=False)
|
285 |
|
|
|
290 |
[state], [textbox, upvote_btn, downvote_btn, flag_btn])
|
291 |
flag_btn.click(flag_last_response,
|
292 |
[state], [textbox, upvote_btn, downvote_btn, flag_btn])
|
293 |
+
# regenerate_btn.click(regenerate, state,
|
294 |
+
# [state, chatbot, textbox, imagebox] + btn_list).then(
|
295 |
+
# http_bot, [state, topk, max_output_tokens, temperature],
|
296 |
+
# [state, chatbot] + btn_list)
|
297 |
+
regenerate_btn.click(regenerate_http_bot, [state, topk, max_output_tokens, temperature],
|
298 |
+
[state, chatbot, textbox, imagebox, videobox] + btn_list)
|
299 |
+
clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox, videobox] + btn_list)
|
300 |
+
|
301 |
+
# textbox.submit(add_text, [state, textbox, imagebox], [state, chatbot, textbox, imagebox] + btn_list
|
302 |
+
# ).then(http_bot, [state, topk, max_output_tokens, temperature],
|
303 |
+
# [state, chatbot] + btn_list)
|
304 |
+
textbox.submit(add_text_http_bot, [state, textbox, imagebox, videobox, topk, max_output_tokens, temperature], [state, chatbot, textbox, imagebox, videobox] + btn_list)
|
305 |
+
|
306 |
+
submit_btn.click(add_text_http_bot, [state, textbox, imagebox, videobox, topk, max_output_tokens, temperature], [state, chatbot, textbox, imagebox, videobox] + btn_list)
|
307 |
|
308 |
demo.load(load_demo, [url_params], [state,
|
309 |
chatbot, textbox, submit_btn, button_row, parameter_row],
|
|
|
319 |
parser.add_argument("--host", type=str, default="0.0.0.0")
|
320 |
parser.add_argument("--debug", action="store_true", help="using debug mode")
|
321 |
parser.add_argument("--port", type=int)
|
322 |
+
parser.add_argument("--concurrency-count", type=int, default=4)
|
323 |
args = parser.parse_args()
|
324 |
demo = build_demo()
|
325 |
demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False).launch(server_name=args.host, debug=args.debug, server_port=args.port, share=False)
|