Spaces:
Runtime error
Runtime error
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import time | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
import torch.nn as nn | |
from collections import OrderedDict | |
import json | |
from models.tta.autoencoder.autoencoder import AutoencoderKL | |
from models.tta.ldm.inference_utils.vocoder import Generator | |
from models.tta.ldm.audioldm import AudioLDM | |
from transformers import T5EncoderModel, AutoTokenizer | |
from diffusers import PNDMScheduler | |
import matplotlib.pyplot as plt | |
from scipy.io.wavfile import write | |
class AttrDict(dict): | |
def __init__(self, *args, **kwargs): | |
super(AttrDict, self).__init__(*args, **kwargs) | |
self.__dict__ = self | |
class AudioLDMInference: | |
def __init__(self, args, cfg): | |
self.cfg = cfg | |
self.args = args | |
self.build_autoencoderkl() | |
self.build_textencoder() | |
self.model = self.build_model() | |
self.load_state_dict() | |
self.build_vocoder() | |
self.out_path = self.args.output_dir | |
self.out_mel_path = os.path.join(self.out_path, "mel") | |
self.out_wav_path = os.path.join(self.out_path, "wav") | |
os.makedirs(self.out_mel_path, exist_ok=True) | |
os.makedirs(self.out_wav_path, exist_ok=True) | |
def build_autoencoderkl(self): | |
self.autoencoderkl = AutoencoderKL(self.cfg.model.autoencoderkl) | |
self.autoencoder_path = self.cfg.model.autoencoder_path | |
checkpoint = torch.load(self.autoencoder_path, map_location="cpu") | |
self.autoencoderkl.load_state_dict(checkpoint["model"]) | |
self.autoencoderkl.cuda(self.args.local_rank) | |
self.autoencoderkl.requires_grad_(requires_grad=False) | |
self.autoencoderkl.eval() | |
def build_textencoder(self): | |
self.tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512) | |
self.text_encoder = T5EncoderModel.from_pretrained("t5-base") | |
self.text_encoder.cuda(self.args.local_rank) | |
self.text_encoder.requires_grad_(requires_grad=False) | |
self.text_encoder.eval() | |
def build_vocoder(self): | |
config_file = os.path.join(self.args.vocoder_config_path) | |
with open(config_file) as f: | |
data = f.read() | |
json_config = json.loads(data) | |
h = AttrDict(json_config) | |
self.vocoder = Generator(h).to(self.args.local_rank) | |
checkpoint_dict = torch.load( | |
self.args.vocoder_path, map_location=self.args.local_rank | |
) | |
self.vocoder.load_state_dict(checkpoint_dict["generator"]) | |
def build_model(self): | |
self.model = AudioLDM(self.cfg.model.audioldm) | |
return self.model | |
def load_state_dict(self): | |
self.checkpoint_path = self.args.checkpoint_path | |
checkpoint = torch.load(self.checkpoint_path, map_location="cpu") | |
self.model.load_state_dict(checkpoint["model"]) | |
self.model.cuda(self.args.local_rank) | |
def get_text_embedding(self): | |
text = self.args.text | |
prompt = [text] | |
text_input = self.tokenizer( | |
prompt, | |
max_length=self.tokenizer.model_max_length, | |
truncation=True, | |
padding="do_not_pad", | |
return_tensors="pt", | |
) | |
text_embeddings = self.text_encoder( | |
text_input.input_ids.to(self.args.local_rank) | |
)[0] | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = self.tokenizer( | |
[""] * 1, padding="max_length", max_length=max_length, return_tensors="pt" | |
) | |
uncond_embeddings = self.text_encoder( | |
uncond_input.input_ids.to(self.args.local_rank) | |
)[0] | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
return text_embeddings | |
def inference(self): | |
text_embeddings = self.get_text_embedding() | |
print(text_embeddings.shape) | |
num_steps = self.args.num_steps | |
guidance_scale = self.args.guidance_scale | |
noise_scheduler = PNDMScheduler( | |
num_train_timesteps=1000, | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
skip_prk_steps=True, | |
set_alpha_to_one=False, | |
steps_offset=1, | |
prediction_type="epsilon", | |
) | |
noise_scheduler.set_timesteps(num_steps) | |
latents = torch.randn( | |
( | |
1, | |
self.cfg.model.autoencoderkl.z_channels, | |
80 // (2 ** (len(self.cfg.model.autoencoderkl.ch_mult) - 1)), | |
624 // (2 ** (len(self.cfg.model.autoencoderkl.ch_mult) - 1)), | |
) | |
).to(self.args.local_rank) | |
self.model.eval() | |
for t in tqdm(noise_scheduler.timesteps): | |
t = t.to(self.args.local_rank) | |
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
latent_model_input = torch.cat([latents] * 2) | |
latent_model_input = noise_scheduler.scale_model_input( | |
latent_model_input, timestep=t | |
) | |
# print(latent_model_input.shape) | |
# predict the noise residual | |
with torch.no_grad(): | |
noise_pred = self.model( | |
latent_model_input, torch.cat([t.unsqueeze(0)] * 2), text_embeddings | |
) | |
# perform guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample | |
# print(latents.shape) | |
latents_out = latents | |
print(latents_out.shape) | |
with torch.no_grad(): | |
mel_out = self.autoencoderkl.decode(latents_out) | |
print(mel_out.shape) | |
melspec = mel_out[0, 0].cpu().detach().numpy() | |
plt.imsave(os.path.join(self.out_mel_path, self.args.text + ".png"), melspec) | |
self.vocoder.eval() | |
self.vocoder.remove_weight_norm() | |
with torch.no_grad(): | |
melspec = np.expand_dims(melspec, 0) | |
melspec = torch.FloatTensor(melspec).to(self.args.local_rank) | |
y = self.vocoder(melspec) | |
audio = y.squeeze() | |
audio = audio * 32768.0 | |
audio = audio.cpu().numpy().astype("int16") | |
write(os.path.join(self.out_wav_path, self.args.text + ".wav"), 16000, audio) | |