Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,014 Bytes
b96e750 9743771 b96e750 5fa1afb b96e750 ccebb03 b96e750 ccebb03 b96e750 9743771 b96e750 2a3c97e b96e750 2a3c97e 5fa1afb 9743771 b96e750 ccebb03 b96e750 ccebb03 b96e750 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
import torchaudio
from einops import rearrange
import argparse
import json
import os
from tqdm import tqdm
import random
import numpy as np
import time
import io
import pydub
from diffrhythm.infer.infer_utils import (
get_reference_latent,
get_lrc_token,
get_style_prompt,
prepare_model,
get_negative_style_prompt
)
def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
downsampling_ratio = 2048
io_channels = 2
if not chunked:
# default behavior. Decode the entire latent in parallel
return vae_model.decode_export(latents)
else:
# chunked decoding
hop_size = chunk_size - overlap
total_size = latents.shape[2]
batch_size = latents.shape[0]
chunks = []
i = 0
for i in range(0, total_size - chunk_size + 1, hop_size):
chunk = latents[:,:,i:i+chunk_size]
chunks.append(chunk)
if i+chunk_size != total_size:
# Final chunk
chunk = latents[:,:,-chunk_size:]
chunks.append(chunk)
chunks = torch.stack(chunks)
num_chunks = chunks.shape[0]
# samples_per_latent is just the downsampling ratio
samples_per_latent = downsampling_ratio
# Create an empty waveform, we will populate it with chunks as decode them
y_size = total_size * samples_per_latent
y_final = torch.zeros((batch_size,io_channels,y_size)).to(latents.device)
for i in range(num_chunks):
x_chunk = chunks[i,:]
# decode the chunk
y_chunk = vae_model.decode_export(x_chunk)
# figure out where to put the audio along the time domain
if i == num_chunks-1:
# final chunk always goes at the end
t_end = y_size
t_start = t_end - y_chunk.shape[2]
else:
t_start = i * hop_size * samples_per_latent
t_end = t_start + chunk_size * samples_per_latent
# remove the edges of the overlaps
ol = (overlap//2) * samples_per_latent
chunk_start = 0
chunk_end = y_chunk.shape[2]
if i > 0:
# no overlap for the start of the first chunk
t_start += ol
chunk_start += ol
if i < num_chunks-1:
# no overlap for the end of the last chunk
t_end -= ol
chunk_end -= ol
# paste the chunked audio into our y_final output audio
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
return y_final
def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, sway_sampling_coef, start_time, file_type):
with torch.inference_mode():
generated, _ = cfm_model.sample(
cond=cond,
text=text,
duration=duration,
style_prompt=style_prompt,
negative_style_prompt=negative_style_prompt,
steps=steps,
cfg_strength=4.0,
sway_sampling_coef=sway_sampling_coef,
start_time=start_time
)
generated = generated.to(torch.float32)
latent = generated.transpose(1, 2) # [b d t]
output = decode_audio(latent, vae_model, chunked=False)
# Rearrange audio batch to a single sequence
output = rearrange(output, "b d n -> d (b n)")
output_tensor = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).cpu()
output_np = output_tensor.numpy().T.astype(np.float32)
if file_type == 'wav':
return (44100, output_np)
else:
buffer = io.BytesIO()
output_np = np.int16(output_np * 2**15)
song = pydub.AudioSegment(output_np.tobytes(), frame_rate=44100, sample_width=2, channels=2)
if file_type == 'mp3':
song.export(buffer, format="mp3", bitrate="320k")
else:
song.export(buffer, format="ogg", bitrate="320k")
return buffer.getvalue()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--lrc-path', type=str, default="example/eg.lrc") # lyrics of target song
parser.add_argument('--ref-audio-path', type=str, default="example/eg.mp3") # reference audio as style prompt for target song
parser.add_argument('--audio-length', type=int, default=95) # length of target song
parser.add_argument('--output-dir', type=str, default="example/output")
args = parser.parse_args()
device = 'cuda'
audio_length = args.audio_length
if audio_length == 95:
max_frames = 2048
elif audio_length == 285:
max_frames = 6144
cfm, tokenizer, muq, vae = prepare_model(device)
with open(args.lrc_path, 'r') as f:
lrc = f.read()
lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
style_prompt = get_style_prompt(muq, args.ref_audio_path)
negative_style_prompt = get_negative_style_prompt(device)
latent_prompt = get_reference_latent(device, max_frames)
s_t = time.time()
generated_song = inference(cfm_model=cfm,
vae_model=vae,
cond=latent_prompt,
text=lrc_prompt,
duration=max_frames,
style_prompt=style_prompt,
negative_style_prompt=negative_style_prompt,
start_time=start_time
)
e_t = time.time() - s_t
print(f"inference cost {e_t} seconds")
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "output.wav")
torchaudio.save(output_path, generated_song, sample_rate=44100)
|