fffiloni commited on
Commit
921f183
1 Parent(s): 15b5e49

Update inference_from_video.py

Browse files
Files changed (1) hide show
  1. inference_from_video.py +9 -18
inference_from_video.py CHANGED
@@ -65,16 +65,16 @@ def parse_args():
65
  help="How many test instances to evaluate.",
66
  )
67
  parser.add_argument(
68
- "--sample_rate", type=int, default=-1,
69
- help="Sample rate for audio output."
70
  )
71
  parser.add_argument(
72
  "--save_dir", type=str, default="./outputs/tmp",
73
- help="Output save directory"
74
  )
75
  parser.add_argument(
76
  "--data_path", type=str, default="data/video_processed/video_gt_augment",
77
- help="Inference data path"
78
  )
79
 
80
  args = parser.parse_args()
@@ -183,7 +183,7 @@ def main():
183
  mel = vae.decode_first_stage(latents)
184
  wave = vae.decode_to_waveform(mel)
185
 
186
- all_outputs += [item.cpu().numpy() for item in wave] # Ensure wave is on CPU and in numpy format
187
 
188
  # Save #
189
  exp_id = str(int(time.time()))
@@ -194,12 +194,7 @@ def main():
194
  output_dir = "{}/{}_{}_steps_{}_guidance_{}_sampleRate_{}_augment".format(args.save_dir, exp_id, "_".join(args.model.split("/")[1:-1]), num_steps, guidance, sample_rate)
195
  os.makedirs(output_dir, exist_ok=True)
196
  for j, wav in enumerate(all_outputs):
197
- file_path = "{}/{}".format(output_dir, wavname[j])
198
- try:
199
- sf.write(file_path, wav, samplerate=sample_rate)
200
- print(f"Saved {file_path}")
201
- except Exception as e:
202
- print(f"Error saving {file_path}: {e}")
203
 
204
  else:
205
  for i in range(num_samples):
@@ -213,12 +208,8 @@ def main():
213
  ranked_wavs_for_text = [wavs_for_text[r] for r in rank]
214
 
215
  for i, wav in enumerate(ranked_wavs_for_text):
216
- file_path = "{}/{}".format(output_dir, wavname[k])
217
- try:
218
- sf.write(file_path, wav, samplerate=sample_rate)
219
- print(f"Saved {file_path}")
220
- except Exception as e:
221
- print(f"Error saving {file_path}: {e}")
222
 
223
  if __name__ == "__main__":
224
- main()
 
65
  help="How many test instances to evaluate.",
66
  )
67
  parser.add_argument(
68
+ "--sample_rate", type=int, default=48000,
69
+ help="Sample rate for audio output.",
70
  )
71
  parser.add_argument(
72
  "--save_dir", type=str, default="./outputs/tmp",
73
+ help="output save dir"
74
  )
75
  parser.add_argument(
76
  "--data_path", type=str, default="data/video_processed/video_gt_augment",
77
+ help="inference data path"
78
  )
79
 
80
  args = parser.parse_args()
 
183
  mel = vae.decode_first_stage(latents)
184
  wave = vae.decode_to_waveform(mel)
185
 
186
+ all_outputs += [item for item in wave]
187
 
188
  # Save #
189
  exp_id = str(int(time.time()))
 
194
  output_dir = "{}/{}_{}_steps_{}_guidance_{}_sampleRate_{}_augment".format(args.save_dir, exp_id, "_".join(args.model.split("/")[1:-1]), num_steps, guidance, sample_rate)
195
  os.makedirs(output_dir, exist_ok=True)
196
  for j, wav in enumerate(all_outputs):
197
+ sf.write("{}/{}".format(output_dir, wavname[j]), wav, samplerate=sample_rate)
 
 
 
 
 
198
 
199
  else:
200
  for i in range(num_samples):
 
208
  ranked_wavs_for_text = [wavs_for_text[r] for r in rank]
209
 
210
  for i, wav in enumerate(ranked_wavs_for_text):
211
+ output_dir = "{}/{}_{}_steps_{}_guidance_{}_sampleRate_{}/rank_{}".format(args.save_dir, exp_id, "_".join(args.model.split("/")[1:-1]), num_steps, guidance, sample_rate, i+1)
212
+ sf.write("{}/{}".format(output_dir, wavname[k]), wav, samplerate=sample_rate)
 
 
 
 
213
 
214
  if __name__ == "__main__":
215
+ main()