Anonymous commited on
Commit
e84616f
1 Parent(s): e994f84

add spaces

Browse files
Files changed (1) hide show
  1. app.py +72 -57
app.py CHANGED
@@ -22,6 +22,9 @@ from funcs import (
22
  from utils.utils import instantiate_from_config
23
  from utils.utils_freetraj import plan_path
24
 
 
 
 
25
  MAX_KEYS = 5
26
 
27
  ckpt_dir_512 = "checkpoints/base_512_v2"
@@ -56,7 +59,7 @@ def check(radio_mode):
56
  video_bbox_path = "output_freetraj_bbox.mp4"
57
  return video_path, video_bbox_path
58
 
59
- @spaces.GPU(duration=270)
60
  def infer(*user_args):
61
  prompt_in = user_args[0]
62
  target_indices = user_args[1]
@@ -75,9 +78,6 @@ def infer(*user_args):
75
  w_positions = user_args[-MAX_KEYS:]
76
  print(user_args)
77
 
78
- video_length = 16
79
- width = 512
80
- height = 320
81
  if radio_mode == 'ori':
82
  config_512 = "configs/inference_t2v_512_v2.0.yaml"
83
  else:
@@ -110,15 +110,6 @@ def infer(*user_args):
110
 
111
  config_512 = OmegaConf.load(config_512)
112
  model_config_512 = config_512.pop("model", OmegaConf.create())
113
- model = instantiate_from_config(model_config_512)
114
- model = model.cuda()
115
- model = load_model_checkpoint(model, ckpt_path_512)
116
- model.eval()
117
-
118
- if seed is None:
119
- seed = int.from_bytes(os.urandom(2), "big")
120
- print(f"Using seed: {seed}")
121
- seed_everything(seed)
122
 
123
  args = argparse.Namespace(
124
  mode="base",
@@ -127,57 +118,20 @@ def infer(*user_args):
127
  ddim_steps=ddim_steps,
128
  ddim_eta=0.0,
129
  bs=1,
130
- height=height,
131
- width=width,
132
- frames=video_length,
133
  fps=video_fps,
134
  unconditional_guidance_scale=unconditional_guidance_scale,
135
  unconditional_guidance_scale_temporal=None,
136
  cond_input=None,
 
 
137
  ddim_edit = ddim_edit,
 
 
 
138
  )
139
 
140
- ## latent noise shape
141
- h, w = args.height // 8, args.width // 8
142
- frames = model.temporal_length if args.frames < 0 else args.frames
143
- channels = model.channels
144
-
145
- batch_size = 1
146
- noise_shape = [batch_size, channels, frames, h, w]
147
- fps = torch.tensor([args.fps] * batch_size).to(model.device).long()
148
- prompts = [prompt_in]
149
- text_emb = model.get_learned_conditioning(prompts)
150
-
151
- cond = {"c_crossattn": [text_emb], "fps": fps}
152
-
153
- ## inference
154
- if radio_mode == 'ori':
155
- batch_samples = batch_ddim_sampling(
156
- model,
157
- cond,
158
- noise_shape,
159
- args.n_samples,
160
- args.ddim_steps,
161
- args.ddim_eta,
162
- args.unconditional_guidance_scale,
163
- args=args,
164
- )
165
- else:
166
- batch_samples = batch_ddim_sampling_freetraj(
167
- model,
168
- cond,
169
- noise_shape,
170
- args.n_samples,
171
- args.ddim_steps,
172
- args.ddim_eta,
173
- args.unconditional_guidance_scale,
174
- idx_list = idx_list,
175
- input_traj = input_traj,
176
- args=args,
177
- )
178
-
179
- vid_tensor = batch_samples[0]
180
- video = vid_tensor.detach().cpu()
181
  video = torch.clamp(video.float(), -1.0, 1.0)
182
  video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
183
 
@@ -251,6 +205,67 @@ def infer(*user_args):
251
 
252
  return video_path, video_bbox_path
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  examples = [
256
  ["A squirrel jumping from one tree to another.",],
 
22
  from utils.utils import instantiate_from_config
23
  from utils.utils_freetraj import plan_path
24
 
25
+ video_length = 16
26
+ width = 512
27
+ height = 320
28
  MAX_KEYS = 5
29
 
30
  ckpt_dir_512 = "checkpoints/base_512_v2"
 
59
  video_bbox_path = "output_freetraj_bbox.mp4"
60
  return video_path, video_bbox_path
61
 
62
+
63
  def infer(*user_args):
64
  prompt_in = user_args[0]
65
  target_indices = user_args[1]
 
78
  w_positions = user_args[-MAX_KEYS:]
79
  print(user_args)
80
 
 
 
 
81
  if radio_mode == 'ori':
82
  config_512 = "configs/inference_t2v_512_v2.0.yaml"
83
  else:
 
110
 
111
  config_512 = OmegaConf.load(config_512)
112
  model_config_512 = config_512.pop("model", OmegaConf.create())
 
 
 
 
 
 
 
 
 
113
 
114
  args = argparse.Namespace(
115
  mode="base",
 
118
  ddim_steps=ddim_steps,
119
  ddim_eta=0.0,
120
  bs=1,
 
 
 
121
  fps=video_fps,
122
  unconditional_guidance_scale=unconditional_guidance_scale,
123
  unconditional_guidance_scale_temporal=None,
124
  cond_input=None,
125
+ prompt_in = prompt_in,
126
+ seed = seed,
127
  ddim_edit = ddim_edit,
128
+ model_config_512 = model_config_512,
129
+ idx_list = idx_list,
130
+ input_traj = input_traj,
131
  )
132
 
133
+ video = infer_gpu_part(args)
134
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  video = torch.clamp(video.float(), -1.0, 1.0)
136
  video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
137
 
 
205
 
206
  return video_path, video_bbox_path
207
 
208
+
209
+
210
+ @spaces.GPU(duration=270)
211
+ def infer_gpu_part(args):
212
+
213
+ model = instantiate_from_config(args.model_config_512)
214
+ model = model.cuda()
215
+ model = load_model_checkpoint(model, ckpt_path_512)
216
+ model.eval()
217
+
218
+ if args.seed is None:
219
+ seed = int.from_bytes(os.urandom(2), "big")
220
+ else:
221
+ seed = args.seed
222
+ print(f"Using seed: {seed}")
223
+ seed_everything(seed)
224
+
225
+ ## latent noise shape
226
+ h, w = height // 8, width // 8
227
+ frames = video_length
228
+ channels = model.channels
229
+
230
+ batch_size = 1
231
+ noise_shape = [batch_size, channels, frames, h, w]
232
+ fps = torch.tensor([args.fps] * batch_size).to(model.device).long()
233
+ prompts = [args.prompt_in]
234
+ text_emb = model.get_learned_conditioning(prompts)
235
+
236
+ cond = {"c_crossattn": [text_emb], "fps": fps}
237
+
238
+ ## inference
239
+ if radio_mode == 'ori':
240
+ batch_samples = batch_ddim_sampling(
241
+ model,
242
+ cond,
243
+ noise_shape,
244
+ args.n_samples,
245
+ args.ddim_steps,
246
+ args.ddim_eta,
247
+ args.unconditional_guidance_scale,
248
+ args=args,
249
+ )
250
+ else:
251
+ batch_samples = batch_ddim_sampling_freetraj(
252
+ model,
253
+ cond,
254
+ noise_shape,
255
+ args.n_samples,
256
+ args.ddim_steps,
257
+ args.ddim_eta,
258
+ args.unconditional_guidance_scale,
259
+ idx_list = args.idx_list,
260
+ input_traj = args.input_traj,
261
+ args=args,
262
+ )
263
+
264
+ vid_tensor = batch_samples[0]
265
+ video = vid_tensor.detach().cpu()
266
+
267
+ return video
268
+
269
 
270
  examples = [
271
  ["A squirrel jumping from one tree to another.",],