|
import argparse |
|
import os |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from omegaconf import OmegaConf |
|
from pydub import AudioSegment |
|
from tqdm import trange |
|
from transformers import ( |
|
AutoFeatureExtractor, |
|
BertForSequenceClassification, |
|
BertJapaneseTokenizer, |
|
Wav2Vec2ForXVector, |
|
) |
|
|
|
|
|
class Embeder: |
|
def __init__(self, config): |
|
self.config = OmegaConf.load(config) |
|
self.df = pd.read_csv(config.path_csv) |
|
self.audio_feature_extractor = AutoFeatureExtractor.from_pretrained( |
|
"anton-l/wav2vec2-base-superb-sv" |
|
) |
|
self.audio_model = Wav2Vec2ForXVector.from_pretrained( |
|
"anton-l/wav2vec2-base-superb-sv" |
|
) |
|
self.text_tokenizer = BertJapaneseTokenizer.from_pretrained( |
|
"cl-tohoku/bert-base-japanese-whole-word-masking" |
|
) |
|
self.text_model = BertForSequenceClassification.from_pretrained( |
|
"cl-tohoku/bert-base-japanese-whole-word-masking", |
|
num_labels=2, |
|
output_attentions=False, |
|
output_hidden_states=True, |
|
).eval() |
|
|
|
def run(self): |
|
self._create_audio_embed() |
|
self._create_text_embed() |
|
|
|
def _create_audio_embed(self): |
|
audio_embed = None |
|
idx = [] |
|
for i in trange(len(self.df)): |
|
audio = [] |
|
song = AudioSegment.from_wav( |
|
os.path.join( |
|
self.config.path_data, |
|
"new_" + self.df.iloc[i]["filename"].replace(".mp3", ".wav"), |
|
) |
|
) |
|
song = np.array(song.get_array_of_samples(), dtype="float") |
|
audio.append(song) |
|
inputs = self.audio_feature_extractor( |
|
audio, |
|
sampling_rate=self.config.sample_rate, |
|
return_tensors="pt", |
|
padding=True, |
|
) |
|
try: |
|
with torch.no_grad(): |
|
embeddings = self.audio_model(**inputs).embeddings |
|
audio_embed = ( |
|
embeddings |
|
if audio_embed is None |
|
else torch.concatenate([audio_embed, embeddings]) |
|
) |
|
except Exception: |
|
idx.append(i) |
|
|
|
audio_embed = torch.nn.functional.normalize(audio_embed, dim=-1).cpu() |
|
self.clean_and_save_data(audio_embed, idx) |
|
self.df = self.df.drop(index=idx) |
|
self.df.to_csv(self.config.path_csv, index=False) |
|
|
|
def _create_text_embed(self): |
|
text_embed = None |
|
for i in range(len(self.df)): |
|
sentence = self.df.iloc[i]["filename"].replace(".mp3", "") |
|
tokenized_text = self.text_tokenizer.tokenize(sentence) |
|
indexed_tokens = self.text_tokenizer.convert_tokens_to_ids(tokenized_text) |
|
tokens_tensor = torch.tensor([indexed_tokens]) |
|
with torch.no_grad(): |
|
all_encoder_layers = self.text_model(tokens_tensor) |
|
embedding = torch.mean(all_encoder_layers[1][-2][0], axis=0).reshape(1, -1) |
|
text_embed = ( |
|
embedding |
|
if text_embed is None |
|
else torch.concatenate([text_embed, embedding]) |
|
) |
|
text_embed = torch.nn.functional.normalize(text_embed, dim=-1).cpu() |
|
torch.save(text_embed, self.config.path_text_embedding) |
|
|
|
def clean_and_save_data(self, audio_embed, idx): |
|
clean_embed = None |
|
for i in range(1, len(audio_embed)): |
|
if i in idx: |
|
continue |
|
else: |
|
clean_embed = ( |
|
audio_embed[i].reshape(1, -1) |
|
if clean_embed is None |
|
else torch.concatenate([clean_embed, audio_embed[i].reshape(1, -1)]) |
|
) |
|
torch.save(clean_embed, self.config.path_audio_embedding) |
|
|
|
|
|
def argparser(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-c", |
|
"--config", |
|
type=str, |
|
default="config.yaml", |
|
help="File path for config file.", |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
if __name__ == "__main__": |
|
args = argparser() |
|
embeder = Embeder(args.config) |
|
embeder.run() |
|
|