xcczach's picture
Upload 73 files
35c1cfd verified
raw
history blame
5.55 kB
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
class WhisperWrappedEncoder:
@classmethod
def load(cls, model_config):
def extract_variable_length_features(self, x: torch.Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
# assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
# x = (x + self.positional_embedding).to(x.dtype)
x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
return x
import whisper
encoder = whisper.load_model(name=model_config.encoder_path, device='cpu').encoder
encoder.extract_variable_length_features = types.MethodType(extract_variable_length_features, encoder)
return encoder
class BEATsEncoder:
@classmethod
def load(cls, model_config):
from .BEATs.BEATs import BEATs, BEATsConfig
checkpoint = torch.load(model_config.encoder_path)
cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
return BEATs_model
@dataclass
class UserDirModule:
user_dir: str
class EATEncoder:
@classmethod
def load(cls, model_config):
import fairseq
model_path = UserDirModule(model_config.encoder_fairseq_dir)
fairseq.utils.import_user_module(model_path)
EATEncoder, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path])
EATEncoder = EATEncoder[0]
return EATEncoder
def extract_features(self, source, padding_mask):
return self.model.extract_features(source, padding_mask = padding_mask, mask=False, remove_extra_tokens = False)['x']
class SpatialASTEncoder:
@classmethod
def load(cls, model_config):
from functools import partial
from .SpatialAST import SpatialAST
binaural_encoder = SpatialAST.BinauralEncoder(
num_classes=355, drop_path_rate=0.1, num_cls_tokens=3,
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
)
checkpoint = torch.load(model_config.encoder_ckpt, map_location='cpu')
binaural_encoder.load_state_dict(checkpoint['model'], strict=False)
return binaural_encoder
class WavLMEncoder(nn.Module):
def __init__(self, config, model):
super().__init__()
self.config = config
self.model = model
@classmethod
def load(cls, model_config):
from .wavlm.WavLM import WavLM, WavLMConfig
checkpoint = torch.load(model_config.encoder_path)
cfg = WavLMConfig(checkpoint['cfg'])
WavLM_model = WavLM(cfg)
WavLM_model.load_state_dict(checkpoint['model'])
assert model_config.normalize == cfg.normalize, "normalize flag in config and model checkpoint do not match"
return cls(cfg, WavLM_model)
def extract_features(self, source, padding_mask):
return self.model.extract_features(source, padding_mask)[0]
class AVHubertEncoder:
@classmethod
def load(cls, model_config):
import fairseq
from .avhubert import hubert_pretraining, hubert, hubert_asr
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path])
model = models[0]
return model
class HubertEncoder:
@classmethod
def load(cls, model_config):
import fairseq
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path])
model = models[0]
if model_config.encoder_type == "pretrain":
pass
elif model_config.encoder_type == "finetune":
model.w2v_encoder.proj = None
model.w2v_encoder.apply_mask = False
else:
assert model_config.encoder_type in ["pretrain", "finetune"], "input_type must be one of [pretrain, finetune]"
return model
class HfTextEncoder:
@classmethod
def load(cls, model_config):
from transformers import AutoModel
model = AutoModel.from_pretrained(model_config.encoder_path)
return model
class MusicFMEncoder(nn.Module):
def __init__(self, config, model):
super().__init__()
self.config = config
self.model = model
@classmethod
def load(cls, model_config):
from .musicfm.model.musicfm_25hz import MusicFM25Hz
model = MusicFM25Hz(
stat_path = model_config.encoder_stat_path,
model_path = model_config.encoder_path,
w2v2_config_path = model_config.get('encoder_config_path', "facebook/wav2vec2-conformer-rope-large-960h-ft")
)
return cls(model_config, model)
def extract_features(self, source, padding_mask=None):
_, hidden_states = self.model.get_predictions(source)
out = hidden_states[self.config.encoder_layer_idx]
return out