Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# pyre-strict | |
""" | |
Main model for using MelodyFlow. This will combine all the required components | |
and provide easy access to the generation API. | |
""" | |
import typing as tp | |
from audiocraft.utils.autocast import TorchAutocast | |
import torch | |
from .genmodel import BaseGenModel | |
from ..modules.conditioners import ConditioningAttributes | |
from ..utils.utils import vae_sample | |
from .loaders import load_compression_model, load_dit_model_melodyflow | |
class MelodyFlow(BaseGenModel): | |
"""MelodyFlow main model with convenient generation API. | |
Args: | |
See MelodyFlow class. | |
""" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.set_generation_params() | |
self.set_editing_params() | |
if self.device.type == 'cpu' or self.device.type == 'mps': | |
self.autocast = TorchAutocast(enabled=False) | |
else: | |
self.autocast = TorchAutocast( | |
enabled=True, device_type=self.device.type, dtype=torch.bfloat16) | |
def get_pretrained(name: str = 'facebook/melodyflow-t24-30secs', device=None): | |
# TODO complete the list of pretrained models | |
""" | |
""" | |
if device is None: | |
if torch.cuda.device_count(): | |
device = 'cuda' | |
elif torch.backends.mps.is_available(): | |
device = 'mps' | |
else: | |
device = 'cpu' | |
compression_model = load_compression_model(name, device=device) | |
def _remove_weight_norm(module): | |
if hasattr(module, "conv"): | |
if hasattr(module.conv, "conv"): | |
torch.nn.utils.parametrize.remove_parametrizations( | |
module.conv.conv, "weight" | |
) | |
if hasattr(module, "convtr"): | |
if hasattr(module.convtr, "convtr"): | |
torch.nn.utils.parametrize.remove_parametrizations( | |
module.convtr.convtr, "weight" | |
) | |
def _clear_weight_norm(module): | |
_remove_weight_norm(module) | |
for child in module.children(): | |
_clear_weight_norm(child) | |
compression_model.to('cpu') | |
_clear_weight_norm(compression_model) | |
compression_model.to(device) | |
lm = load_dit_model_melodyflow(name, device=device) | |
kwargs = {'name': name, 'compression_model': compression_model, 'lm': lm} | |
return MelodyFlow(**kwargs) | |
def set_generation_params( | |
self, | |
solver: str = "midpoint", | |
steps: int = 64, | |
duration: float = 10.0, | |
) -> tp.Dict[str, torch.Tensor]: | |
"""Set regularized inversion parameters for MelodyFlow. | |
Args: | |
solver (str, optional): ODE solver, either euler or midpoint. | |
steps (int, optional): number of inference steps. | |
""" | |
self.generation_params = { | |
'solver': solver, | |
'steps': steps, | |
'duration': duration, | |
} | |
def set_editing_params( | |
self, | |
solver: str = "euler", | |
steps: int = 25, | |
target_flowstep: float = 0.0, | |
regularize: bool = True, | |
regularize_iters: int = 4, | |
keep_last_k_iters: int = 2, | |
lambda_kl: float = 0.2, | |
) -> tp.Dict[str, torch.Tensor]: | |
"""Set regularized inversion parameters for MelodyFlow. | |
Args: | |
solver (str, optional): ODE solver, either euler or midpoint. | |
steps (int, optional): number of inference steps. | |
target_flowstep (float): Target flow step. | |
regularize (bool): Regularize each solver step. | |
regularize_iters (int, optional): Number of regularization iterations. | |
keep_last_k_iters (int, optional): Number of meaningful regularization iterations for moving average computation. | |
lambda_kl (float, optional): KL regularization loss weight. | |
""" | |
self.editing_params = { | |
'solver': solver, | |
'steps': steps, | |
'target_flowstep': target_flowstep, | |
'regularize': regularize, | |
'regularize_iters': regularize_iters, | |
'keep_last_k_iters': keep_last_k_iters, | |
'lambda_kl': lambda_kl, | |
} | |
def encode_audio(self, waveform: torch.Tensor) -> torch.Tensor: | |
"""Generate Audio from tokens.""" | |
assert waveform.dim() == 3 | |
with torch.no_grad(): | |
latent_sequence = self.compression_model.encode(waveform)[0].squeeze(1) | |
return latent_sequence | |
def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor: | |
"""Generate Audio from tokens.""" | |
assert gen_tokens.dim() == 3 | |
with torch.no_grad(): | |
if self.lm.latent_mean.shape[1] != gen_tokens.shape[1]: | |
# tokens directly emanate from the VAE encoder | |
mean, scale = gen_tokens.chunk(2, dim=1) | |
gen_tokens = vae_sample(mean, scale) | |
else: | |
# tokens emanate from the generator | |
gen_tokens = gen_tokens * (self.lm.latent_std + 1e-5) + self.lm.latent_mean | |
gen_audio = self.compression_model.decode(gen_tokens, None) | |
return gen_audio | |
def generate_unconditional(self, num_samples: int, progress: bool = False, | |
return_tokens: bool = False) -> tp.Union[torch.Tensor, | |
tp.Tuple[torch.Tensor, torch.Tensor]]: | |
"""Generate samples in an unconditional manner. | |
Args: | |
num_samples (int): Number of samples to be generated. | |
progress (bool, optional): Flag to display progress of the generation process. Defaults to False. | |
""" | |
descriptions: tp.List[tp.Optional[str]] = [None] * num_samples | |
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) | |
assert prompt_tokens is None | |
tokens = self._generate_tokens(attributes=attributes, | |
prompt_tokens=prompt_tokens, | |
progress=progress, | |
**self.generation_params, | |
) | |
if return_tokens: | |
return self.generate_audio(tokens), tokens | |
return self.generate_audio(tokens) | |
def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \ | |
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: | |
"""Generate samples conditioned on text. | |
Args: | |
descriptions (list of str): A list of strings used as text conditioning. | |
progress (bool, optional): Flag to display progress of the generation process. Defaults to False. | |
""" | |
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) | |
assert prompt_tokens is None | |
tokens = self._generate_tokens(attributes=attributes, | |
prompt_tokens=prompt_tokens, | |
progress=progress, | |
**self.generation_params, | |
) | |
if return_tokens: | |
return self.generate_audio(tokens), tokens | |
return self.generate_audio(tokens) | |
def edit(self, | |
prompt_tokens: torch.Tensor, | |
descriptions: tp.List[str], | |
src_descriptions: tp.Optional[tp.List[str]] = None, | |
progress: bool = False, | |
return_tokens: bool = False, | |
) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: | |
"""Generate samples conditioned on text. | |
Args: | |
prompt_tokens (torch.Tensor, optional): Audio prompt used as initial latent sequence. | |
descriptions (list of str): A list of strings used as editing conditioning. | |
inversion (str): Inversion method (either ddim or fm_renoise) | |
target_flowstep (float): Target flow step pivot in [0, 1[. | |
steps (int): number of solver steps. | |
src_descriptions (list of str): A list of strings used as conditioning during latent inversion. | |
progress (bool, optional): Flag to display progress of the generation process. Defaults to False. | |
return_tokens (bool): Whether to return the generated tokens. | |
""" | |
empty_attributes, no_tokens = self._prepare_tokens_and_attributes( | |
[""] if src_descriptions is None else src_descriptions, None) | |
assert no_tokens is None | |
edit_attributes, no_tokens = self._prepare_tokens_and_attributes(descriptions, None) | |
assert no_tokens is None | |
inversion_params = self.editing_params.copy() | |
override_total_steps = inversion_params["steps"] * ( | |
inversion_params["regularize_iters"] + 1) if inversion_params["regularize"] else inversion_params["steps"] * 2 | |
current_step_offset: int = 0 | |
def _progress_callback(elapsed_steps: int, total_steps: int): | |
elapsed_steps += current_step_offset | |
if self._progress_callback is not None: | |
self._progress_callback(elapsed_steps, override_total_steps) | |
else: | |
print(f'{elapsed_steps: 6d} / {override_total_steps: 6d}', end='\r') | |
intermediate_tokens = self._generate_tokens(attributes=empty_attributes, | |
prompt_tokens=prompt_tokens, | |
source_flowstep=1.0, | |
progress=progress, | |
callback=_progress_callback, | |
**inversion_params, | |
) | |
if intermediate_tokens.shape[0] < len(descriptions): | |
intermediate_tokens = intermediate_tokens.repeat(len(descriptions)//intermediate_tokens.shape[0], 1, 1) | |
current_step_offset += inversion_params["steps"] * ( | |
inversion_params["regularize_iters"]) if inversion_params["regularize"] else inversion_params["steps"] | |
inversion_params.pop("regularize") | |
final_tokens = self._generate_tokens(attributes=edit_attributes, | |
prompt_tokens=intermediate_tokens, | |
source_flowstep=inversion_params.pop("target_flowstep"), | |
target_flowstep=1.0, | |
progress=progress, | |
callback=_progress_callback, | |
**inversion_params,) | |
if return_tokens: | |
return self.generate_audio(final_tokens), final_tokens | |
return self.generate_audio(final_tokens) | |
def _generate_tokens(self, | |
attributes: tp.List[ConditioningAttributes], | |
prompt_tokens: tp.Optional[torch.Tensor], | |
progress: bool = False, | |
callback: tp.Optional[tp.Callable[[int, int], None]] = None, | |
**kwargs) -> torch.Tensor: | |
"""Generate continuous audio tokens given audio prompt and/or conditions. | |
Args: | |
attributes (list of ConditioningAttributes): Conditions used for generation (here text). | |
prompt_tokens (torch.Tensor, optional): Audio prompt used as initial latent sequence. | |
progress (bool, optional): Flag to display progress of the generation process. Defaults to False. | |
Returns: | |
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params. | |
""" | |
generate_params = kwargs.copy() | |
total_gen_len = prompt_tokens.shape[-1] if prompt_tokens is not None else int( | |
generate_params.pop('duration') * self.frame_rate) | |
current_step_offset: int = 0 | |
def _progress_callback(elapsed_steps: int, total_steps: int): | |
elapsed_steps += current_step_offset | |
if self._progress_callback is not None: | |
self._progress_callback(elapsed_steps, total_steps) | |
else: | |
print(f'{elapsed_steps: 6d} / {total_steps: 6d}', end='\r') | |
if progress and callback is None: | |
callback = _progress_callback | |
assert total_gen_len <= int(self.max_duration * self.frame_rate) | |
with self.autocast: | |
gen_tokens = self.lm.generate( | |
prompt=prompt_tokens, | |
conditions=attributes, | |
callback=callback, | |
max_gen_len=total_gen_len, | |
**generate_params, | |
) | |
return gen_tokens | |