hongfz16 commited on
Commit
e774ced
Β·
verified Β·
1 Parent(s): 45d6b3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -75
app.py CHANGED
@@ -227,87 +227,89 @@ def marching_cube(b, text, global_info):
227
  return path
228
 
229
  def infer(prompt, samples, steps, scale, seed, global_info):
230
- prompt = prompt.replace('/', '')
231
- pl.seed_everything(seed)
232
- batch_size = samples
233
- with torch.no_grad():
234
- noise = None
235
- c = model.get_learned_conditioning([prompt])
236
- unconditional_c = torch.zeros_like(c)
237
- sample, _ = sampler.sample(
238
- S=steps,
239
- batch_size=batch_size,
240
- shape=shape,
241
- verbose=False,
242
- x_T = noise,
243
- conditioning = c.repeat(batch_size, 1, 1),
244
- unconditional_guidance_scale=scale,
245
- unconditional_conditioning=unconditional_c.repeat(batch_size, 1, 1)
246
- )
247
- decode_res = model.decode_first_stage(sample)
248
-
249
- big_video_list = []
250
-
251
- global_info['decode_res'] = decode_res
252
-
253
- for b in range(batch_size):
254
- def render_img(v):
255
- rgb_sample, _ = model.first_stage_model.render_triplane_eg3d_decoder(
256
- decode_res[b:b+1], batch_rays_list[v:v+1].to(device), torch.zeros(1, H, H, 3).to(device),
257
- )
258
- rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0]
259
- rgb_sample = np.stack(
260
- [rgb_sample[..., 2], rgb_sample[..., 1], rgb_sample[..., 0]], -1
261
- )
262
- rgb_sample = add_text(rgb_sample, str(b))
263
- return rgb_sample
264
-
265
- view_num = len(batch_rays_list)
266
- video_list = []
267
- for v in tqdm.tqdm(range(view_num//8*3, view_num//8*5, 2)):
268
- rgb_sample = render_img(v)
269
- video_list.append(rgb_sample)
270
- big_video_list.append(video_list)
271
- # if batch_size == 2:
272
- # cat_video_list = [
273
- # np.concatenate([big_video_list[j][i] for j in range(len(big_video_list))], 1) \
274
- # for i in range(len(big_video_list[0]))
275
- # ]
276
- # elif batch_size > 2:
277
- # if batch_size == 3:
278
- # big_video_list.append(
279
- # [np.zeros_like(f) for f in big_video_list[0]]
280
- # )
281
- # cat_video_list = [
282
- # np.concatenate([
283
- # np.concatenate([big_video_list[0][i], big_video_list[1][i]], 1),
284
- # np.concatenate([big_video_list[2][i], big_video_list[3][i]], 1),
285
- # ], 0) \
286
- # for i in range(len(big_video_list[0]))
287
- # ]
288
- # else:
289
- # cat_video_list = big_video_list[0]
290
-
291
- for _ in range(4 - batch_size):
292
- big_video_list.append(
293
- [np.zeros_like(f) + 255 for f in big_video_list[0]]
294
  )
295
- cat_video_list = [
296
- np.concatenate([
297
- np.concatenate([big_video_list[0][i], big_video_list[1][i]], 1),
298
- np.concatenate([big_video_list[2][i], big_video_list[3][i]], 1),
299
- ], 0) \
300
- for i in range(len(big_video_list[0]))
301
- ]
302
-
303
- path = f"tmp/{prompt.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.mp4"
304
- imageio.mimwrite(path, np.stack(cat_video_list, 0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
  return global_info, path
307
 
308
  def infer_stage2(prompt, selection, seed, global_info, iters):
309
  prompt = prompt.replace('/', '')
310
- mesh_path = marching_cube(int(selection), prompt, global_info)
 
311
  mesh_name = mesh_path.split('/')[-1][:-4]
312
  # if2_cmd = f"threefiner if2 --mesh {mesh_path} --prompt \"{prompt}\" --outdir tmp --save {mesh_name}_if2.glb --text_dir --front_dir=-y"
313
  # print(if2_cmd)
 
227
  return path
228
 
229
  def infer(prompt, samples, steps, scale, seed, global_info):
230
+ with torch.cuda.device(1):
231
+ prompt = prompt.replace('/', '')
232
+ pl.seed_everything(seed)
233
+ batch_size = samples
234
+ with torch.no_grad():
235
+ noise = None
236
+ c = model.get_learned_conditioning([prompt])
237
+ unconditional_c = torch.zeros_like(c)
238
+ sample, _ = sampler.sample(
239
+ S=steps,
240
+ batch_size=batch_size,
241
+ shape=shape,
242
+ verbose=False,
243
+ x_T = noise,
244
+ conditioning = c.repeat(batch_size, 1, 1),
245
+ unconditional_guidance_scale=scale,
246
+ unconditional_conditioning=unconditional_c.repeat(batch_size, 1, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  )
248
+ decode_res = model.decode_first_stage(sample)
249
+
250
+ big_video_list = []
251
+
252
+ global_info['decode_res'] = decode_res
253
+
254
+ for b in range(batch_size):
255
+ def render_img(v):
256
+ rgb_sample, _ = model.first_stage_model.render_triplane_eg3d_decoder(
257
+ decode_res[b:b+1], batch_rays_list[v:v+1].to(device), torch.zeros(1, H, H, 3).to(device),
258
+ )
259
+ rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0]
260
+ rgb_sample = np.stack(
261
+ [rgb_sample[..., 2], rgb_sample[..., 1], rgb_sample[..., 0]], -1
262
+ )
263
+ rgb_sample = add_text(rgb_sample, str(b))
264
+ return rgb_sample
265
+
266
+ view_num = len(batch_rays_list)
267
+ video_list = []
268
+ for v in tqdm.tqdm(range(view_num//8*3, view_num//8*5, 2)):
269
+ rgb_sample = render_img(v)
270
+ video_list.append(rgb_sample)
271
+ big_video_list.append(video_list)
272
+ # if batch_size == 2:
273
+ # cat_video_list = [
274
+ # np.concatenate([big_video_list[j][i] for j in range(len(big_video_list))], 1) \
275
+ # for i in range(len(big_video_list[0]))
276
+ # ]
277
+ # elif batch_size > 2:
278
+ # if batch_size == 3:
279
+ # big_video_list.append(
280
+ # [np.zeros_like(f) for f in big_video_list[0]]
281
+ # )
282
+ # cat_video_list = [
283
+ # np.concatenate([
284
+ # np.concatenate([big_video_list[0][i], big_video_list[1][i]], 1),
285
+ # np.concatenate([big_video_list[2][i], big_video_list[3][i]], 1),
286
+ # ], 0) \
287
+ # for i in range(len(big_video_list[0]))
288
+ # ]
289
+ # else:
290
+ # cat_video_list = big_video_list[0]
291
+
292
+ for _ in range(4 - batch_size):
293
+ big_video_list.append(
294
+ [np.zeros_like(f) + 255 for f in big_video_list[0]]
295
+ )
296
+ cat_video_list = [
297
+ np.concatenate([
298
+ np.concatenate([big_video_list[0][i], big_video_list[1][i]], 1),
299
+ np.concatenate([big_video_list[2][i], big_video_list[3][i]], 1),
300
+ ], 0) \
301
+ for i in range(len(big_video_list[0]))
302
+ ]
303
+
304
+ path = f"tmp/{prompt.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.mp4"
305
+ imageio.mimwrite(path, np.stack(cat_video_list, 0))
306
 
307
  return global_info, path
308
 
309
  def infer_stage2(prompt, selection, seed, global_info, iters):
310
  prompt = prompt.replace('/', '')
311
+ with torch.cuda.device(1):
312
+ mesh_path = marching_cube(int(selection), prompt, global_info)
313
  mesh_name = mesh_path.split('/')[-1][:-4]
314
  # if2_cmd = f"threefiner if2 --mesh {mesh_path} --prompt \"{prompt}\" --outdir tmp --save {mesh_name}_if2.glb --text_dir --front_dir=-y"
315
  # print(if2_cmd)