ing0's picture
readme
5fa1afb
import torch
import torchaudio
from einops import rearrange
import argparse
import json
import os
from tqdm import tqdm
import random
import numpy as np
import time
import io
import pydub
from diffrhythm.infer.infer_utils import (
get_reference_latent,
get_lrc_token,
get_style_prompt,
prepare_model,
get_negative_style_prompt
)
def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
downsampling_ratio = 2048
io_channels = 2
if not chunked:
# default behavior. Decode the entire latent in parallel
return vae_model.decode_export(latents)
else:
# chunked decoding
hop_size = chunk_size - overlap
total_size = latents.shape[2]
batch_size = latents.shape[0]
chunks = []
i = 0
for i in range(0, total_size - chunk_size + 1, hop_size):
chunk = latents[:,:,i:i+chunk_size]
chunks.append(chunk)
if i+chunk_size != total_size:
# Final chunk
chunk = latents[:,:,-chunk_size:]
chunks.append(chunk)
chunks = torch.stack(chunks)
num_chunks = chunks.shape[0]
# samples_per_latent is just the downsampling ratio
samples_per_latent = downsampling_ratio
# Create an empty waveform, we will populate it with chunks as decode them
y_size = total_size * samples_per_latent
y_final = torch.zeros((batch_size,io_channels,y_size)).to(latents.device)
for i in range(num_chunks):
x_chunk = chunks[i,:]
# decode the chunk
y_chunk = vae_model.decode_export(x_chunk)
# figure out where to put the audio along the time domain
if i == num_chunks-1:
# final chunk always goes at the end
t_end = y_size
t_start = t_end - y_chunk.shape[2]
else:
t_start = i * hop_size * samples_per_latent
t_end = t_start + chunk_size * samples_per_latent
# remove the edges of the overlaps
ol = (overlap//2) * samples_per_latent
chunk_start = 0
chunk_end = y_chunk.shape[2]
if i > 0:
# no overlap for the start of the first chunk
t_start += ol
chunk_start += ol
if i < num_chunks-1:
# no overlap for the end of the last chunk
t_end -= ol
chunk_end -= ol
# paste the chunked audio into our y_final output audio
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
return y_final
def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, sway_sampling_coef, start_time, file_type):
with torch.inference_mode():
generated, _ = cfm_model.sample(
cond=cond,
text=text,
duration=duration,
style_prompt=style_prompt,
negative_style_prompt=negative_style_prompt,
steps=steps,
cfg_strength=4.0,
sway_sampling_coef=sway_sampling_coef,
start_time=start_time
)
generated = generated.to(torch.float32)
latent = generated.transpose(1, 2) # [b d t]
output = decode_audio(latent, vae_model, chunked=False)
# Rearrange audio batch to a single sequence
output = rearrange(output, "b d n -> d (b n)")
output_tensor = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).cpu()
output_np = output_tensor.numpy().T.astype(np.float32)
if file_type == 'wav':
return (44100, output_np)
else:
buffer = io.BytesIO()
output_np = np.int16(output_np * 2**15)
song = pydub.AudioSegment(output_np.tobytes(), frame_rate=44100, sample_width=2, channels=2)
if file_type == 'mp3':
song.export(buffer, format="mp3", bitrate="320k")
else:
song.export(buffer, format="ogg", bitrate="320k")
return buffer.getvalue()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--lrc-path', type=str, default="example/eg.lrc") # lyrics of target song
parser.add_argument('--ref-audio-path', type=str, default="example/eg.mp3") # reference audio as style prompt for target song
parser.add_argument('--audio-length', type=int, default=95) # length of target song
parser.add_argument('--output-dir', type=str, default="example/output")
args = parser.parse_args()
device = 'cuda'
audio_length = args.audio_length
if audio_length == 95:
max_frames = 2048
elif audio_length == 285:
max_frames = 6144
cfm, tokenizer, muq, vae = prepare_model(device)
with open(args.lrc_path, 'r') as f:
lrc = f.read()
lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
style_prompt = get_style_prompt(muq, args.ref_audio_path)
negative_style_prompt = get_negative_style_prompt(device)
latent_prompt = get_reference_latent(device, max_frames)
s_t = time.time()
generated_song = inference(cfm_model=cfm,
vae_model=vae,
cond=latent_prompt,
text=lrc_prompt,
duration=max_frames,
style_prompt=style_prompt,
negative_style_prompt=negative_style_prompt,
start_time=start_time
)
e_t = time.time() - s_t
print(f"inference cost {e_t} seconds")
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "output.wav")
torchaudio.save(output_path, generated_song, sample_rate=44100)