nx_denoise / examples /mpnet /step_1_prepare_data.py
HoneyTian's picture
update
b27ed9f
raw
history blame
5.56 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os
from pathlib import Path
import random
import sys
import shutil
pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))
import pandas as pd
from scipy.io import wavfile
from tqdm import tqdm
import librosa
from project_settings import project_path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--file_dir", default="./", type=str)
parser.add_argument(
"--noise_dir",
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
type=str
)
parser.add_argument(
"--speech_dir",
default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
type=str
)
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
parser.add_argument("--duration", default=2.0, type=float)
parser.add_argument("--min_snr_db", default=-10, type=float)
parser.add_argument("--max_snr_db", default=20, type=float)
parser.add_argument("--target_sample_rate", default=8000, type=int)
parser.add_argument("--scale", default=1, type=float)
args = parser.parse_args()
return args
def filename_generator(data_dir: str):
data_dir = Path(data_dir)
for filename in data_dir.glob("**/*.wav"):
yield filename.as_posix()
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
data_dir = Path(data_dir)
for filename in data_dir.glob("**/*.wav"):
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
if raw_duration < duration:
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
continue
if signal.ndim != 1:
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
signal_length = len(signal)
win_size = int(duration * sample_rate)
for begin in range(0, signal_length - win_size, win_size):
row = {
"filename": filename.as_posix(),
"raw_duration": round(raw_duration, 4),
"offset": round(begin / sample_rate, 4),
"duration": round(duration, 4),
}
yield row
def get_dataset(args):
file_dir = Path(args.file_dir)
file_dir.mkdir(exist_ok=True)
noise_dir = Path(args.noise_dir)
speech_dir = Path(args.speech_dir)
noise_generator = target_second_signal_generator(
noise_dir.as_posix(),
duration=args.duration,
sample_rate=args.target_sample_rate
)
speech_generator = target_second_signal_generator(
speech_dir.as_posix(),
duration=args.duration,
sample_rate=args.target_sample_rate
)
dataset = list()
count = 0
process_bar = tqdm(desc="build dataset excel")
for noise, speech in zip(noise_generator, speech_generator):
flag = random.random()
if flag > args.scale:
continue
noise_filename = noise["filename"]
noise_raw_duration = noise["raw_duration"]
noise_offset = noise["offset"]
noise_duration = noise["duration"]
speech_filename = speech["filename"]
speech_raw_duration = speech["raw_duration"]
speech_offset = speech["offset"]
speech_duration = speech["duration"]
random1 = random.random()
random2 = random.random()
row = {
"noise_filename": noise_filename,
"noise_raw_duration": noise_raw_duration,
"noise_offset": noise_offset,
"noise_duration": noise_duration,
"speech_filename": speech_filename,
"speech_raw_duration": speech_raw_duration,
"speech_offset": speech_offset,
"speech_duration": speech_duration,
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
"random1": random1,
"random2": random2,
"flag": "TRAIN" if random2 < 0.8 else "TEST",
}
dataset.append(row)
count += 1
duration_seconds = count * args.duration
duration_hours = duration_seconds / 3600
process_bar.update(n=1)
process_bar.set_postfix({
# "duration_seconds": round(duration_seconds, 4),
"duration_hours": round(duration_hours, 4),
})
dataset = pd.DataFrame(dataset)
dataset = dataset.sort_values(by=["random1"], ascending=False)
dataset.to_excel(
file_dir / "dataset.xlsx",
index=False,
)
return
def split_dataset(args):
"""分割训练集, 测试集"""
file_dir = Path(args.file_dir)
file_dir.mkdir(exist_ok=True)
df = pd.read_excel(file_dir / "dataset.xlsx")
train = list()
test = list()
for i, row in df.iterrows():
flag = row["flag"]
if flag == "TRAIN":
train.append(row)
else:
test.append(row)
train = pd.DataFrame(train)
train.to_excel(
args.train_dataset,
index=False,
# encoding="utf_8_sig"
)
test = pd.DataFrame(test)
test.to_excel(
args.valid_dataset,
index=False,
# encoding="utf_8_sig"
)
return
def main():
args = get_args()
get_dataset(args)
split_dataset(args)
return
if __name__ == "__main__":
main()