# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # This code is modified from https://huggingface.co./m-a-p/MERT-v1-330M import torch from tqdm import tqdm import numpy as np from transformers import Wav2Vec2FeatureExtractor from transformers import AutoModel import torchaudio import torchaudio.transforms as T from sklearn.preprocessing import StandardScaler def mert_encoder(model, processor, audio_path, hps): """ # mert default sr: 24000 """ with torch.no_grad(): resample_rate = processor.sampling_rate device = next(model.parameters()).device input_audio, sampling_rate = torchaudio.load(audio_path) input_audio = input_audio.squeeze() if sampling_rate != resample_rate: resampler = T.Resample(sampling_rate, resample_rate) input_audio = resampler(input_audio) inputs = processor( input_audio, sampling_rate=resample_rate, return_tensors="pt" ).to( device ) # {input_values: tensor, attention_mask: tensor} outputs = model(**inputs, output_hidden_states=True) # list: len is 25 # [25 layer, Time steps, 1024 feature_dim] # all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze() # mert_features.append(all_layer_hidden_states) feature = outputs.hidden_states[ hps.mert_feature_layer ].squeeze() # [1, frame len, 1024] -> [frame len, 1024] return feature.cpu().detach().numpy() def mert_features_normalization(raw_mert_features): normalized_mert_features = list() mert_features = np.array(raw_mert_features) scaler = StandardScaler().fit(mert_features) for raw_mert_feature in raw_mert_feature: normalized_mert_feature = scaler.transform(raw_mert_feature) normalized_mert_features.append(normalized_mert_feature) return normalized_mert_features def get_mapped_mert_features(raw_mert_features, mapping_features, fast_mapping=True): source_hop = 320 target_hop = 256 factor = np.gcd(source_hop, target_hop) source_hop //= factor target_hop //= factor print( "Mapping source's {} frames => target's {} frames".format( target_hop, source_hop ) ) mert_features = [] for index, mapping_feat in enumerate(tqdm(mapping_features)): # mapping_feat: (mels_frame_len, n_mels) target_len = mapping_feat.shape[0] # (frame_len, 1024) raw_feats = raw_mert_features[index].cpu().numpy() source_len, width = raw_feats.shape # const ~= target_len * target_hop const = source_len * source_hop // target_hop * target_hop # (source_len * source_hop, dim) up_sampling_feats = np.repeat(raw_feats, source_hop, axis=0) # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim) down_sampling_feats = np.average( up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 ) err = abs(target_len - len(down_sampling_feats)) if err > 3: print("index:", index) print("mels:", mapping_feat.shape) print("raw mert vector:", raw_feats.shape) print("up_sampling:", up_sampling_feats.shape) print("const:", const) print("down_sampling_feats:", down_sampling_feats.shape) exit() if len(down_sampling_feats) < target_len: # (1, dim) -> (err, dim) end = down_sampling_feats[-1][None, :].repeat(err, axis=0) down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0) # (target_len, dim) feats = down_sampling_feats[:target_len] mert_features.append(feats) return mert_features def load_mert_model(hps): print("Loading MERT Model: ", hps.mert_model) # Load model model_name = hps.mert_model model = AutoModel.from_pretrained(model_name, trust_remote_code=True) if torch.cuda.is_available(): model = model.cuda() # model = model.eval() preprocessor = Wav2Vec2FeatureExtractor.from_pretrained( model_name, trust_remote_code=True ) return model, preprocessor # loading the corresponding preprocessor config # def load_preprocessor (model_name="m-a-p/MERT-v1-330M"): # print('load_preprocessor...') # preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(model_name,trust_remote_code=True) # return preprocessor