Spaces:
Sleeping
Sleeping
Staticaliza
commited on
Upload 3 files
Browse files- dac/utils/__init__.py +123 -0
- dac/utils/decode.py +95 -0
- dac/utils/encode.py +94 -0
dac/utils/__init__.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import argbind
|
4 |
+
from audiotools import ml
|
5 |
+
|
6 |
+
import dac
|
7 |
+
|
8 |
+
DAC = dac.model.DAC
|
9 |
+
Accelerator = ml.Accelerator
|
10 |
+
|
11 |
+
__MODEL_LATEST_TAGS__ = {
|
12 |
+
("44khz", "8kbps"): "0.0.1",
|
13 |
+
("24khz", "8kbps"): "0.0.4",
|
14 |
+
("16khz", "8kbps"): "0.0.5",
|
15 |
+
("44khz", "16kbps"): "1.0.0",
|
16 |
+
}
|
17 |
+
|
18 |
+
__MODEL_URLS__ = {
|
19 |
+
(
|
20 |
+
"44khz",
|
21 |
+
"0.0.1",
|
22 |
+
"8kbps",
|
23 |
+
): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
|
24 |
+
(
|
25 |
+
"24khz",
|
26 |
+
"0.0.4",
|
27 |
+
"8kbps",
|
28 |
+
): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
|
29 |
+
(
|
30 |
+
"16khz",
|
31 |
+
"0.0.5",
|
32 |
+
"8kbps",
|
33 |
+
): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
|
34 |
+
(
|
35 |
+
"44khz",
|
36 |
+
"1.0.0",
|
37 |
+
"16kbps",
|
38 |
+
): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
|
39 |
+
}
|
40 |
+
|
41 |
+
|
42 |
+
@argbind.bind(group="download", positional=True, without_prefix=True)
|
43 |
+
def download(
|
44 |
+
model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
|
45 |
+
):
|
46 |
+
"""
|
47 |
+
Function that downloads the weights file from URL if a local cache is not found.
|
48 |
+
|
49 |
+
Parameters
|
50 |
+
----------
|
51 |
+
model_type : str
|
52 |
+
The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
|
53 |
+
model_bitrate: str
|
54 |
+
Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
|
55 |
+
Only 44khz model supports 16kbps.
|
56 |
+
tag : str
|
57 |
+
The tag of the model to download. Defaults to "latest".
|
58 |
+
|
59 |
+
Returns
|
60 |
+
-------
|
61 |
+
Path
|
62 |
+
Directory path required to load model via audiotools.
|
63 |
+
"""
|
64 |
+
model_type = model_type.lower()
|
65 |
+
tag = tag.lower()
|
66 |
+
|
67 |
+
assert model_type in [
|
68 |
+
"44khz",
|
69 |
+
"24khz",
|
70 |
+
"16khz",
|
71 |
+
], "model_type must be one of '44khz', '24khz', or '16khz'"
|
72 |
+
|
73 |
+
assert model_bitrate in [
|
74 |
+
"8kbps",
|
75 |
+
"16kbps",
|
76 |
+
], "model_bitrate must be one of '8kbps', or '16kbps'"
|
77 |
+
|
78 |
+
if tag == "latest":
|
79 |
+
tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
|
80 |
+
|
81 |
+
download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
|
82 |
+
|
83 |
+
if download_link is None:
|
84 |
+
raise ValueError(
|
85 |
+
f"Could not find model with tag {tag} and model type {model_type}"
|
86 |
+
)
|
87 |
+
|
88 |
+
local_path = (
|
89 |
+
Path.home()
|
90 |
+
/ ".cache"
|
91 |
+
/ "descript"
|
92 |
+
/ "dac"
|
93 |
+
/ f"weights_{model_type}_{model_bitrate}_{tag}.pth"
|
94 |
+
)
|
95 |
+
if not local_path.exists():
|
96 |
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
97 |
+
|
98 |
+
# Download the model
|
99 |
+
import requests
|
100 |
+
|
101 |
+
response = requests.get(download_link)
|
102 |
+
|
103 |
+
if response.status_code != 200:
|
104 |
+
raise ValueError(
|
105 |
+
f"Could not download model. Received response code {response.status_code}"
|
106 |
+
)
|
107 |
+
local_path.write_bytes(response.content)
|
108 |
+
|
109 |
+
return local_path
|
110 |
+
|
111 |
+
|
112 |
+
def load_model(
|
113 |
+
model_type: str = "44khz",
|
114 |
+
model_bitrate: str = "8kbps",
|
115 |
+
tag: str = "latest",
|
116 |
+
load_path: str = None,
|
117 |
+
):
|
118 |
+
if not load_path:
|
119 |
+
load_path = download(
|
120 |
+
model_type=model_type, model_bitrate=model_bitrate, tag=tag
|
121 |
+
)
|
122 |
+
generator = DAC.load(load_path)
|
123 |
+
return generator
|
dac/utils/decode.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import argbind
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from audiotools import AudioSignal
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from dac import DACFile
|
11 |
+
from dac.utils import load_model
|
12 |
+
|
13 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
14 |
+
|
15 |
+
|
16 |
+
@argbind.bind(group="decode", positional=True, without_prefix=True)
|
17 |
+
@torch.inference_mode()
|
18 |
+
@torch.no_grad()
|
19 |
+
def decode(
|
20 |
+
input: str,
|
21 |
+
output: str = "",
|
22 |
+
weights_path: str = "",
|
23 |
+
model_tag: str = "latest",
|
24 |
+
model_bitrate: str = "8kbps",
|
25 |
+
device: str = "cuda",
|
26 |
+
model_type: str = "44khz",
|
27 |
+
verbose: bool = False,
|
28 |
+
):
|
29 |
+
"""Decode audio from codes.
|
30 |
+
|
31 |
+
Parameters
|
32 |
+
----------
|
33 |
+
input : str
|
34 |
+
Path to input directory or file
|
35 |
+
output : str, optional
|
36 |
+
Path to output directory, by default "".
|
37 |
+
If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
|
38 |
+
weights_path : str, optional
|
39 |
+
Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
|
40 |
+
model_tag and model_type.
|
41 |
+
model_tag : str, optional
|
42 |
+
Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
|
43 |
+
model_bitrate: str
|
44 |
+
Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
|
45 |
+
device : str, optional
|
46 |
+
Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
|
47 |
+
model_type : str, optional
|
48 |
+
The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
|
49 |
+
"""
|
50 |
+
generator = load_model(
|
51 |
+
model_type=model_type,
|
52 |
+
model_bitrate=model_bitrate,
|
53 |
+
tag=model_tag,
|
54 |
+
load_path=weights_path,
|
55 |
+
)
|
56 |
+
generator.to(device)
|
57 |
+
generator.eval()
|
58 |
+
|
59 |
+
# Find all .dac files in input directory
|
60 |
+
_input = Path(input)
|
61 |
+
input_files = list(_input.glob("**/*.dac"))
|
62 |
+
|
63 |
+
# If input is a .dac file, add it to the list
|
64 |
+
if _input.suffix == ".dac":
|
65 |
+
input_files.append(_input)
|
66 |
+
|
67 |
+
# Create output directory
|
68 |
+
output = Path(output)
|
69 |
+
output.mkdir(parents=True, exist_ok=True)
|
70 |
+
|
71 |
+
for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
|
72 |
+
# Load file
|
73 |
+
artifact = DACFile.load(input_files[i])
|
74 |
+
|
75 |
+
# Reconstruct audio from codes
|
76 |
+
recons = generator.decompress(artifact, verbose=verbose)
|
77 |
+
|
78 |
+
# Compute output path
|
79 |
+
relative_path = input_files[i].relative_to(input)
|
80 |
+
output_dir = output / relative_path.parent
|
81 |
+
if not relative_path.name:
|
82 |
+
output_dir = output
|
83 |
+
relative_path = input_files[i]
|
84 |
+
output_name = relative_path.with_suffix(".wav").name
|
85 |
+
output_path = output_dir / output_name
|
86 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
87 |
+
|
88 |
+
# Write to file
|
89 |
+
recons.write(output_path)
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
args = argbind.parse_args()
|
94 |
+
with argbind.scope(args):
|
95 |
+
decode()
|
dac/utils/encode.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import argbind
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from audiotools import AudioSignal
|
9 |
+
from audiotools.core import util
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from dac.utils import load_model
|
13 |
+
|
14 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
15 |
+
|
16 |
+
|
17 |
+
@argbind.bind(group="encode", positional=True, without_prefix=True)
|
18 |
+
@torch.inference_mode()
|
19 |
+
@torch.no_grad()
|
20 |
+
def encode(
|
21 |
+
input: str,
|
22 |
+
output: str = "",
|
23 |
+
weights_path: str = "",
|
24 |
+
model_tag: str = "latest",
|
25 |
+
model_bitrate: str = "8kbps",
|
26 |
+
n_quantizers: int = None,
|
27 |
+
device: str = "cuda",
|
28 |
+
model_type: str = "44khz",
|
29 |
+
win_duration: float = 5.0,
|
30 |
+
verbose: bool = False,
|
31 |
+
):
|
32 |
+
"""Encode audio files in input path to .dac format.
|
33 |
+
|
34 |
+
Parameters
|
35 |
+
----------
|
36 |
+
input : str
|
37 |
+
Path to input audio file or directory
|
38 |
+
output : str, optional
|
39 |
+
Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
|
40 |
+
weights_path : str, optional
|
41 |
+
Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
|
42 |
+
model_tag and model_type.
|
43 |
+
model_tag : str, optional
|
44 |
+
Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
|
45 |
+
model_bitrate: str
|
46 |
+
Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
|
47 |
+
n_quantizers : int, optional
|
48 |
+
Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
|
49 |
+
device : str, optional
|
50 |
+
Device to use, by default "cuda"
|
51 |
+
model_type : str, optional
|
52 |
+
The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
|
53 |
+
"""
|
54 |
+
generator = load_model(
|
55 |
+
model_type=model_type,
|
56 |
+
model_bitrate=model_bitrate,
|
57 |
+
tag=model_tag,
|
58 |
+
load_path=weights_path,
|
59 |
+
)
|
60 |
+
generator.to(device)
|
61 |
+
generator.eval()
|
62 |
+
kwargs = {"n_quantizers": n_quantizers}
|
63 |
+
|
64 |
+
# Find all audio files in input path
|
65 |
+
input = Path(input)
|
66 |
+
audio_files = util.find_audio(input)
|
67 |
+
|
68 |
+
output = Path(output)
|
69 |
+
output.mkdir(parents=True, exist_ok=True)
|
70 |
+
|
71 |
+
for i in tqdm(range(len(audio_files)), desc="Encoding files"):
|
72 |
+
# Load file
|
73 |
+
signal = AudioSignal(audio_files[i])
|
74 |
+
|
75 |
+
# Encode audio to .dac format
|
76 |
+
artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
|
77 |
+
|
78 |
+
# Compute output path
|
79 |
+
relative_path = audio_files[i].relative_to(input)
|
80 |
+
output_dir = output / relative_path.parent
|
81 |
+
if not relative_path.name:
|
82 |
+
output_dir = output
|
83 |
+
relative_path = audio_files[i]
|
84 |
+
output_name = relative_path.with_suffix(".dac").name
|
85 |
+
output_path = output_dir / output_name
|
86 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
87 |
+
|
88 |
+
artifact.save(output_path)
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
args = argbind.parse_args()
|
93 |
+
with argbind.scope(args):
|
94 |
+
encode()
|