Spaces:
Runtime error
Runtime error
File size: 5,550 Bytes
35c1cfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
class WhisperWrappedEncoder:
@classmethod
def load(cls, model_config):
def extract_variable_length_features(self, x: torch.Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
# assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
# x = (x + self.positional_embedding).to(x.dtype)
x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype)
for block in self.blocks:
x = block(x)
x = self.ln_post(x)
return x
import whisper
encoder = whisper.load_model(name=model_config.encoder_path, device='cpu').encoder
encoder.extract_variable_length_features = types.MethodType(extract_variable_length_features, encoder)
return encoder
class BEATsEncoder:
@classmethod
def load(cls, model_config):
from .BEATs.BEATs import BEATs, BEATsConfig
checkpoint = torch.load(model_config.encoder_path)
cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
return BEATs_model
@dataclass
class UserDirModule:
user_dir: str
class EATEncoder:
@classmethod
def load(cls, model_config):
import fairseq
model_path = UserDirModule(model_config.encoder_fairseq_dir)
fairseq.utils.import_user_module(model_path)
EATEncoder, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path])
EATEncoder = EATEncoder[0]
return EATEncoder
def extract_features(self, source, padding_mask):
return self.model.extract_features(source, padding_mask = padding_mask, mask=False, remove_extra_tokens = False)['x']
class SpatialASTEncoder:
@classmethod
def load(cls, model_config):
from functools import partial
from .SpatialAST import SpatialAST
binaural_encoder = SpatialAST.BinauralEncoder(
num_classes=355, drop_path_rate=0.1, num_cls_tokens=3,
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
)
checkpoint = torch.load(model_config.encoder_ckpt, map_location='cpu')
binaural_encoder.load_state_dict(checkpoint['model'], strict=False)
return binaural_encoder
class WavLMEncoder(nn.Module):
def __init__(self, config, model):
super().__init__()
self.config = config
self.model = model
@classmethod
def load(cls, model_config):
from .wavlm.WavLM import WavLM, WavLMConfig
checkpoint = torch.load(model_config.encoder_path)
cfg = WavLMConfig(checkpoint['cfg'])
WavLM_model = WavLM(cfg)
WavLM_model.load_state_dict(checkpoint['model'])
assert model_config.normalize == cfg.normalize, "normalize flag in config and model checkpoint do not match"
return cls(cfg, WavLM_model)
def extract_features(self, source, padding_mask):
return self.model.extract_features(source, padding_mask)[0]
class AVHubertEncoder:
@classmethod
def load(cls, model_config):
import fairseq
from .avhubert import hubert_pretraining, hubert, hubert_asr
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path])
model = models[0]
return model
class HubertEncoder:
@classmethod
def load(cls, model_config):
import fairseq
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path])
model = models[0]
if model_config.encoder_type == "pretrain":
pass
elif model_config.encoder_type == "finetune":
model.w2v_encoder.proj = None
model.w2v_encoder.apply_mask = False
else:
assert model_config.encoder_type in ["pretrain", "finetune"], "input_type must be one of [pretrain, finetune]"
return model
class HfTextEncoder:
@classmethod
def load(cls, model_config):
from transformers import AutoModel
model = AutoModel.from_pretrained(model_config.encoder_path)
return model
class MusicFMEncoder(nn.Module):
def __init__(self, config, model):
super().__init__()
self.config = config
self.model = model
@classmethod
def load(cls, model_config):
from .musicfm.model.musicfm_25hz import MusicFM25Hz
model = MusicFM25Hz(
stat_path = model_config.encoder_stat_path,
model_path = model_config.encoder_path,
w2v2_config_path = model_config.get('encoder_config_path', "facebook/wav2vec2-conformer-rope-large-960h-ft")
)
return cls(model_config, model)
def extract_features(self, source, padding_mask=None):
_, hidden_states = self.model.get_predictions(source)
out = hidden_states[self.config.encoder_layer_idx]
return out
|