Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torchaudio | |
from einops import rearrange | |
import argparse | |
import json | |
import os | |
from tqdm import tqdm | |
import random | |
import numpy as np | |
import time | |
import io | |
import pydub | |
from diffrhythm.infer.infer_utils import ( | |
get_reference_latent, | |
get_lrc_token, | |
get_style_prompt, | |
prepare_model, | |
get_negative_style_prompt | |
) | |
def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128): | |
downsampling_ratio = 2048 | |
io_channels = 2 | |
if not chunked: | |
# default behavior. Decode the entire latent in parallel | |
return vae_model.decode_export(latents) | |
else: | |
# chunked decoding | |
hop_size = chunk_size - overlap | |
total_size = latents.shape[2] | |
batch_size = latents.shape[0] | |
chunks = [] | |
i = 0 | |
for i in range(0, total_size - chunk_size + 1, hop_size): | |
chunk = latents[:,:,i:i+chunk_size] | |
chunks.append(chunk) | |
if i+chunk_size != total_size: | |
# Final chunk | |
chunk = latents[:,:,-chunk_size:] | |
chunks.append(chunk) | |
chunks = torch.stack(chunks) | |
num_chunks = chunks.shape[0] | |
# samples_per_latent is just the downsampling ratio | |
samples_per_latent = downsampling_ratio | |
# Create an empty waveform, we will populate it with chunks as decode them | |
y_size = total_size * samples_per_latent | |
y_final = torch.zeros((batch_size,io_channels,y_size)).to(latents.device) | |
for i in range(num_chunks): | |
x_chunk = chunks[i,:] | |
# decode the chunk | |
y_chunk = vae_model.decode_export(x_chunk) | |
# figure out where to put the audio along the time domain | |
if i == num_chunks-1: | |
# final chunk always goes at the end | |
t_end = y_size | |
t_start = t_end - y_chunk.shape[2] | |
else: | |
t_start = i * hop_size * samples_per_latent | |
t_end = t_start + chunk_size * samples_per_latent | |
# remove the edges of the overlaps | |
ol = (overlap//2) * samples_per_latent | |
chunk_start = 0 | |
chunk_end = y_chunk.shape[2] | |
if i > 0: | |
# no overlap for the start of the first chunk | |
t_start += ol | |
chunk_start += ol | |
if i < num_chunks-1: | |
# no overlap for the end of the last chunk | |
t_end -= ol | |
chunk_end -= ol | |
# paste the chunked audio into our y_final output audio | |
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] | |
return y_final | |
def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, sway_sampling_coef, start_time, file_type): | |
with torch.inference_mode(): | |
generated, _ = cfm_model.sample( | |
cond=cond, | |
text=text, | |
duration=duration, | |
style_prompt=style_prompt, | |
negative_style_prompt=negative_style_prompt, | |
steps=steps, | |
cfg_strength=4.0, | |
sway_sampling_coef=sway_sampling_coef, | |
start_time=start_time | |
) | |
generated = generated.to(torch.float32) | |
latent = generated.transpose(1, 2) # [b d t] | |
output = decode_audio(latent, vae_model, chunked=False) | |
# Rearrange audio batch to a single sequence | |
output = rearrange(output, "b d n -> d (b n)") | |
output_tensor = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).cpu() | |
output_np = output_tensor.numpy().T.astype(np.float32) | |
if file_type == 'wav': | |
return (44100, output_np) | |
else: | |
buffer = io.BytesIO() | |
output_np = np.int16(output_np * 2**15) | |
song = pydub.AudioSegment(output_np.tobytes(), frame_rate=44100, sample_width=2, channels=2) | |
if file_type == 'mp3': | |
song.export(buffer, format="mp3", bitrate="320k") | |
else: | |
song.export(buffer, format="ogg", bitrate="320k") | |
return buffer.getvalue() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--lrc-path', type=str, default="example/eg.lrc") # lyrics of target song | |
parser.add_argument('--ref-audio-path', type=str, default="example/eg.mp3") # reference audio as style prompt for target song | |
parser.add_argument('--audio-length', type=int, default=95) # length of target song | |
parser.add_argument('--output-dir', type=str, default="example/output") | |
args = parser.parse_args() | |
device = 'cuda' | |
audio_length = args.audio_length | |
if audio_length == 95: | |
max_frames = 2048 | |
elif audio_length == 285: | |
max_frames = 6144 | |
cfm, tokenizer, muq, vae = prepare_model(device) | |
with open(args.lrc_path, 'r') as f: | |
lrc = f.read() | |
lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device) | |
style_prompt = get_style_prompt(muq, args.ref_audio_path) | |
negative_style_prompt = get_negative_style_prompt(device) | |
latent_prompt = get_reference_latent(device, max_frames) | |
s_t = time.time() | |
generated_song = inference(cfm_model=cfm, | |
vae_model=vae, | |
cond=latent_prompt, | |
text=lrc_prompt, | |
duration=max_frames, | |
style_prompt=style_prompt, | |
negative_style_prompt=negative_style_prompt, | |
start_time=start_time | |
) | |
e_t = time.time() - s_t | |
print(f"inference cost {e_t} seconds") | |
output_dir = args.output_dir | |
os.makedirs(output_dir, exist_ok=True) | |
output_path = os.path.join(output_dir, "output.wav") | |
torchaudio.save(output_path, generated_song, sample_rate=44100) | |