|
|
|
import os |
|
import sys |
|
import copy |
|
import json |
|
import time |
|
import random |
|
import argparse |
|
import soundfile as sf |
|
import numpy as np |
|
import librosa |
|
import torchaudio |
|
from tqdm import tqdm |
|
import laion_clap |
|
from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
import torch |
|
from datetime import datetime |
|
from diffusers import DDPMScheduler |
|
from datasets import load_dataset |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
import models.controllable_diffusion as ConDiffusion |
|
import models.controllable_dataset as ConDataset |
|
from data.filter_data import get_event_list |
|
|
|
class dotdict(dict): |
|
"""dot.notation access to dictionary attributes""" |
|
__getattr__ = dict.get |
|
__setattr__ = dict.__setitem__ |
|
__delattr__ = dict.__delitem__ |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Inference for text to audio generation task.") |
|
parser.add_argument( |
|
"--exp_path", '-exp', type=str, default=None, |
|
help="Path for experiment." |
|
) |
|
parser.add_argument( |
|
"--test_file", type=str, default="data/meta_data/test-onoff-control_multi-event.json", |
|
help="Path for test_file." |
|
) |
|
parser.add_argument( |
|
"--original_args", type=str, default="summary.jsonl", |
|
help="Path for summary jsonl file saved during training." |
|
) |
|
parser.add_argument( |
|
"--model_pt", type=str, default="best.pt", |
|
help="Path for saved model bin file." |
|
) |
|
parser.add_argument( |
|
"--num_steps", type=int, default=200, |
|
help="How many denoising steps for generation.", |
|
) |
|
parser.add_argument( |
|
"--guidance", '-g', type=float, |
|
|
|
default=1, |
|
help="Guidance scale for classifier free guidance." |
|
) |
|
parser.add_argument( |
|
"--batch_size", '-b', type=int, default=32, |
|
help="Batch size for generation.", |
|
) |
|
parser.add_argument( |
|
"--num_samples", type=int, default=1, |
|
help="How many samples per prompt.", |
|
) |
|
parser.add_argument( |
|
"--seed", type=int, default=0, |
|
help="seed.", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
args.original_args = os.path.join(args.exp_path, args.original_args) |
|
args.model_pt = os.path.join(args.exp_path, args.model_pt) |
|
return args |
|
|
|
def main(): |
|
args = parse_args() |
|
train_args = dotdict(json.loads(open(args.original_args).readlines()[0])) |
|
|
|
seed = args.seed |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
if torch.cuda.is_available(): |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
extension = args.test_file.split(".")[-1] |
|
raw_datasets = load_dataset(extension, data_files={"test": args.test_file}) |
|
test_dataset = getattr(ConDataset, train_args.dataset_class)(raw_datasets["test"], train_args) |
|
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=test_dataset.collate_fn) |
|
|
|
|
|
|
|
print("\n------Load model") |
|
name = "audioldm-s-full" |
|
vae, stft = ConDiffusion.build_pretrained_models(name) |
|
vae, stft = vae.cuda(), stft.cuda() |
|
print(train_args.model_class) |
|
model = getattr(ConDiffusion, train_args.model_class)( |
|
scheduler_name=train_args.scheduler_name, |
|
unet_model_config_path=train_args.unet_model_config, |
|
snr_gamma=train_args.snr_gamma |
|
).cuda().eval() |
|
|
|
|
|
device = vae.device() |
|
model.load_state_dict(torch.load(args.model_pt)) |
|
scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler") |
|
|
|
|
|
|
|
num_steps, guidance, batch_size, num_samples = args.num_steps, args.guidance, args.batch_size, args.num_samples |
|
audio_len = 16000 * 10 |
|
output_dir = os.path.join(CONTROLLABLE_PATH, |
|
f"synthesized/{'-'.join(args.model_pt.split('/')[-3:-1])}_steps-{num_steps}_guidance-{guidance}_samples-{num_samples}_{args.test_file.split('/')[-1].split('.')[0]}/") |
|
os.makedirs(output_dir, exist_ok=True) |
|
print(f"------Write to files to {output_dir}") |
|
|
|
print("------Diffusion begin!") |
|
if args.num_samples == 1: |
|
for batch in tqdm(test_dataloader): |
|
idx, onset, event_info, _, _, onset_str = batch |
|
with torch.no_grad(): |
|
latents = model.inference({"onset":onset.to(device), "event_info":event_info.to(device)}, scheduler, num_steps, guidance, num_samples, disable_progress=True) |
|
mel = vae.decode_first_stage(latents) |
|
wave = vae.decode_to_waveform(mel) |
|
for j, wav in enumerate(wave): |
|
sf.write(f"{output_dir}/{idx[j]}--{onset_str[j]}.wav", wav[:audio_len], samplerate=16000, subtype='PCM_16') |
|
else: |
|
print("Clap scorer filter") |
|
clap_scorer = laion_clap.CLAP_Module(enable_fusion=False) |
|
ckpt_path = 'miniconda3/envs/py3.10.11/lib/python3.10/site-packages/laion_clap/630k-audioset-best.pt' |
|
ckpt = clap_load_state_dict(ckpt_path, skip_params=True) |
|
del_parameter_key = ["text_branch.embeddings.position_ids"] |
|
ckpt = {"model."+k:v for k, v in ckpt.items() if k not in del_parameter_key} |
|
clap_scorer.load_state_dict(ckpt) |
|
for batch in tqdm(test_dataloader): |
|
_, onset, event_info, _, caption, onset_str = batch |
|
with torch.no_grad(): |
|
latents = model.inference({"onset":onset.to(device), "event_info":event_info.to(device)}, scheduler, num_steps, guidance, num_samples, disable_progress=True) |
|
mel = vae.decode_first_stage(latents) |
|
wave = vae.decode_to_waveform(mel) |
|
for j in range(batch_size): |
|
text_embed = clap_scorer.get_text_embedding([caption[j], ""], use_tensor=False)[:1] |
|
best_idx, best_score = 0, float('-inf') |
|
for candidate_idx in range(num_samples): |
|
wav_48k = librosa.core.resample(wave[j*num_samples+candidate_idx].astype(np.float64)[:audio_len] / 32768, orig_sr=16000, target_sr=48000).reshape(1, -1) |
|
audio_embed = clap_scorer.get_audio_embedding_from_data(x = wav_48k) |
|
pair_similarity = cosine_similarity(audio_embed, text_embed)[0][0] |
|
if pair_similarity > best_score: |
|
best_score = pair_similarity |
|
best_idx = candidate_idx |
|
sf.write(f"{output_dir}/{idx[j]}--{onset_str[j]}.wav", wave[j * num_samples + best_idx][:audio_len], samplerate=16000, subtype='PCM_16') |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |