fffiloni commited on
Commit
15b5e49
1 Parent(s): 4c7d0b3

Update inference_from_video.py

Browse files
Files changed (1) hide show
  1. inference_from_video.py +16 -7
inference_from_video.py CHANGED
@@ -66,15 +66,15 @@ def parse_args():
66
  )
67
  parser.add_argument(
68
  "--sample_rate", type=int, default=-1,
69
- help="How many test instances to evaluate.",
70
  )
71
  parser.add_argument(
72
  "--save_dir", type=str, default="./outputs/tmp",
73
- help="output save dir"
74
  )
75
  parser.add_argument(
76
  "--data_path", type=str, default="data/video_processed/video_gt_augment",
77
- help="inference data path"
78
  )
79
 
80
  args = parser.parse_args()
@@ -183,7 +183,7 @@ def main():
183
  mel = vae.decode_first_stage(latents)
184
  wave = vae.decode_to_waveform(mel)
185
 
186
- all_outputs += [item for item in wave]
187
 
188
  # Save #
189
  exp_id = str(int(time.time()))
@@ -194,7 +194,12 @@ def main():
194
  output_dir = "{}/{}_{}_steps_{}_guidance_{}_sampleRate_{}_augment".format(args.save_dir, exp_id, "_".join(args.model.split("/")[1:-1]), num_steps, guidance, sample_rate)
195
  os.makedirs(output_dir, exist_ok=True)
196
  for j, wav in enumerate(all_outputs):
197
- sf.write("{}/{}".format(output_dir, wavname[j]), wav, samplerate=sample_rate)
 
 
 
 
 
198
 
199
  else:
200
  for i in range(num_samples):
@@ -208,8 +213,12 @@ def main():
208
  ranked_wavs_for_text = [wavs_for_text[r] for r in rank]
209
 
210
  for i, wav in enumerate(ranked_wavs_for_text):
211
- output_dir = "{}/{}_{}_steps_{}_guidance_{}_sampleRate_{}/rank_{}".format(args.save_dir, exp_id, "_".join(args.model.split("/")[1:-1]), num_steps, guidance, sample_rate, i+1)
212
- sf.write("{}/{}".format(output_dir, wavname[k]), wav, samplerate=sample_rate)
 
 
 
 
213
 
214
  if __name__ == "__main__":
215
  main()
 
66
  )
67
  parser.add_argument(
68
  "--sample_rate", type=int, default=-1,
69
+ help="Sample rate for audio output."
70
  )
71
  parser.add_argument(
72
  "--save_dir", type=str, default="./outputs/tmp",
73
+ help="Output save directory"
74
  )
75
  parser.add_argument(
76
  "--data_path", type=str, default="data/video_processed/video_gt_augment",
77
+ help="Inference data path"
78
  )
79
 
80
  args = parser.parse_args()
 
183
  mel = vae.decode_first_stage(latents)
184
  wave = vae.decode_to_waveform(mel)
185
 
186
+ all_outputs += [item.cpu().numpy() for item in wave] # Ensure wave is on CPU and in numpy format
187
 
188
  # Save #
189
  exp_id = str(int(time.time()))
 
194
  output_dir = "{}/{}_{}_steps_{}_guidance_{}_sampleRate_{}_augment".format(args.save_dir, exp_id, "_".join(args.model.split("/")[1:-1]), num_steps, guidance, sample_rate)
195
  os.makedirs(output_dir, exist_ok=True)
196
  for j, wav in enumerate(all_outputs):
197
+ file_path = "{}/{}".format(output_dir, wavname[j])
198
+ try:
199
+ sf.write(file_path, wav, samplerate=sample_rate)
200
+ print(f"Saved {file_path}")
201
+ except Exception as e:
202
+ print(f"Error saving {file_path}: {e}")
203
 
204
  else:
205
  for i in range(num_samples):
 
213
  ranked_wavs_for_text = [wavs_for_text[r] for r in rank]
214
 
215
  for i, wav in enumerate(ranked_wavs_for_text):
216
+ file_path = "{}/{}".format(output_dir, wavname[k])
217
+ try:
218
+ sf.write(file_path, wav, samplerate=sample_rate)
219
+ print(f"Saved {file_path}")
220
+ except Exception as e:
221
+ print(f"Error saving {file_path}: {e}")
222
 
223
  if __name__ == "__main__":
224
  main()