# Copyright (c) 2024 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license. # LICENSE is in incl_licenses directory. import sys sys.path.append('../') from typing import Optional from copy import deepcopy from transformers import AutoModelForCausalLM, AutoTokenizer from ms_clap.src.CLAPWrapper import CLAPWrapper import torch from torch import nn try: from .flamingo import Flamingo from .flamingo_lm import FlamingoLMMixin from .utils import extend_instance except: from flamingo import Flamingo from flamingo_lm import FlamingoLMMixin from utils import extend_instance class CLAP(nn.Module): def __init__(self, clap_config): super(CLAP, self).__init__() self.method = clap_config["method"] device_id = f'cuda:{torch.cuda.current_device()}' if self.method == 'laion-clap': # https://github.com/LAION-AI/CLAP if clap_config["model_name"] in ['630k-audioset-best', '630k-best', '630k-audioset-fusion-best', '630k-fusion-best']: amodel = 'HTSAT-tiny' elif clap_config["model_name"] in ['music_speech_audioset_epoch_15_esc_89.98']: amodel = 'HTSAT-base' else: raise NotImplementedError enable_fusion = 'fusion' in clap_config["model_name"].lower() self.laion_clap = CLAP_Module(enable_fusion=enable_fusion, amodel=amodel, device=device_id) self.laion_clap.load_ckpt(ckpt=clap_config["checkpoint"]) for param in self.laion_clap.parameters(): param.requires_grad = False self.laion_clap.eval() print('loaded laion-clap model: {}'.format(clap_config["checkpoint"])) elif self.method == 'microsoft-clap': # https://github.com/microsoft/CLAP self.ms_clap = CLAPWrapper( clap_config["checkpoint"], config_root=clap_config["config_root"], version=clap_config['model_name'], use_cuda=True ) if clap_config['model_name'] in ['2022', '2023']: for param in self.ms_clap.clap.parameters(): param.requires_grad = False self.ms_clap.clap.eval() else: for param in self.ms_clap.clapcap.parameters(): param.requires_grad = False self.ms_clap.clapcap.eval() print('loaded microsoft-clap model: {}'.format(clap_config["checkpoint"])) else: raise NotImplementedError def forward(self, audio_clips): if len(audio_clips.shape) == 2: audio_clips = audio_clips.unsqueeze(0) assert len(audio_clips.shape) == 3 audio_embeds = [] for x in audio_clips: if self.method == 'laion-clap': audio_embed = self.laion_clap.get_audio_embedding_from_data(x=x, use_tensor=True) elif self.method == 'microsoft-clap': audio_embed = self.ms_clap.get_audio_embeddings_from_clips(x) audio_embeds.append(audio_embed) audio_embeds = torch.stack(audio_embeds, dim=0) audio_embeds.requires_grad = False return audio_embeds def create_model_and_transforms( clap_config: dict, lang_encoder_path: str, tokenizer_path: str, audio_transformer_kwargs: dict, cross_attn_every_n_layers: int = 1, use_local_files: bool = False, decoder_layers_attr_name: str = None, freeze_lm_embeddings: bool = False, unfreeze_full_lm: bool = False, cache_dir: Optional[str] = None, **flamingo_kwargs, ): clap = CLAP(clap_config) text_tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, local_files_only=use_local_files, trust_remote_code=True, cache_dir=cache_dir, ) text_tokenizer.add_special_tokens( {"additional_special_tokens": ["