shisheng7 commited on
Commit
0216866
1 Parent(s): 2477c25

fix cuda bug

Browse files
configs/inference/inference.yaml CHANGED
@@ -101,7 +101,7 @@ enable_zero_snr: True
101
  stage1_ckpt_dir: "./exp_output/stage1/"
102
 
103
  single_inference_times: 10
104
- inference_steps: 40
105
  cfg_scale: 3.5
106
 
107
  seed: 42
 
101
  stage1_ckpt_dir: "./exp_output/stage1/"
102
 
103
  single_inference_times: 10
104
+ inference_steps: 20 # improve inference speed
105
  cfg_scale: 3.5
106
 
107
  seed: 42
scripts/inference.py CHANGED
@@ -347,7 +347,8 @@ def inference(cfg, image_processor, audio_processor, pipeline, audioproj, save_d
347
 
348
  times = audio_emb.shape[0] // clip_length
349
  tensor_result = []
350
- generator = torch.manual_seed(42)
 
351
  for t in range(times):
352
  print(f"[{t+1}/{times}]")
353
 
 
347
 
348
  times = audio_emb.shape[0] // clip_length
349
  tensor_result = []
350
+ # generator = torch.manual_seed(42)
351
+ generator = torch.cuda.manual_seed_all(42) # use cuda seed all
352
  for t in range(times):
353
  print(f"[{t+1}/{times}]")
354