THUdyh commited on
Commit
223aac8
1 Parent(s): 0e70eb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -38
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  import re
 
4
  from decord import VideoReader, cpu
5
  from PIL import Image
6
  import numpy as np
@@ -12,7 +13,6 @@ import subprocess
12
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
13
 
14
  import sys
15
- # sys.path.append('/mnt/lzy/oryx-demo')
16
  from oryx.conversation import conv_templates, SeparatorStyle
17
  from oryx.model.builder import load_pretrained_model
18
  from oryx.utils import disable_torch_init
@@ -83,14 +83,23 @@ def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_im
83
  return input_ids
84
 
85
  @spaces.GPU(duration=120)
86
- def oryx_inference(video, text):
87
- vr = VideoReader(video, ctx=cpu(0))
88
- total_frame_num = len(vr)
89
- fps = round(vr.get_avg_fps())
90
- uniform_sampled_frames = np.linspace(0, total_frame_num - 1, 64, dtype=int)
91
- frame_idx = uniform_sampled_frames.tolist()
92
- spare_frames = vr.get_batch(frame_idx).asnumpy()
93
- video = [Image.fromarray(frame) for frame in spare_frames]
 
 
 
 
 
 
 
 
 
94
 
95
  conv_mode = "qwen_1_5"
96
 
@@ -104,39 +113,73 @@ def oryx_inference(video, text):
104
 
105
  input_ids = preprocess_qwen([{'from': 'human','value': question},{'from': 'gpt','value': None}], tokenizer, has_image=True).to(device)
106
 
107
- video_processed = []
108
- for idx, frame in enumerate(video):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  image_processor.do_resize = False
110
  image_processor.do_center_crop = False
111
- frame = process_anyres_video_genli(frame, image_processor)
112
-
113
- if frame_idx is not None and idx in frame_idx:
114
- video_processed.append(frame.unsqueeze(0))
115
- elif frame_idx is None:
116
- video_processed.append(frame.unsqueeze(0))
117
-
118
- if frame_idx is None:
119
- frame_idx = np.arange(0, len(video_processed), dtype=int).tolist()
120
-
121
- video_processed = torch.cat(video_processed, dim=0).bfloat16().to(device)
122
- video_processed = (video_processed, video_processed)
123
-
124
- video_data = (video_processed, (384, 384), "video")
 
 
 
125
 
126
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
127
  keywords = [stop_str]
128
 
129
  with torch.inference_mode():
130
- output_ids = model.generate(
131
- inputs=input_ids,
132
- images=video_data[0][0],
133
- images_highres=video_data[0][1],
134
- modalities=video_data[2],
135
- do_sample=False,
136
- temperature=0,
137
- max_new_tokens=1024,
138
- use_cache=True,
139
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
 
142
  outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
@@ -147,12 +190,23 @@ def oryx_inference(video, text):
147
  return outputs
148
 
149
  # Define input and output for the Gradio interface
 
150
  demo = gr.Interface(
151
  fn=oryx_inference,
152
- inputs=[gr.Video(label="Input Video"), gr.Textbox(label="Input Text")],
153
  outputs="text",
154
- title="Oryx Inference",
155
- description="This is a demo for Oryx inference."
 
 
 
 
 
 
 
 
 
 
156
  )
157
 
158
  # Launch the Gradio app
 
1
  import gradio as gr
2
  import torch
3
  import re
4
+ import os
5
  from decord import VideoReader, cpu
6
  from PIL import Image
7
  import numpy as np
 
13
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
14
 
15
  import sys
 
16
  from oryx.conversation import conv_templates, SeparatorStyle
17
  from oryx.model.builder import load_pretrained_model
18
  from oryx.utils import disable_torch_init
 
83
  return input_ids
84
 
85
  @spaces.GPU(duration=120)
86
+ def oryx_inference(multimodal):
87
+ visual, text = multimodal["files"][0], multimodal["text"]
88
+ if visual.endswith(".mp4"):
89
+ modality = "video"
90
+ else:
91
+ modality = "image"
92
+ if modality == "video":
93
+ vr = VideoReader(visual, ctx=cpu(0))
94
+ total_frame_num = len(vr)
95
+ fps = round(vr.get_avg_fps())
96
+ uniform_sampled_frames = np.linspace(0, total_frame_num - 1, 64, dtype=int)
97
+ frame_idx = uniform_sampled_frames.tolist()
98
+ spare_frames = vr.get_batch(frame_idx).asnumpy()
99
+ video = [Image.fromarray(frame) for frame in spare_frames]
100
+ else:
101
+ image = [Image.open(visual)]
102
+ image_sizes = [image[0].size]
103
 
104
  conv_mode = "qwen_1_5"
105
 
 
113
 
114
  input_ids = preprocess_qwen([{'from': 'human','value': question},{'from': 'gpt','value': None}], tokenizer, has_image=True).to(device)
115
 
116
+ if modality == "video":
117
+ video_processed = []
118
+ for idx, frame in enumerate(video):
119
+ image_processor.do_resize = False
120
+ image_processor.do_center_crop = False
121
+ frame = process_anyres_video_genli(frame, image_processor)
122
+
123
+ if frame_idx is not None and idx in frame_idx:
124
+ video_processed.append(frame.unsqueeze(0))
125
+ elif frame_idx is None:
126
+ video_processed.append(frame.unsqueeze(0))
127
+
128
+ if frame_idx is None:
129
+ frame_idx = np.arange(0, len(video_processed), dtype=int).tolist()
130
+
131
+ video_processed = torch.cat(video_processed, dim=0).bfloat16().to(device)
132
+ video_processed = (video_processed, video_processed)
133
+
134
+ video_data = (video_processed, (384, 384), "video")
135
+ else:
136
  image_processor.do_resize = False
137
  image_processor.do_center_crop = False
138
+ image_tensor, image_highres_tensor = [], []
139
+ for visual in image:
140
+ image_tensor_, image_highres_tensor_ = process_anyres_highres_image_genli(visual, image_processor)
141
+ image_tensor.append(image_tensor_)
142
+ image_highres_tensor.append(image_highres_tensor_)
143
+ if all(x.shape == image_tensor[0].shape for x in image_tensor):
144
+ image_tensor = torch.stack(image_tensor, dim=0)
145
+ if all(x.shape == image_highres_tensor[0].shape for x in image_highres_tensor):
146
+ image_highres_tensor = torch.stack(image_highres_tensor, dim=0)
147
+ if type(image_tensor) is list:
148
+ image_tensor = [_image.bfloat16().to(device) for _image in image_tensor]
149
+ else:
150
+ image_tensor = image_tensor.bfloat16().to(device)
151
+ if type(image_highres_tensor) is list:
152
+ image_highres_tensor = [_image.bfloat16().to(device) for _image in image_highres_tensor]
153
+ else:
154
+ image_highres_tensor = image_highres_tensor.bfloat16().to(device)
155
 
156
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
157
  keywords = [stop_str]
158
 
159
  with torch.inference_mode():
160
+ if modality == "video":
161
+ output_ids = model.generate(
162
+ inputs=input_ids,
163
+ images=video_data[0][0],
164
+ images_highres=video_data[0][1],
165
+ modalities=video_data[2],
166
+ do_sample=False,
167
+ temperature=0,
168
+ max_new_tokens=1024,
169
+ use_cache=True,
170
+ )
171
+ else:
172
+ output_ids = model.generate(
173
+ inputs=input_ids,
174
+ images=image_tensor,
175
+ images_highres=image_highres_tensor,
176
+ image_sizes=image_sizes,
177
+ modalities=['image'],
178
+ do_sample=False,
179
+ temperature=0,
180
+ max_new_tokens=1024,
181
+ use_cache=True,
182
+ )
183
 
184
 
185
  outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
 
190
  return outputs
191
 
192
  # Define input and output for the Gradio interface
193
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
194
  demo = gr.Interface(
195
  fn=oryx_inference,
196
+ inputs=gr.MultimodalTextbox(file_types=[".mp4", "image"],placeholder="Enter message or upload file..."),
197
  outputs="text",
198
+ examples=[
199
+ {
200
+ "files":[f"{cur_dir}/case/case1.mp4"],
201
+ "text":"Describe what is happening in this video in detail.",
202
+ },
203
+ {
204
+ "files":[f"{cur_dir}/case/image.png"],
205
+ "text":"Describe this icon.",
206
+ },
207
+ ],
208
+ title="Oryx Demo",
209
+ description="A huggingface space for Oryx-7B."
210
  )
211
 
212
  # Launch the Gradio app