GMFTBY commited on
Commit
c2f7915
1 Parent(s): 8d678a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -65
app.py CHANGED
@@ -1,4 +1,5 @@
1
  from transformers import AutoModel, AutoTokenizer
 
2
  import os
3
  import ipdb
4
  import gradio as gr
@@ -10,9 +11,9 @@ import json
10
  # init the model
11
  args = {
12
  'model': 'openllama_peft',
13
- 'imagebind_ckpt_path': './pretrained_ckpt/imagebind_ckpt/',
14
  'vicuna_ckpt_path': 'openllmplayground/vicuna_7b_v0',
15
- 'delta_ckpt_path': './pretrained_ckpt/pandagpt_ckpt/pytorch_model.pt',
16
  'stage': 2,
17
  'max_tgt_len': 128,
18
  'lora_r': 32,
@@ -22,9 +23,8 @@ args = {
22
  model = OpenLLAMAPEFTModel(**args)
23
  delta_ckpt = torch.load(args['delta_ckpt_path'], map_location=torch.device('cpu'))
24
  model.load_state_dict(delta_ckpt, strict=False)
25
- model = model.eval().half().cuda()
26
- print(f'[!] init the model over ...')
27
-
28
 
29
  """Override Chatbot.postprocess"""
30
 
@@ -76,6 +76,25 @@ def parse_text(text):
76
  return text
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def predict(
80
  input,
81
  image_path,
@@ -90,11 +109,11 @@ def predict(
90
  modality_cache,
91
  ):
92
  if image_path is None and audio_path is None and video_path is None and thermal_path is None:
93
- return [(input, "There is no image/audio/video provided. Please upload the file to start a conversation.")]
94
  else:
95
- print(f'[!] image path: {image_path}\n[!] audio path: {audio_path}\n[!] video path: {video_path}\n[!] thermal pah: {thermal_path}')
96
- # prepare the prompt
97
 
 
98
  prompt_text = ''
99
  for idx, (q, a) in enumerate(history):
100
  if idx == 0:
@@ -106,17 +125,18 @@ def predict(
106
  else:
107
  prompt_text += f' Human: {input}'
108
 
109
- response = model.generate({
110
- 'prompt': prompt_text,
111
- 'image_paths': [image_path] if image_path else [],
112
- 'audio_paths': [audio_path] if audio_path else [],
113
- 'video_paths': [video_path] if video_path else [],
114
- 'thermal_paths': [thermal_path] if thermal_path else [],
115
- 'top_p': top_p,
116
- 'temperature': temperature,
117
- 'max_tgt_len': max_length,
118
- 'modality_embeds': modality_cache
119
- })
 
120
  chatbot.append((parse_text(input), parse_text(response)))
121
  history.append((input, response))
122
  return chatbot, history, modality_cache
@@ -125,70 +145,44 @@ def predict(
125
  def reset_user_input():
126
  return gr.update(value='')
127
 
 
 
128
 
129
  def reset_state():
130
  return None, None, None, None, [], [], []
131
 
132
 
133
- with gr.Blocks() as demo:
134
- gr.HTML("""<h1 align="center">PandaGPT</h1>""")
 
 
 
135
 
136
  with gr.Row(scale=4):
137
- with gr.Column(scale=2):
138
  image_path = gr.Image(type="filepath", label="Image", value=None)
139
-
140
- gr.Examples(
141
- [
142
- os.path.join(os.path.dirname(__file__), "/assets/images/bird_image.jpg"),
143
- os.path.join(os.path.dirname(__file__), "/assets/images/dog_image.jpg"),
144
- os.path.join(os.path.dirname(__file__), "/assets/images/car_image.jpg"),
145
- ],
146
- image_path
147
- )
148
- with gr.Column(scale=2):
149
  audio_path = gr.Audio(type="filepath", label="Audio", value=None)
150
- gr.Examples(
151
- [
152
- os.path.join(os.path.dirname(__file__), "/assets/audios/bird_audio.wav"),
153
- os.path.join(os.path.dirname(__file__), "/assets/audios/dog_audio.wav"),
154
- os.path.join(os.path.dirname(__file__), "/assets/audios/car_audio.wav"),
155
- ],
156
- audio_path
157
- )
158
- with gr.Row(scale=4):
159
- with gr.Column(scale=2):
160
  video_path = gr.Video(type='file', label="Video")
161
-
162
- gr.Examples(
163
- [
164
- os.path.join(os.path.dirname(__file__), "/assets/videos/world.mp4"),
165
- os.path.join(os.path.dirname(__file__), "/assets/videos/a.mp4"),
166
- ],
167
- video_path
168
- )
169
- with gr.Column(scale=2):
170
  thermal_path = gr.Image(type="filepath", label="Thermal Image", value=None)
171
 
172
- gr.Examples(
173
- [
174
- os.path.join(os.path.dirname(__file__), "/assets/thermals/190662.jpg"),
175
- os.path.join(os.path.dirname(__file__), "/assets/thermals/210009.jpg"),
176
- ],
177
- thermal_path
178
- )
179
-
180
- chatbot = gr.Chatbot()
181
  with gr.Row():
182
  with gr.Column(scale=4):
183
  with gr.Column(scale=12):
184
  user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False)
185
  with gr.Column(min_width=32, scale=1):
186
- submitBtn = gr.Button("Submit", variant="primary")
 
 
 
187
  with gr.Column(scale=1):
188
  emptyBtn = gr.Button("Clear History")
189
- max_length = gr.Slider(0, 512, value=128, step=1.0, label="Maximum length", interactive=True)
190
- top_p = gr.Slider(0, 1, value=0.4, step=0.01, label="Top P", interactive=True)
191
- temperature = gr.Slider(0, 1, value=0.8, step=0.01, label="Temperature", interactive=True)
192
 
193
  history = gr.State([])
194
  modality_cache = gr.State([])
@@ -214,6 +208,28 @@ with gr.Blocks() as demo:
214
  show_progress=True
215
  )
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  submitBtn.click(reset_user_input, [], [user_input])
218
  emptyBtn.click(reset_state, outputs=[
219
  image_path,
@@ -225,4 +241,4 @@ with gr.Blocks() as demo:
225
  modality_cache
226
  ], show_progress=True)
227
 
228
- demo.queue().launch(share=False, inbrowser=True, server_name='0.0.0.0', server_port=24000)
 
1
  from transformers import AutoModel, AutoTokenizer
2
+ from copy import deepcopy
3
  import os
4
  import ipdb
5
  import gradio as gr
 
11
  # init the model
12
  args = {
13
  'model': 'openllama_peft',
14
+ 'imagebind_ckpt_path': 'pretrained_ckpt/imagebind_ckpt',
15
  'vicuna_ckpt_path': 'openllmplayground/vicuna_7b_v0',
16
+ 'delta_ckpt_path': 'pretrained_ckpt/pandagpt_ckpt/pytorch_model.pt',
17
  'stage': 2,
18
  'max_tgt_len': 128,
19
  'lora_r': 32,
 
23
  model = OpenLLAMAPEFTModel(**args)
24
  delta_ckpt = torch.load(args['delta_ckpt_path'], map_location=torch.device('cpu'))
25
  model.load_state_dict(delta_ckpt, strict=False)
26
+ model = model.half().cuda().eval()
27
+ print(f'[!] init the 13b model over ...')
 
28
 
29
  """Override Chatbot.postprocess"""
30
 
 
76
  return text
77
 
78
 
79
+ def re_predict(
80
+ input,
81
+ image_path,
82
+ audio_path,
83
+ video_path,
84
+ thermal_path,
85
+ chatbot,
86
+ max_length,
87
+ top_p,
88
+ temperature,
89
+ history,
90
+ modality_cache,
91
+ ):
92
+ # drop the latest query and answers and generate again
93
+ q, a = history.pop()
94
+ chatbot.pop()
95
+ return predict(q, image_path, audio_path, video_path, thermal_path, chatbot, max_length, top_p, temperature, history, modality_cache)
96
+
97
+
98
  def predict(
99
  input,
100
  image_path,
 
109
  modality_cache,
110
  ):
111
  if image_path is None and audio_path is None and video_path is None and thermal_path is None:
112
+ return [(input, "图片和音频以及视频为空!请重新上传才能开启对话。")]
113
  else:
114
+ print(f'[!] image path: {image_path}\n[!] audio path: {audio_path}\n[!] video path: {video_path}\n[!] thermal path: {thermal_path}')
 
115
 
116
+ # prepare the prompt
117
  prompt_text = ''
118
  for idx, (q, a) in enumerate(history):
119
  if idx == 0:
 
125
  else:
126
  prompt_text += f' Human: {input}'
127
 
128
+ with torch.no_grad():
129
+ response = model.generate({
130
+ 'prompt': prompt_text,
131
+ 'image_paths': [image_path] if image_path else [],
132
+ 'audio_paths': [audio_path] if audio_path else [],
133
+ 'video_paths': [video_path] if video_path else [],
134
+ 'thermal_paths': [thermal_path] if thermal_path else [],
135
+ 'top_p': top_p,
136
+ 'temperature': temperature,
137
+ 'max_tgt_len': max_length,
138
+ 'modality_embeds': modality_cache
139
+ })
140
  chatbot.append((parse_text(input), parse_text(response)))
141
  history.append((input, response))
142
  return chatbot, history, modality_cache
 
145
  def reset_user_input():
146
  return gr.update(value='')
147
 
148
+ def reset_dialog():
149
+ return [], []
150
 
151
  def reset_state():
152
  return None, None, None, None, [], [], []
153
 
154
 
155
+ with gr.Blocks(scale=4) as demo:
156
+ gr.HTML("""<h1 align="center">PandaGPT</h1>
157
+
158
+ We note that the current online demo uses the 7B version of PandaGPT due to the limitation of computation resource. Better results should be expected when switching to the 13B version of PandaGPT. For more details on how to run 13B PandaGPT, please refer to our [main project repository](https://github.com/yxuansu/PandaGPT).
159
+ """)
160
 
161
  with gr.Row(scale=4):
162
+ with gr.Column(scale=1):
163
  image_path = gr.Image(type="filepath", label="Image", value=None)
164
+ with gr.Column(scale=1):
 
 
 
 
 
 
 
 
 
165
  audio_path = gr.Audio(type="filepath", label="Audio", value=None)
166
+ with gr.Column(scale=1):
 
 
 
 
 
 
 
 
 
167
  video_path = gr.Video(type='file', label="Video")
168
+ with gr.Column(scale=1):
 
 
 
 
 
 
 
 
169
  thermal_path = gr.Image(type="filepath", label="Thermal Image", value=None)
170
 
171
+ chatbot = gr.Chatbot().style(height=300)
 
 
 
 
 
 
 
 
172
  with gr.Row():
173
  with gr.Column(scale=4):
174
  with gr.Column(scale=12):
175
  user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False)
176
  with gr.Column(min_width=32, scale=1):
177
+ with gr.Row(scale=1):
178
+ submitBtn = gr.Button("Submit", variant="primary")
179
+ with gr.Row(scale=1):
180
+ resubmitBtn = gr.Button("Resubmit", variant="primary")
181
  with gr.Column(scale=1):
182
  emptyBtn = gr.Button("Clear History")
183
+ max_length = gr.Slider(0, 400, value=256, step=1.0, label="Maximum length", interactive=True)
184
+ top_p = gr.Slider(0, 1, value=0.01, step=0.01, label="Top P", interactive=True)
185
+ temperature = gr.Slider(0, 1, value=1.0, step=0.01, label="Temperature", interactive=True)
186
 
187
  history = gr.State([])
188
  modality_cache = gr.State([])
 
208
  show_progress=True
209
  )
210
 
211
+ resubmitBtn.click(
212
+ re_predict, [
213
+ user_input,
214
+ image_path,
215
+ audio_path,
216
+ video_path,
217
+ thermal_path,
218
+ chatbot,
219
+ max_length,
220
+ top_p,
221
+ temperature,
222
+ history,
223
+ modality_cache,
224
+ ], [
225
+ chatbot,
226
+ history,
227
+ modality_cache
228
+ ],
229
+ show_progress=True
230
+ )
231
+
232
+
233
  submitBtn.click(reset_user_input, [], [user_input])
234
  emptyBtn.click(reset_state, outputs=[
235
  image_path,
 
241
  modality_cache
242
  ], show_progress=True)
243
 
244
+ demo.launch(enable_queue=True)