shisheng7
commited on
Commit
•
0216866
1
Parent(s):
2477c25
fix cuda bug
Browse files- configs/inference/inference.yaml +1 -1
- scripts/inference.py +2 -1
configs/inference/inference.yaml
CHANGED
@@ -101,7 +101,7 @@ enable_zero_snr: True
|
|
101 |
stage1_ckpt_dir: "./exp_output/stage1/"
|
102 |
|
103 |
single_inference_times: 10
|
104 |
-
inference_steps:
|
105 |
cfg_scale: 3.5
|
106 |
|
107 |
seed: 42
|
|
|
101 |
stage1_ckpt_dir: "./exp_output/stage1/"
|
102 |
|
103 |
single_inference_times: 10
|
104 |
+
inference_steps: 20 # improve inference speed
|
105 |
cfg_scale: 3.5
|
106 |
|
107 |
seed: 42
|
scripts/inference.py
CHANGED
@@ -347,7 +347,8 @@ def inference(cfg, image_processor, audio_processor, pipeline, audioproj, save_d
|
|
347 |
|
348 |
times = audio_emb.shape[0] // clip_length
|
349 |
tensor_result = []
|
350 |
-
generator = torch.manual_seed(42)
|
|
|
351 |
for t in range(times):
|
352 |
print(f"[{t+1}/{times}]")
|
353 |
|
|
|
347 |
|
348 |
times = audio_emb.shape[0] // clip_length
|
349 |
tensor_result = []
|
350 |
+
# generator = torch.manual_seed(42)
|
351 |
+
generator = torch.cuda.manual_seed_all(42) # use cuda seed all
|
352 |
for t in range(times):
|
353 |
print(f"[{t+1}/{times}]")
|
354 |
|