Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
# This code is modified from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/model/base.py | |
import math | |
from dataclasses import dataclass | |
from pathlib import Path | |
from typing import Union | |
import numpy as np | |
import torch | |
import tqdm | |
from audiotools import AudioSignal | |
from torch import nn | |
SUPPORTED_VERSIONS = ["1.0.0"] | |
class DACFile: | |
codes: torch.Tensor | |
# Metadata | |
chunk_length: int | |
original_length: int | |
input_db: float | |
channels: int | |
sample_rate: int | |
padding: bool | |
dac_version: str | |
def save(self, path): | |
artifacts = { | |
"codes": self.codes.numpy().astype(np.uint16), | |
"metadata": { | |
"input_db": self.input_db.numpy().astype(np.float32), | |
"original_length": self.original_length, | |
"sample_rate": self.sample_rate, | |
"chunk_length": self.chunk_length, | |
"channels": self.channels, | |
"padding": self.padding, | |
"dac_version": SUPPORTED_VERSIONS[-1], | |
}, | |
} | |
path = Path(path).with_suffix(".dac") | |
with open(path, "wb") as f: | |
np.save(f, artifacts) | |
return path | |
def load(cls, path): | |
artifacts = np.load(path, allow_pickle=True)[()] | |
codes = torch.from_numpy(artifacts["codes"].astype(int)) | |
if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: | |
raise RuntimeError( | |
f"Given file {path} can't be loaded with this version of descript-audio-codec." | |
) | |
return cls(codes=codes, **artifacts["metadata"]) | |
class CodecMixin: | |
def padding(self): | |
if not hasattr(self, "_padding"): | |
self._padding = True | |
return self._padding | |
def padding(self, value): | |
assert isinstance(value, bool) | |
layers = [ | |
l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d)) | |
] | |
for layer in layers: | |
if value: | |
if hasattr(layer, "original_padding"): | |
layer.padding = layer.original_padding | |
else: | |
layer.original_padding = layer.padding | |
layer.padding = tuple(0 for _ in range(len(layer.padding))) | |
self._padding = value | |
def get_delay(self): | |
# Any number works here, delay is invariant to input length | |
l_out = self.get_output_length(0) | |
L = l_out | |
layers = [] | |
for layer in self.modules(): | |
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): | |
layers.append(layer) | |
for layer in reversed(layers): | |
d = layer.dilation[0] | |
k = layer.kernel_size[0] | |
s = layer.stride[0] | |
if isinstance(layer, nn.ConvTranspose1d): | |
L = ((L - d * (k - 1) - 1) / s) + 1 | |
elif isinstance(layer, nn.Conv1d): | |
L = (L - 1) * s + d * (k - 1) + 1 | |
L = math.ceil(L) | |
l_in = L | |
return (l_in - l_out) // 2 | |
def get_output_length(self, input_length): | |
L = input_length | |
# Calculate output length | |
for layer in self.modules(): | |
if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): | |
d = layer.dilation[0] | |
k = layer.kernel_size[0] | |
s = layer.stride[0] | |
if isinstance(layer, nn.Conv1d): | |
L = ((L - d * (k - 1) - 1) / s) + 1 | |
elif isinstance(layer, nn.ConvTranspose1d): | |
L = (L - 1) * s + d * (k - 1) + 1 | |
L = math.floor(L) | |
return L | |
def compress( | |
self, | |
audio_path_or_signal: Union[str, Path, AudioSignal], | |
win_duration: float = 1.0, | |
verbose: bool = False, | |
normalize_db: float = -16, | |
n_quantizers: int = None, | |
) -> DACFile: | |
"""Processes an audio signal from a file or AudioSignal object into | |
discrete codes. This function processes the signal in short windows, | |
using constant GPU memory. | |
Parameters | |
---------- | |
audio_path_or_signal : Union[str, Path, AudioSignal] | |
audio signal to reconstruct | |
win_duration : float, optional | |
window duration in seconds, by default 5.0 | |
verbose : bool, optional | |
by default False | |
normalize_db : float, optional | |
normalize db, by default -16 | |
Returns | |
------- | |
DACFile | |
Object containing compressed codes and metadata | |
required for decompression | |
""" | |
audio_signal = audio_path_or_signal | |
if isinstance(audio_signal, (str, Path)): | |
audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) | |
self.eval() | |
original_padding = self.padding | |
original_device = audio_signal.device | |
audio_signal = audio_signal.clone() | |
original_sr = audio_signal.sample_rate | |
resample_fn = audio_signal.resample | |
loudness_fn = audio_signal.loudness | |
# If audio is > 10 minutes long, use the ffmpeg versions | |
if audio_signal.signal_duration >= 10 * 60 * 60: | |
resample_fn = audio_signal.ffmpeg_resample | |
loudness_fn = audio_signal.ffmpeg_loudness | |
original_length = audio_signal.signal_length | |
resample_fn(self.sample_rate) | |
input_db = loudness_fn() | |
if normalize_db is not None: | |
audio_signal.normalize(normalize_db) | |
audio_signal.ensure_max_of_audio() | |
nb, nac, nt = audio_signal.audio_data.shape | |
audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) | |
win_duration = ( | |
audio_signal.signal_duration if win_duration is None else win_duration | |
) | |
if audio_signal.signal_duration <= win_duration: | |
# Unchunked compression (used if signal length < win duration) | |
self.padding = True | |
n_samples = nt | |
hop = nt | |
else: | |
# Chunked inference | |
self.padding = False | |
# Zero-pad signal on either side by the delay | |
audio_signal.zero_pad(self.delay, self.delay) | |
n_samples = int(win_duration * self.sample_rate) | |
# Round n_samples to nearest hop length multiple | |
n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) | |
hop = self.get_output_length(n_samples) | |
codes = [] | |
range_fn = range if not verbose else tqdm.trange | |
for i in range_fn(0, nt, hop): | |
x = audio_signal[..., i : i + n_samples] | |
x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) | |
audio_data = x.audio_data.to(self.device) | |
audio_data = self.preprocess(audio_data, self.sample_rate) | |
_, c, _, _, _ = self.encode(audio_data, n_quantizers) | |
codes.append(c.to(original_device)) | |
chunk_length = c.shape[-1] | |
codes = torch.cat(codes, dim=-1) | |
dac_file = DACFile( | |
codes=codes, | |
chunk_length=chunk_length, | |
original_length=original_length, | |
input_db=input_db, | |
channels=nac, | |
sample_rate=original_sr, | |
padding=self.padding, | |
dac_version=SUPPORTED_VERSIONS[-1], | |
) | |
if n_quantizers is not None: | |
codes = codes[:, :n_quantizers, :] | |
self.padding = original_padding | |
return dac_file | |
def decompress( | |
self, | |
obj: Union[str, Path, DACFile], | |
verbose: bool = False, | |
) -> AudioSignal: | |
"""Reconstruct audio from a given .dac file | |
Parameters | |
---------- | |
obj : Union[str, Path, DACFile] | |
.dac file location or corresponding DACFile object. | |
verbose : bool, optional | |
Prints progress if True, by default False | |
Returns | |
------- | |
AudioSignal | |
Object with the reconstructed audio | |
""" | |
self.eval() | |
if isinstance(obj, (str, Path)): | |
obj = DACFile.load(obj) | |
original_padding = self.padding | |
self.padding = obj.padding | |
range_fn = range if not verbose else tqdm.trange | |
codes = obj.codes | |
original_device = codes.device | |
chunk_length = obj.chunk_length | |
recons = [] | |
for i in range_fn(0, codes.shape[-1], chunk_length): | |
c = codes[..., i : i + chunk_length].to(self.device) | |
z = self.quantizer.from_codes(c)[0] | |
r = self.decode(z) | |
recons.append(r.to(original_device)) | |
recons = torch.cat(recons, dim=-1) | |
recons = AudioSignal(recons, self.sample_rate) | |
resample_fn = recons.resample | |
loudness_fn = recons.loudness | |
# If audio is > 10 minutes long, use the ffmpeg versions | |
if recons.signal_duration >= 10 * 60 * 60: | |
resample_fn = recons.ffmpeg_resample | |
loudness_fn = recons.ffmpeg_loudness | |
recons.normalize(obj.input_db) | |
resample_fn(obj.sample_rate) | |
recons = recons[..., : obj.original_length] | |
loudness_fn() | |
recons.audio_data = recons.audio_data.reshape( | |
-1, obj.channels, obj.original_length | |
) | |
self.padding = original_padding | |
return recons | |