Spaces:
Runtime error
Runtime error
import types | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from dataclasses import dataclass | |
class WhisperWrappedEncoder: | |
def load(cls, model_config): | |
def extract_variable_length_features(self, x: torch.Tensor): | |
""" | |
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) | |
the mel spectrogram of the audio | |
""" | |
x = F.gelu(self.conv1(x)) | |
x = F.gelu(self.conv2(x)) | |
x = x.permute(0, 2, 1) | |
# assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" | |
# x = (x + self.positional_embedding).to(x.dtype) | |
x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype) | |
for block in self.blocks: | |
x = block(x) | |
x = self.ln_post(x) | |
return x | |
import whisper | |
encoder = whisper.load_model(name=model_config.encoder_path, device='cpu').encoder | |
encoder.extract_variable_length_features = types.MethodType(extract_variable_length_features, encoder) | |
return encoder | |
class BEATsEncoder: | |
def load(cls, model_config): | |
from .BEATs.BEATs import BEATs, BEATsConfig | |
checkpoint = torch.load(model_config.encoder_path) | |
cfg = BEATsConfig(checkpoint['cfg']) | |
BEATs_model = BEATs(cfg) | |
BEATs_model.load_state_dict(checkpoint['model']) | |
return BEATs_model | |
class UserDirModule: | |
user_dir: str | |
class EATEncoder: | |
def load(cls, model_config): | |
import fairseq | |
model_path = UserDirModule(model_config.encoder_fairseq_dir) | |
fairseq.utils.import_user_module(model_path) | |
EATEncoder, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path]) | |
EATEncoder = EATEncoder[0] | |
return EATEncoder | |
def extract_features(self, source, padding_mask): | |
return self.model.extract_features(source, padding_mask = padding_mask, mask=False, remove_extra_tokens = False)['x'] | |
class SpatialASTEncoder: | |
def load(cls, model_config): | |
from functools import partial | |
from .SpatialAST import SpatialAST | |
binaural_encoder = SpatialAST.BinauralEncoder( | |
num_classes=355, drop_path_rate=0.1, num_cls_tokens=3, | |
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, | |
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) | |
) | |
checkpoint = torch.load(model_config.encoder_ckpt, map_location='cpu') | |
binaural_encoder.load_state_dict(checkpoint['model'], strict=False) | |
return binaural_encoder | |
class WavLMEncoder(nn.Module): | |
def __init__(self, config, model): | |
super().__init__() | |
self.config = config | |
self.model = model | |
def load(cls, model_config): | |
from .wavlm.WavLM import WavLM, WavLMConfig | |
checkpoint = torch.load(model_config.encoder_path) | |
cfg = WavLMConfig(checkpoint['cfg']) | |
WavLM_model = WavLM(cfg) | |
WavLM_model.load_state_dict(checkpoint['model']) | |
assert model_config.normalize == cfg.normalize, "normalize flag in config and model checkpoint do not match" | |
return cls(cfg, WavLM_model) | |
def extract_features(self, source, padding_mask): | |
return self.model.extract_features(source, padding_mask)[0] | |
class AVHubertEncoder: | |
def load(cls, model_config): | |
import fairseq | |
from .avhubert import hubert_pretraining, hubert, hubert_asr | |
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path]) | |
model = models[0] | |
return model | |
class HubertEncoder: | |
def load(cls, model_config): | |
import fairseq | |
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path]) | |
model = models[0] | |
if model_config.encoder_type == "pretrain": | |
pass | |
elif model_config.encoder_type == "finetune": | |
model.w2v_encoder.proj = None | |
model.w2v_encoder.apply_mask = False | |
else: | |
assert model_config.encoder_type in ["pretrain", "finetune"], "input_type must be one of [pretrain, finetune]" | |
return model | |
class HfTextEncoder: | |
def load(cls, model_config): | |
from transformers import AutoModel | |
model = AutoModel.from_pretrained(model_config.encoder_path) | |
return model | |
class MusicFMEncoder(nn.Module): | |
def __init__(self, config, model): | |
super().__init__() | |
self.config = config | |
self.model = model | |
def load(cls, model_config): | |
from .musicfm.model.musicfm_25hz import MusicFM25Hz | |
model = MusicFM25Hz( | |
stat_path = model_config.encoder_stat_path, | |
model_path = model_config.encoder_path, | |
w2v2_config_path = model_config.get('encoder_config_path', "facebook/wav2vec2-conformer-rope-large-960h-ft") | |
) | |
return cls(model_config, model) | |
def extract_features(self, source, padding_mask=None): | |
_, hidden_states = self.model.get_predictions(source) | |
out = hidden_states[self.config.encoder_layer_idx] | |
return out | |