Spaces:
Build error
Build error
# Copyright (c) 2024 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license. | |
# LICENSE is in incl_licenses directory. | |
import sys | |
sys.path.append('../') | |
from typing import Optional | |
from copy import deepcopy | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from ms_clap.src.CLAPWrapper import CLAPWrapper | |
import torch | |
from torch import nn | |
try: | |
from .flamingo import Flamingo | |
from .flamingo_lm import FlamingoLMMixin | |
from .utils import extend_instance | |
except: | |
from flamingo import Flamingo | |
from flamingo_lm import FlamingoLMMixin | |
from utils import extend_instance | |
class CLAP(nn.Module): | |
def __init__(self, clap_config): | |
super(CLAP, self).__init__() | |
self.method = clap_config["method"] | |
if torch.cuda.is_available(): | |
device = 'cuda:0' | |
else: | |
device = 'cpu' | |
if self.method == 'laion-clap': | |
# https://github.com/LAION-AI/CLAP | |
if clap_config["model_name"] in ['630k-audioset-best', '630k-best', '630k-audioset-fusion-best', '630k-fusion-best']: | |
amodel = 'HTSAT-tiny' | |
elif clap_config["model_name"] in ['music_speech_audioset_epoch_15_esc_89.98']: | |
amodel = 'HTSAT-base' | |
else: | |
raise NotImplementedError | |
enable_fusion = 'fusion' in clap_config["model_name"].lower() | |
self.laion_clap = CLAP_Module(enable_fusion=enable_fusion, amodel=amodel, device=device) | |
self.laion_clap.load_ckpt(ckpt=clap_config["checkpoint"]) | |
for param in self.laion_clap.parameters(): | |
param.requires_grad = False | |
self.laion_clap.eval() | |
print('loaded laion-clap model: {}'.format(clap_config["checkpoint"])) | |
elif self.method == 'microsoft-clap': | |
# https://github.com/microsoft/CLAP | |
self.ms_clap = CLAPWrapper( | |
clap_config["checkpoint"], | |
config_root=clap_config["config_root"], | |
version=clap_config['model_name'], | |
use_cuda=torch.cuda.is_available() | |
) | |
if clap_config['model_name'] in ['2022', '2023']: | |
for param in self.ms_clap.clap.parameters(): | |
param.requires_grad = False | |
self.ms_clap.clap.eval() | |
else: | |
for param in self.ms_clap.clapcap.parameters(): | |
param.requires_grad = False | |
self.ms_clap.clapcap.eval() | |
print('loaded microsoft-clap model: {}'.format(clap_config["checkpoint"])) | |
else: | |
raise NotImplementedError | |
def forward(self, audio_clips): | |
if len(audio_clips.shape) == 2: | |
audio_clips = audio_clips.unsqueeze(0) | |
assert len(audio_clips.shape) == 3 | |
audio_embeds = [] | |
for x in audio_clips: | |
if self.method == 'laion-clap': | |
audio_embed = self.laion_clap.get_audio_embedding_from_data(x=x, use_tensor=True) | |
elif self.method == 'microsoft-clap': | |
audio_embed = self.ms_clap.get_audio_embeddings_from_clips(x) | |
audio_embeds.append(audio_embed) | |
audio_embeds = torch.stack(audio_embeds, dim=0) | |
audio_embeds.requires_grad = False | |
return audio_embeds | |
def create_model_and_transforms( | |
clap_config: dict, | |
lang_encoder_path: str, | |
tokenizer_path: str, | |
audio_transformer_kwargs: dict, | |
cross_attn_every_n_layers: int = 1, | |
use_local_files: bool = False, | |
decoder_layers_attr_name: str = None, | |
freeze_lm_embeddings: bool = False, | |
unfreeze_full_lm: bool = False, | |
cache_dir: Optional[str] = None, | |
**flamingo_kwargs, | |
): | |
clap = CLAP(clap_config) | |
text_tokenizer = AutoTokenizer.from_pretrained( | |
tokenizer_path, | |
local_files_only=use_local_files, | |
trust_remote_code=True, | |
cache_dir=cache_dir, | |
) | |
text_tokenizer.add_special_tokens( | |
{"additional_special_tokens": ["<audio>", "<|endofchunk|>"]} | |
) | |
if text_tokenizer.pad_token is None: | |
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"}) | |
if text_tokenizer.sep_token is None: | |
text_tokenizer.add_special_tokens({"sep_token": "<SEP>"}) | |
lang_encoder = AutoModelForCausalLM.from_pretrained( | |
lang_encoder_path, | |
local_files_only=use_local_files, | |
trust_remote_code=True, | |
cache_dir=cache_dir, | |
) | |
extend_instance(lang_encoder, FlamingoLMMixin) | |
if decoder_layers_attr_name is None: | |
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) | |
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) | |
lang_encoder.resize_token_embeddings(len(text_tokenizer)) | |
unfreeze_clap = False | |
model = Flamingo( | |
clap, | |
unfreeze_clap, | |
lang_encoder, | |
text_tokenizer.encode("<|endofchunk|>")[-1], | |
text_tokenizer.encode("<audio>")[-1], | |
text_tokenizer.sep_token_id, | |
audio_embed_dim=clap_config["audio_embed_dim"], | |
audio_transformer_kwargs=audio_transformer_kwargs, | |
cross_attn_every_n_layers=cross_attn_every_n_layers, | |
**flamingo_kwargs, | |
) | |
model.requires_grad_(False) | |
assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 | |
model.audio_transformer.requires_grad_(True) | |
model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) | |
if not freeze_lm_embeddings: | |
model.lang_encoder.get_input_embeddings().requires_grad_(True) | |
if unfreeze_full_lm: | |
model.lang_encoder.requires_grad_(True) | |
if unfreeze_clap: | |
model.clap.requires_grad_(True) | |
print("Flamingo model initialized with {:,} trainable parameters (audio transformer has {:,}, LM has {:,})".format( | |
sum(p.numel() for p in model.parameters() if p.requires_grad), | |
sum(p.numel() for p in model.audio_transformer.parameters() if p.requires_grad), | |
sum(p.numel() for p in model.lang_encoder.parameters() if p.requires_grad) | |
)) | |
return model, text_tokenizer | |
def _infer_decoder_layers_attr_name(model): | |
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: | |
if k.lower() in model.__class__.__name__.lower(): | |
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] | |
raise ValueError( | |
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually." | |
) | |
__KNOWN_DECODER_LAYERS_ATTR_NAMES = { | |
"opt": "model.decoder.layers", | |
"gptj": "transformer.h", | |
"gpt-j": "transformer.h", | |
"pythia": "gpt_neox.layers", | |
"llama": "model.layers", | |
"gptneoxforcausallm": "gpt_neox.layers", | |
"mpt": "transformer.blocks", | |
"mosaicgpt": "transformer.blocks", | |
} | |