AudioSep / pipeline.py
badayvedat's picture
Initial commit
ae29df4
raw
history blame
2.02 kB
import yaml
from typing import Dict, List
import torch
import torch.nn as nn
import numpy as np
import librosa
from scipy.io.wavfile import write
from utils import ignore_warnings; ignore_warnings()
from utils import parse_yaml, load_ss_model
from models.clap_encoder import CLAP_Encoder
def build_audiosep(config_yaml, checkpoint_path, device):
configs = parse_yaml(config_yaml)
query_encoder = CLAP_Encoder().eval()
model = load_ss_model(
configs=configs,
checkpoint_path=checkpoint_path,
query_encoder=query_encoder
).eval().to(device)
print(f'Load AudioSep model from [{checkpoint_path}]')
return model
def inference(model, audio_file, text, output_file, device='cuda'):
print(f'Separate audio from [{audio_file}] with textual query [{text}]')
mixture, fs = librosa.load(audio_file, sr=32000, mono=True)
with torch.no_grad():
text = [text]
conditions = model.query_encoder.get_query_embed(
modality='text',
text=text,
device=device
)
input_dict = {
"mixture": torch.Tensor(mixture)[None, None, :].to(device),
"condition": conditions,
}
sep_segment = model.ss_model(input_dict)["waveform"]
sep_segment = sep_segment.squeeze(0).squeeze(0).data.cpu().numpy()
write(output_file, 32000, np.round(sep_segment * 32767).astype(np.int16))
print(f'Write separated audio to [{output_file}]')
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = build_audiosep(
config_yaml='config/audiosep_base.yaml',
checkpoint_path='checkpoint/step=3920000.ckpt',
device=device)
audio_file = '/mnt/bn/data-xubo/project/AudioShop/YT_audios/Y3VHpLxtd498.wav'
text = 'pigeons are cooing in the background'
output_file='separated_audio.wav'
inference(model, audio_file, text, output_file, device)