File size: 3,783 Bytes
8c1bf05
 
 
 
 
 
 
 
 
 
 
cb0c99a
8c1bf05
 
 
 
 
 
 
 
 
cb0c99a
 
 
 
 
 
 
 
8c1bf05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb0c99a
8c1bf05
 
 
 
 
 
 
 
cb0c99a
 
 
 
 
8c1bf05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb0c99a
8c1bf05
 
cb0c99a
 
8c1bf05
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

import os
import json
import random
import argparse
import soundfile as sf
import numpy as np

import torch
from diffusers import DDPMScheduler
from pico_model import PicoDiffusion, build_pretrained_models
from llm_preprocess import get_event, preprocess_gemini, preprocess_gpt
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

def parse_args():
    parser = argparse.ArgumentParser(description="Inference for text to audio generation task.")
    parser.add_argument(
        "--text", '-t', type=str, default="spraying two times then gunshot three times.",
        help="free-text caption."
    )
    parser.add_argument(
        "--timestamp_caption", '-c', type=str, 
        default=None,
        #default="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",
        help="timestamp caption, formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'."
    )
    parser.add_argument(
        "--exp_path", '-exp', type=str, default="/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/huggingface/ckpts/pico_model",
        help="Path for experiment."
    )
    parser.add_argument(
        "--freeze_text_encoder_ckpt", type=str, default='/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/huggingface/ckpts/laion_clap/630k-audioset-best.pt',
        help="Path for clap."
    )
    parser.add_argument(
        "--seed", type=int, default=0,
        help="seed.",
    )
    
    args = parser.parse_args()
    args.original_args = os.path.join(args.exp_path, "summary.jsonl")    
    args.diffusion_pt = os.path.join(args.exp_path, "diffusion.pt")
    return args

def main():
    args = parse_args() 
    train_args = dotdict(json.loads(open(args.original_args).readlines()[0]))
    
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)    
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Step1: preprocess via llm
    if args.timestamp_caption == None:
        #args.timestamp_caption = preprocess_gpt(args.text)
        args.timestamp_caption = preprocess_gemini(args.text)

    # Load Models #  
    print("------Load model")
    name = "audioldm-s-full"
    vae, stft = build_pretrained_models(name)
    vae, stft = vae.cuda(), stft.cuda()
    model = PicoDiffusion(
        scheduler_name=train_args.scheduler_name, 
        unet_model_config_path=train_args.unet_model_config, 
        snr_gamma=train_args.snr_gamma,
        freeze_text_encoder_ckpt=args.freeze_text_encoder_ckpt,
        diffusion_pt=args.diffusion_pt,
    ).cuda().eval()
    scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
    
    # Generate #
    num_steps, guidance, num_samples, audio_len = 200, 3.0, 1, 16000 * 10
    output_dir = os.path.join("/hpc_stor03/sjtu_home/zeyu.xie/workspace/controllable_audio_generation/synthesized", 
        f"huggingface_demo_steps-{num_steps}_guidance-{guidance}_samples-{num_samples}")
    os.makedirs(output_dir, exist_ok=True)
    
    print("------Diffusion begin!")    
    with torch.no_grad():
        latents = model.demo_inference(args.timestamp_caption, scheduler, num_steps, guidance, num_samples, disable_progress=True)
        mel = vae.decode_first_stage(latents)
        wave = vae.decode_to_waveform(mel)
        sf.write(f"{output_dir}/{args.timestamp_caption}.wav", wave[0][:audio_len], samplerate=16000, subtype='PCM_16')
    print(f"------Write to files to {output_dir}/{args.timestamp_caption}.wav")

if __name__ == "__main__":
    main()