HoneyTian commited on
Commit
f74ae8e
·
1 Parent(s): 7d18e1c
examples/mpnet_aishell/run.sh ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+
6
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir \
7
+ --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
8
+ --speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train"
9
+
10
+
11
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \
12
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
+
15
+ sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir \
16
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
+
19
+
20
+ END
21
+
22
+
23
+ # params
24
+ system_version="windows";
25
+ verbose=true;
26
+ stage=0 # start from 0 if you need to start from data preparation
27
+ stop_stage=9
28
+
29
+ work_dir="$(pwd)"
30
+ file_folder_name=file_folder_name
31
+ final_model_name=final_model_name
32
+ config_file="yaml/config.yaml"
33
+ limit=10
34
+
35
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
36
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
37
+
38
+ nohup_name=nohup.out
39
+
40
+ # model params
41
+ batch_size=64
42
+ max_epochs=200
43
+ save_top_k=10
44
+ patience=5
45
+
46
+
47
+ # parse options
48
+ while true; do
49
+ [ -z "${1:-}" ] && break; # break if there are no arguments
50
+ case "$1" in
51
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
52
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
53
+ old_value="(eval echo \\$$name)";
54
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
55
+ was_bool=true;
56
+ else
57
+ was_bool=false;
58
+ fi
59
+
60
+ # Set the variable to the right value-- the escaped quotes make it work if
61
+ # the option had spaces, like --cmd "queue.pl -sync y"
62
+ eval "${name}=\"$2\"";
63
+
64
+ # Check that Boolean-valued arguments are really Boolean.
65
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
66
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
67
+ exit 1;
68
+ fi
69
+ shift 2;
70
+ ;;
71
+
72
+ *) break;
73
+ esac
74
+ done
75
+
76
+ file_dir="${work_dir}/${file_folder_name}"
77
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
78
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
79
+
80
+ dataset="${file_dir}/dataset.xlsx"
81
+ train_dataset="${file_dir}/train.xlsx"
82
+ valid_dataset="${file_dir}/valid.xlsx"
83
+
84
+ $verbose && echo "system_version: ${system_version}"
85
+ $verbose && echo "file_folder_name: ${file_folder_name}"
86
+
87
+ if [ $system_version == "windows" ]; then
88
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
89
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
90
+ #source /data/local/bin/nx_denoise/bin/activate
91
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
92
+ fi
93
+
94
+
95
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
96
+ $verbose && echo "stage 1: prepare data"
97
+ cd "${work_dir}" || exit 1
98
+ python3 step_1_prepare_data.py \
99
+ --file_dir "${file_dir}" \
100
+ --noise_dir "${noise_dir}" \
101
+ --speech_dir "${speech_dir}" \
102
+ --train_dataset "${train_dataset}" \
103
+ --valid_dataset "${valid_dataset}" \
104
+
105
+ fi
106
+
107
+
108
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
109
+ $verbose && echo "stage 2: train model"
110
+ cd "${work_dir}" || exit 1
111
+ python3 step_2_train_model.py \
112
+ --train_dataset "${train_dataset}" \
113
+ --valid_dataset "${valid_dataset}" \
114
+ --serialization_dir "${file_dir}" \
115
+ --config_file "${config_file}" \
116
+
117
+ fi
118
+
119
+
120
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
121
+ $verbose && echo "stage 3: test model"
122
+ cd "${work_dir}" || exit 1
123
+ python3 step_3_evaluation.py \
124
+ --valid_dataset "${valid_dataset}" \
125
+ --model_dir "${file_dir}/best" \
126
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
127
+ --limit "${limit}" \
128
+
129
+ fi
130
+
131
+
132
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
133
+ $verbose && echo "stage 4: export model"
134
+ cd "${work_dir}" || exit 1
135
+ python3 step_5_export_models.py \
136
+ --vocabulary_dir "${vocabulary_dir}" \
137
+ --model_dir "${file_dir}/best" \
138
+ --serialization_dir "${file_dir}" \
139
+
140
+ fi
141
+
142
+
143
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
144
+ $verbose && echo "stage 5: collect files"
145
+ cd "${work_dir}" || exit 1
146
+
147
+ mkdir -p ${final_model_dir}
148
+
149
+ cp "${file_dir}/best"/* "${final_model_dir}"
150
+ cp -r "${file_dir}/vocabulary" "${final_model_dir}"
151
+
152
+ cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
153
+
154
+ cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
155
+ cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
156
+ cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
157
+ cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
158
+
159
+ cd "${final_model_dir}/.." || exit 1;
160
+
161
+ if [ -e "${final_model_name}.zip" ]; then
162
+ rm -rf "${final_model_name}_backup.zip"
163
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
164
+ fi
165
+
166
+ zip -r "${final_model_name}.zip" "${final_model_name}"
167
+ rm -rf "${final_model_name}"
168
+
169
+ fi
170
+
171
+
172
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
173
+ $verbose && echo "stage 6: clear file_dir"
174
+ cd "${work_dir}" || exit 1
175
+
176
+ rm -rf "${file_dir}";
177
+
178
+ fi
examples/mpnet_aishell/step_1_prepare_data.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import random
7
+ import sys
8
+ import shutil
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import pandas as pd
14
+ from scipy.io import wavfile
15
+ from tqdm import tqdm
16
+ import librosa
17
+
18
+ from project_settings import project_path
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--file_dir", default="./", type=str)
24
+
25
+ parser.add_argument(
26
+ "--noise_dir",
27
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
28
+ type=str
29
+ )
30
+ parser.add_argument(
31
+ "--speech_dir",
32
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
33
+ type=str
34
+ )
35
+
36
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
+
39
+ parser.add_argument("--duration", default=2.0, type=float)
40
+ parser.add_argument("--min_snr_db", default=-10, type=float)
41
+ parser.add_argument("--max_snr_db", default=20, type=float)
42
+
43
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
44
+
45
+ args = parser.parse_args()
46
+ return args
47
+
48
+
49
+ def filename_generator(data_dir: str):
50
+ data_dir = Path(data_dir)
51
+ for filename in data_dir.glob("**/*.wav"):
52
+ yield filename.as_posix()
53
+
54
+
55
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
56
+ data_dir = Path(data_dir)
57
+ for filename in data_dir.glob("**/*.wav"):
58
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
+
61
+ if raw_duration < duration:
62
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
+ continue
64
+ if signal.ndim != 1:
65
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
+
67
+ signal_length = len(signal)
68
+ win_size = int(duration * sample_rate)
69
+ for begin in range(0, signal_length - win_size, win_size):
70
+ row = {
71
+ "filename": filename.as_posix(),
72
+ "raw_duration": round(raw_duration, 4),
73
+ "offset": round(begin / sample_rate, 4),
74
+ "duration": round(duration, 4),
75
+ }
76
+ yield row
77
+
78
+
79
+ def get_dataset(args):
80
+ file_dir = Path(args.file_dir)
81
+ file_dir.mkdir(exist_ok=True)
82
+
83
+ noise_dir = Path(args.noise_dir)
84
+ speech_dir = Path(args.speech_dir)
85
+
86
+ noise_generator = target_second_signal_generator(
87
+ noise_dir.as_posix(),
88
+ duration=args.duration,
89
+ sample_rate=args.target_sample_rate
90
+ )
91
+ speech_generator = target_second_signal_generator(
92
+ speech_dir.as_posix(),
93
+ duration=args.duration,
94
+ sample_rate=args.target_sample_rate
95
+ )
96
+
97
+ dataset = list()
98
+
99
+ count = 0
100
+ process_bar = tqdm(desc="build dataset excel")
101
+ for noise, speech in zip(noise_generator, speech_generator):
102
+
103
+ noise_filename = noise["filename"]
104
+ noise_raw_duration = noise["raw_duration"]
105
+ noise_offset = noise["offset"]
106
+ noise_duration = noise["duration"]
107
+
108
+ speech_filename = speech["filename"]
109
+ speech_raw_duration = speech["raw_duration"]
110
+ speech_offset = speech["offset"]
111
+ speech_duration = speech["duration"]
112
+
113
+ random1 = random.random()
114
+ random2 = random.random()
115
+
116
+ row = {
117
+ "noise_filename": noise_filename,
118
+ "noise_raw_duration": noise_raw_duration,
119
+ "noise_offset": noise_offset,
120
+ "noise_duration": noise_duration,
121
+
122
+ "speech_filename": speech_filename,
123
+ "speech_raw_duration": speech_raw_duration,
124
+ "speech_offset": speech_offset,
125
+ "speech_duration": speech_duration,
126
+
127
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
128
+
129
+ "random1": random1,
130
+ "random2": random2,
131
+ "flag": "TRAIN" if random2 < 0.8 else "TEST",
132
+ }
133
+ dataset.append(row)
134
+ count += 1
135
+ duration_seconds = count * args.duration
136
+ duration_hours = duration_seconds / 3600
137
+
138
+ process_bar.update(n=1)
139
+ process_bar.set_postfix({
140
+ # "duration_seconds": round(duration_seconds, 4),
141
+ "duration_hours": round(duration_hours, 4),
142
+
143
+ })
144
+
145
+ dataset = pd.DataFrame(dataset)
146
+ dataset = dataset.sort_values(by=["random1"], ascending=False)
147
+ dataset.to_excel(
148
+ file_dir / "dataset.xlsx",
149
+ index=False,
150
+ )
151
+ return
152
+
153
+
154
+
155
+ def split_dataset(args):
156
+ """分割训练集, 测试集"""
157
+ file_dir = Path(args.file_dir)
158
+ file_dir.mkdir(exist_ok=True)
159
+
160
+ df = pd.read_excel(file_dir / "dataset.xlsx")
161
+
162
+ train = list()
163
+ test = list()
164
+
165
+ for i, row in df.iterrows():
166
+ flag = row["flag"]
167
+ if flag == "TRAIN":
168
+ train.append(row)
169
+ else:
170
+ test.append(row)
171
+
172
+ train = pd.DataFrame(train)
173
+ train.to_excel(
174
+ args.train_dataset,
175
+ index=False,
176
+ # encoding="utf_8_sig"
177
+ )
178
+ test = pd.DataFrame(test)
179
+ test.to_excel(
180
+ args.valid_dataset,
181
+ index=False,
182
+ # encoding="utf_8_sig"
183
+ )
184
+
185
+ return
186
+
187
+
188
+ def main():
189
+ args = get_args()
190
+
191
+ get_dataset(args)
192
+ split_dataset(args)
193
+ return
194
+
195
+
196
+ if __name__ == "__main__":
197
+ main()
examples/mpnet_aishell/step_2_train_model.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/yxlu-0102/MP-SENet/blob/main/train.py
5
+ """
6
+ import argparse
7
+ import json
8
+ import logging
9
+ from logging.handlers import TimedRotatingFileHandler
10
+ import os
11
+ import platform
12
+ from pathlib import Path
13
+ import random
14
+ import sys
15
+ import shutil
16
+ from typing import List
17
+
18
+ pwd = os.path.abspath(os.path.dirname(__file__))
19
+ sys.path.append(os.path.join(pwd, "../../"))
20
+
21
+ import numpy as np
22
+ import torch
23
+ from torch.distributed import init_process_group
24
+ import torch.multiprocessing as mp
25
+ from torch.nn.parallel import DistributedDataParallel
26
+ import torch.nn as nn
27
+ from torch.nn import functional as F
28
+ from torch.utils.data import DistributedSampler
29
+ from torch.utils.data.dataloader import DataLoader
30
+ import torchaudio
31
+ from tqdm import tqdm
32
+
33
+ from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
34
+ from toolbox.torchaudio.models.mpnet.configuation_mpnet import MPNetConfig
35
+ from toolbox.torchaudio.models.mpnet.discriminator import MetricDiscriminator, batch_pesq
36
+ from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNet, MPNetPretrainedModel, phase_losses, pesq_score
37
+ from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
38
+
39
+
40
+ def get_args():
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
43
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
44
+
45
+ parser.add_argument("--max_epochs", default=100, type=int)
46
+
47
+ parser.add_argument("--batch_size", default=64, type=int)
48
+ parser.add_argument("--learning_rate", default=1e-4, type=float)
49
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
50
+ parser.add_argument("--patience", default=5, type=int)
51
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
52
+ parser.add_argument("--seed", default=0, type=int)
53
+
54
+ parser.add_argument("--config_file", default="config.yaml", type=str)
55
+
56
+ args = parser.parse_args()
57
+ return args
58
+
59
+
60
+ def logging_config(file_dir: str):
61
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
62
+
63
+ logging.basicConfig(format=fmt,
64
+ datefmt="%m/%d/%Y %H:%M:%S",
65
+ level=logging.INFO)
66
+ file_handler = TimedRotatingFileHandler(
67
+ filename=os.path.join(file_dir, "main.log"),
68
+ encoding="utf-8",
69
+ when="D",
70
+ interval=1,
71
+ backupCount=7
72
+ )
73
+ file_handler.setLevel(logging.INFO)
74
+ file_handler.setFormatter(logging.Formatter(fmt))
75
+ logger = logging.getLogger(__name__)
76
+ logger.addHandler(file_handler)
77
+
78
+ return logger
79
+
80
+
81
+ class CollateFunction(object):
82
+ def __init__(self,
83
+ n_fft: int = 512,
84
+ win_length: int = 200,
85
+ hop_length: int = 80,
86
+ window_fn: str = "hamming",
87
+ irm_beta: float = 1.0,
88
+ epsilon: float = 1e-8,
89
+ ):
90
+ self.n_fft = n_fft
91
+ self.win_length = win_length
92
+ self.hop_length = hop_length
93
+ self.window_fn = window_fn
94
+ self.irm_beta = irm_beta
95
+ self.epsilon = epsilon
96
+
97
+ self.transform = torchaudio.transforms.Spectrogram(
98
+ n_fft=self.n_fft,
99
+ win_length=self.win_length,
100
+ hop_length=self.hop_length,
101
+ power=2.0,
102
+ window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
103
+ )
104
+
105
+ @staticmethod
106
+ def make_unfold_snr_db(x: torch.Tensor, n_time_steps: int = 3):
107
+ batch_size, channels, freq_dim, time_steps = x.shape
108
+
109
+ # kernel: [freq_dim, n_time_step]
110
+ kernel_size = (freq_dim, n_time_steps)
111
+
112
+ # pad
113
+ pad = n_time_steps // 2
114
+ x = torch.concat(tensors=[
115
+ x[:, :, :, :pad],
116
+ x,
117
+ x[:, :, :, -pad:],
118
+ ], dim=-1)
119
+
120
+ x = F.unfold(
121
+ input=x,
122
+ kernel_size=kernel_size,
123
+ )
124
+ # x shape: [batch_size, fold, time_steps]
125
+ return x
126
+
127
+ def __call__(self, batch: List[dict]):
128
+ mix_spec_list = list()
129
+ speech_irm_list = list()
130
+ snr_db_list = list()
131
+ for sample in batch:
132
+ noise_wave: torch.Tensor = sample["noise_wave"]
133
+ speech_wave: torch.Tensor = sample["speech_wave"]
134
+ mix_wave: torch.Tensor = sample["mix_wave"]
135
+ # snr_db: float = sample["snr_db"]
136
+
137
+ noise_spec = self.transform.forward(noise_wave)
138
+ speech_spec = self.transform.forward(speech_wave)
139
+ mix_spec = self.transform.forward(mix_wave)
140
+
141
+ # noise_irm = noise_spec / (noise_spec + speech_spec)
142
+ speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon)
143
+ speech_irm = torch.pow(speech_irm, self.irm_beta)
144
+
145
+ # noise_spec, speech_spec, mix_spec, speech_irm
146
+ # shape: [freq_dim, time_steps]
147
+
148
+ snr_db: torch.Tensor = 10 * torch.log10(
149
+ speech_spec / (noise_spec + self.epsilon)
150
+ )
151
+ snr_db = torch.clamp(snr_db, min=self.epsilon)
152
+
153
+ snr_db_ = torch.unsqueeze(snr_db, dim=0)
154
+ snr_db_ = torch.unsqueeze(snr_db_, dim=0)
155
+ snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3)
156
+ snr_db_ = torch.squeeze(snr_db_, dim=0)
157
+ # snr_db_ shape: [fold, time_steps]
158
+
159
+ snr_db = torch.mean(snr_db_, dim=0, keepdim=True)
160
+ # snr_db shape: [1, time_steps]
161
+
162
+ mix_spec_list.append(mix_spec)
163
+ speech_irm_list.append(speech_irm)
164
+ snr_db_list.append(snr_db)
165
+
166
+ mix_spec_list = torch.stack(mix_spec_list)
167
+ speech_irm_list = torch.stack(speech_irm_list)
168
+ snr_db_list = torch.stack(snr_db_list) # shape: (batch_size, time_steps, 1)
169
+
170
+ mix_spec_list = mix_spec_list[:, :-1, :]
171
+ speech_irm_list = speech_irm_list[:, :-1, :]
172
+
173
+ # mix_spec_list shape: [batch_size, freq_dim, time_steps]
174
+ # speech_irm_list shape: [batch_size, freq_dim, time_steps]
175
+ # snr_db shape: [batch_size, 1, time_steps]
176
+
177
+ # assert
178
+ if torch.any(torch.isnan(mix_spec_list)) or torch.any(torch.isinf(mix_spec_list)):
179
+ raise AssertionError("nan or inf in mix_spec_list")
180
+ if torch.any(torch.isnan(speech_irm_list)) or torch.any(torch.isinf(speech_irm_list)):
181
+ raise AssertionError("nan or inf in speech_irm_list")
182
+ if torch.any(torch.isnan(snr_db_list)) or torch.any(torch.isinf(snr_db_list)):
183
+ raise AssertionError("nan or inf in snr_db_list")
184
+
185
+ return mix_spec_list, speech_irm_list, snr_db_list
186
+
187
+
188
+ collate_fn = CollateFunction()
189
+
190
+
191
+ def main():
192
+ args = get_args()
193
+
194
+ config = MPNetConfig.from_pretrained(
195
+ pretrained_model_name_or_path=args.config_file,
196
+ )
197
+
198
+ serialization_dir = Path(args.serialization_dir)
199
+ serialization_dir.mkdir(parents=True, exist_ok=True)
200
+
201
+ logger = logging_config(serialization_dir)
202
+
203
+ random.seed(config.seed)
204
+ np.random.seed(config.seed)
205
+ torch.manual_seed(config.seed)
206
+ logger.info(f"set seed: {config.seed}")
207
+
208
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
209
+ n_gpu = torch.cuda.device_count()
210
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
211
+
212
+ # datasets
213
+ train_dataset = DenoiseExcelDataset(
214
+ excel_file=args.train_dataset,
215
+ expected_sample_rate=8000,
216
+ max_wave_value=32768.0,
217
+ )
218
+ valid_dataset = DenoiseExcelDataset(
219
+ excel_file=args.valid_dataset,
220
+ expected_sample_rate=8000,
221
+ max_wave_value=32768.0,
222
+ )
223
+ train_data_loader = DataLoader(
224
+ dataset=train_dataset,
225
+ batch_size=args.batch_size,
226
+ shuffle=True,
227
+ sampler=None,
228
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
229
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
230
+ collate_fn=collate_fn,
231
+ pin_memory=False,
232
+ # prefetch_factor=64,
233
+ )
234
+ valid_data_loader = DataLoader(
235
+ dataset=valid_dataset,
236
+ batch_size=args.batch_size,
237
+ shuffle=True,
238
+ sampler=None,
239
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
240
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
241
+ collate_fn=collate_fn,
242
+ pin_memory=False,
243
+ # prefetch_factor=64,
244
+ )
245
+
246
+ # models
247
+ logger.info(f"prepare models. config_file: {args.config_file}")
248
+ generator = MPNetPretrainedModel(config).to(device)
249
+ discriminator = MetricDiscriminator().to(device)
250
+
251
+ # optimizer
252
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
253
+ num_params = 0
254
+ for p in generator.parameters():
255
+ num_params += p.numel()
256
+ print("Total Parameters (generator): {:.3f}M".format(num_params/1e6))
257
+
258
+ optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
259
+ optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
260
+
261
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=-1)
262
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=-1)
263
+
264
+ # training loop
265
+ logger.info("training")
266
+ for idx_epoch in range(args.max_epochs):
267
+ generator.train()
268
+ discriminator.train()
269
+
270
+ total_loss_d = 0.
271
+ total_loss_g = 0.
272
+ total_batches = 0.
273
+ progress_bar = tqdm(
274
+ total=len(train_data_loader),
275
+ desc="Training; epoch: {}".format(idx_epoch),
276
+ )
277
+ for batch in train_data_loader:
278
+ clean_audio, noisy_audio = batch
279
+ clean_audio = torch.autograd.Variable(clean_audio.to(device, non_blocking=True))
280
+ noisy_audio = torch.autograd.Variable(noisy_audio.to(device, non_blocking=True))
281
+ one_labels = torch.ones(config.batch_size).to(device, non_blocking=True)
282
+
283
+ clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
284
+ noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
285
+
286
+ mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha)
287
+
288
+ audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
289
+ mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
290
+
291
+ audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy())
292
+ batch_pesq_score = batch_pesq(audio_list_r, audio_list_g)
293
+
294
+ # Discriminator
295
+ optim_d.zero_grad()
296
+ metric_r = discriminator.forward(clean_mag, clean_mag)
297
+ metric_g = discriminator.forward(clean_mag, mag_g_hat.detach())
298
+ loss_disc_r = F.mse_loss(one_labels, metric_r.flatten())
299
+
300
+ if batch_pesq_score is not None:
301
+ loss_disc_g = F.mse_loss(batch_pesq_score.to(device), metric_g.flatten())
302
+ else:
303
+ print("pesq is None!")
304
+ loss_disc_g = 0
305
+
306
+ loss_disc_all = loss_disc_r + loss_disc_g
307
+ loss_disc_all.backward()
308
+ optim_d.step()
309
+
310
+ # Generator
311
+ optim_g.zero_grad()
312
+ # L2 Magnitude Loss
313
+ loss_mag = F.mse_loss(clean_mag, mag_g)
314
+ # Anti-wrapping Phase Loss
315
+ loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g)
316
+ loss_pha = loss_ip + loss_gd + loss_iaf
317
+ # L2 Complex Loss
318
+ loss_com = F.mse_loss(clean_com, com_g) * 2
319
+ # L2 Consistency Loss
320
+ loss_stft = F.mse_loss(com_g, com_g_hat) * 2
321
+ # Time Loss
322
+ loss_time = F.l1_loss(clean_audio, audio_g)
323
+ # Metric Loss
324
+ metric_g = discriminator.forward(clean_mag, mag_g_hat)
325
+ loss_metric = F.mse_loss(metric_g.flatten(), one_labels)
326
+
327
+ loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_stft * 0.1 + loss_metric * 0.05 + loss_time * 0.2
328
+
329
+ loss_gen_all.backward()
330
+ optim_g.step()
331
+
332
+ total_loss_d += loss_disc_all.item()
333
+ total_loss_g += loss_gen_all.item()
334
+ total_batches += 1
335
+
336
+ progress_bar.update(1)
337
+ progress_bar.set_postfix({
338
+ "loss_d": round(total_loss_d / total_batches, 4),
339
+ "loss_g": round(total_loss_g / total_batches, 4),
340
+ })
341
+
342
+ generator.eval()
343
+ torch.cuda.empty_cache()
344
+ total_pesq_score = 0.
345
+ total_mag_err = 0.
346
+ total_pha_err = 0.
347
+ total_com_err = 0.
348
+ total_stft_err = 0.
349
+ total_batches = 0.
350
+
351
+ progress_bar = tqdm(
352
+ total=len(valid_data_loader),
353
+ desc="Evaluation; epoch: {}".format(idx_epoch),
354
+ )
355
+ with torch.no_grad():
356
+ for batch in valid_data_loader:
357
+ clean_audio, noisy_audio = batch
358
+ clean_audio = torch.autograd.Variable(clean_audio.to(device, non_blocking=True))
359
+ noisy_audio = torch.autograd.Variable(noisy_audio.to(device, non_blocking=True))
360
+
361
+ clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
362
+ noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
363
+
364
+ mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha)
365
+
366
+ audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
367
+ mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
368
+
369
+ total_pesq_score += pesq_score(
370
+ torch.split(clean_audio, 1, dim=0),
371
+ torch.split(audio_g, 1, dim=0),
372
+ config
373
+ ).item()
374
+ total_mag_err += F.mse_loss(clean_mag, mag_g).item()
375
+ val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g)
376
+ total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item()
377
+ total_com_err += F.mse_loss(clean_com, com_g).item()
378
+ total_stft_err += F.mse_loss(com_g, com_g_hat).item()
379
+
380
+ total_batches += 1
381
+
382
+ progress_bar.update(1)
383
+ progress_bar.set_postfix({
384
+ "pesq_score": round(total_pesq_score / total_batches, 4),
385
+ "mag_err": round(total_mag_err / total_batches, 4),
386
+ "pha_err": round(total_pha_err / total_batches, 4),
387
+ "com_err": round(total_com_err / total_batches, 4),
388
+ "stft_err": round(total_stft_err / total_batches, 4),
389
+
390
+ })
391
+
392
+ return
393
+
394
+
395
+ if __name__ == '__main__':
396
+ main()
examples/mpnet_aishell/step_3_evaluation.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
examples/mpnet_aishell/yaml/config.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "mpnet"
2
+
3
+ num_gpus: 0
4
+ batch_size: 4
5
+ learning_rate: 0.0005
6
+ adam_b1: 0.8
7
+ adam_b2: 0.99
8
+ lr_decay: 0.99
9
+ seed: 1234
10
+
11
+ dense_channel: 64
12
+ compress_factor: 0.3
13
+ num_tsconformers: 4
14
+ beta: 2.0
15
+
16
+ sample_rate: 16000
17
+ segment_size: 32000
18
+ n_fft: 400
19
+ hop_size: 100
20
+ win_size: 400
21
+
22
+ num_workers: 4
23
+
24
+ dist_config:
25
+ dist_backend: nccl
26
+ dist_url: tcp://localhost:54321
27
+ world_size: 1
examples/spectrum_dfnet_aishell/step_3_evaluation.py CHANGED
@@ -255,7 +255,6 @@ def main():
255
  # speech_irm_prediction shape: [batch_size, freq_dim (256), time_steps]
256
  batch_size, _, time_steps = speech_irm_prediction.shape
257
 
258
-
259
  mix_spec_complex = torch.concat(
260
  [
261
  mix_spec_complex,
 
255
  # speech_irm_prediction shape: [batch_size, freq_dim (256), time_steps]
256
  batch_size, _, time_steps = speech_irm_prediction.shape
257
 
 
258
  mix_spec_complex = torch.concat(
259
  [
260
  mix_spec_complex,
requirements-python-3-9-9.txt CHANGED
@@ -11,3 +11,4 @@ overrides==7.7.0
11
  torch-pesq
12
  torchmetrics
13
  torchmetrics[audio]
 
 
11
  torch-pesq
12
  torchmetrics
13
  torchmetrics[audio]
14
+ einops
requirements.txt CHANGED
@@ -10,4 +10,5 @@ torchaudio==2.5.1
10
  overrides==7.7.0
11
  torch-pesq==0.1.2
12
  torchmetrics==1.6.1
13
- torchmetrics[audio]
 
 
10
  overrides==7.7.0
11
  torch-pesq==0.1.2
12
  torchmetrics==1.6.1
13
+ torchmetrics[audio]==1.6.1
14
+ einops==0.8.1
toolbox/torchaudio/models/mpnet/configuation_mpnet.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Tuple
4
+
5
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class MPNetConfig(PretrainedConfig):
9
+ """
10
+ https://github.com/yxlu-0102/MP-SENet/blob/main/config.json
11
+ """
12
+ def __init__(self,
13
+ num_gpus: int = 0,
14
+ batch_size: int = 4,
15
+ learning_rate: float = 0.0005,
16
+ adam_b1: float = 0.8,
17
+ adam_b2: float = 0.99,
18
+ lr_decay: float = 0.99,
19
+ seed: int = 1234,
20
+
21
+ dense_channel: int = 64,
22
+ compress_factor: float = 0.3,
23
+ num_tsconformers: int = 4,
24
+ beta: float = 2.0,
25
+
26
+ sample_rate: int = 16000,
27
+ segment_size: int = 32000,
28
+ n_fft: int = 400,
29
+ hop_size: int = 100,
30
+ win_size: int = 400,
31
+
32
+ num_workers: int = 4,
33
+
34
+ dist_config: dict = None,
35
+
36
+ **kwargs
37
+ ):
38
+ super(MPNetConfig, self).__init__(**kwargs)
39
+ self.num_gpus = num_gpus
40
+ self.batch_size = batch_size
41
+ self.learning_rate = learning_rate
42
+ self.adam_b1 = adam_b1
43
+ self.adam_b2 = adam_b2
44
+ self.lr_decay = lr_decay
45
+ self.seed = seed
46
+
47
+ self.dense_channel = dense_channel
48
+ self.compress_factor = compress_factor
49
+ self.num_tsconformers = num_tsconformers
50
+ self.beta = beta
51
+
52
+ self.sample_rate = sample_rate
53
+ self.segment_size = segment_size
54
+ self.n_fft = n_fft
55
+ self.hop_size = hop_size
56
+ self.win_size = win_size
57
+
58
+ self.num_workers = num_workers
59
+
60
+ self.dist_config = dist_config or {
61
+ "dist_backend": "nccl",
62
+ "dist_url": "tcp://localhost:54321",
63
+ "world_size": 1
64
+ }
65
+
66
+
67
+ if __name__ == "__main__":
68
+ pass
toolbox/torchaudio/models/mpnet/conformer.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from einops.layers.torch import Rearrange
4
+ import torch.nn as nn
5
+
6
+
7
+ def get_padding(kernel_size: int, dilation: int = 1):
8
+ return int((kernel_size * dilation - dilation) / 2)
9
+
10
+
11
+ class FeedForwardModule(nn.Module):
12
+ def __init__(self, dim, mult=4, dropout=0):
13
+ super(FeedForwardModule, self).__init__()
14
+ self.ffm = nn.Sequential(
15
+ nn.LayerNorm(dim),
16
+ nn.Linear(dim, dim * mult),
17
+ nn.SiLU(),
18
+ nn.Dropout(dropout),
19
+ nn.Linear(dim * mult, dim),
20
+ nn.Dropout(dropout)
21
+ )
22
+
23
+ def forward(self, x):
24
+ return self.ffm(x)
25
+
26
+
27
+ class ConformerConvModule(nn.Module):
28
+ def __init__(self, dim, expansion_factor=2, kernel_size=31, dropout=0.):
29
+ super(ConformerConvModule, self).__init__()
30
+ inner_dim = dim * expansion_factor
31
+ self.ccm = nn.Sequential(
32
+ nn.LayerNorm(dim),
33
+ Rearrange('b n c -> b c n'),
34
+ nn.Conv1d(dim, inner_dim*2, 1),
35
+ nn.GLU(dim=1),
36
+ nn.Conv1d(inner_dim, inner_dim, kernel_size=kernel_size,
37
+ padding=get_padding(kernel_size), groups=inner_dim), # DepthWiseConv1d
38
+ nn.BatchNorm1d(inner_dim),
39
+ nn.SiLU(),
40
+ nn.Conv1d(inner_dim, dim, 1),
41
+ Rearrange('b c n -> b n c'),
42
+ nn.Dropout(dropout)
43
+ )
44
+
45
+ def forward(self, x):
46
+ return self.ccm(x)
47
+
48
+
49
+ class AttentionModule(nn.Module):
50
+ def __init__(self, dim, n_head=8, dropout=0.):
51
+ super(AttentionModule, self).__init__()
52
+ self.attn = nn.MultiheadAttention(dim, n_head, dropout=dropout)
53
+ self.layernorm = nn.LayerNorm(dim)
54
+
55
+ def forward(self, x, attn_mask=None, key_padding_mask=None):
56
+ x = self.layernorm(x)
57
+ x, _ = self.attn(x, x, x,
58
+ attn_mask=attn_mask,
59
+ key_padding_mask=key_padding_mask)
60
+ return x
61
+
62
+
63
+ class ConformerBlock(nn.Module):
64
+ def __init__(self, dim, n_head=8, ffm_mult=4, ccm_expansion_factor=2, ccm_kernel_size=31,
65
+ ffm_dropout=0., attn_dropout=0., ccm_dropout=0.):
66
+ super(ConformerBlock, self).__init__()
67
+ self.ffm1 = FeedForwardModule(dim, ffm_mult, dropout=ffm_dropout)
68
+ self.attn = AttentionModule(dim, n_head, dropout=attn_dropout)
69
+ self.ccm = ConformerConvModule(dim, ccm_expansion_factor, ccm_kernel_size, dropout=ccm_dropout)
70
+ self.ffm2 = FeedForwardModule(dim, ffm_mult, dropout=ffm_dropout)
71
+ self.post_norm = nn.LayerNorm(dim)
72
+
73
+ def forward(self, x):
74
+ x = x + 0.5 * self.ffm1(x)
75
+ x = x + self.attn(x)
76
+ x = x + self.ccm(x)
77
+ x = x + 0.5 * self.ffm2(x)
78
+ x = self.post_norm(x)
79
+ return x
80
+
81
+
82
+ if __name__ == '__main__':
83
+ pass
toolbox/torchaudio/models/mpnet/discriminator.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+ from pesq import pesq
8
+ from joblib import Parallel, delayed
9
+
10
+ from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid1d
11
+
12
+
13
+ def cal_pesq(clean, noisy, sr=16000):
14
+ try:
15
+ pesq_score = pesq(sr, clean, noisy, 'wb')
16
+ except:
17
+ # error can happen due to silent period
18
+ pesq_score = -1
19
+ return pesq_score
20
+
21
+
22
+ def batch_pesq(clean, noisy):
23
+ pesq_score = Parallel(n_jobs=15)(delayed(cal_pesq)(c, n) for c, n in zip(clean, noisy))
24
+ pesq_score = np.array(pesq_score)
25
+ if -1 in pesq_score:
26
+ return None
27
+ pesq_score = (pesq_score - 1) / 3.5
28
+ return torch.FloatTensor(pesq_score)
29
+
30
+
31
+ def metric_loss(metric_ref, metrics_gen):
32
+ loss = 0
33
+ for metric_gen in metrics_gen:
34
+ metric_loss = F.mse_loss(metric_ref, metric_gen.flatten())
35
+ loss += metric_loss
36
+
37
+ return loss
38
+
39
+
40
+ class MetricDiscriminator(nn.Module):
41
+ def __init__(self, dim=16, in_channel=2):
42
+ super(MetricDiscriminator, self).__init__()
43
+ self.layers = nn.Sequential(
44
+ nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
45
+ nn.InstanceNorm2d(dim, affine=True),
46
+ nn.PReLU(dim),
47
+ nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
48
+ nn.InstanceNorm2d(dim*2, affine=True),
49
+ nn.PReLU(dim*2),
50
+ nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)),
51
+ nn.InstanceNorm2d(dim*4, affine=True),
52
+ nn.PReLU(dim*4),
53
+ nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)),
54
+ nn.InstanceNorm2d(dim*8, affine=True),
55
+ nn.PReLU(dim*8),
56
+ nn.AdaptiveMaxPool2d(1),
57
+ nn.Flatten(),
58
+ nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)),
59
+ nn.Dropout(0.3),
60
+ nn.PReLU(dim*4),
61
+ nn.utils.spectral_norm(nn.Linear(dim*4, 1)),
62
+ LearnableSigmoid1d(1)
63
+ )
64
+
65
+ def forward(self, x, y):
66
+ xy = torch.stack((x, y), dim=1)
67
+ return self.layers(xy)
68
+
69
+
70
+ if __name__ == '__main__':
71
+ pass
toolbox/torchaudio/models/mpnet/modeling_mpnet.py CHANGED
@@ -2,8 +2,295 @@
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://huggingface.co/spaces/LeeSangHoon/HierSpeech_TTS/blob/main/denoiser/generator.py
 
 
 
 
5
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  if __name__ == '__main__':
9
- pass
 
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://huggingface.co/spaces/LeeSangHoon/HierSpeech_TTS/blob/main/denoiser/generator.py
5
+
6
+ https://arxiv.org/abs/2305.13686
7
+ https://github.com/yxlu-0102/MP-SENet
8
+
9
  """
10
+ import os
11
+ from typing import Optional, Union
12
+
13
+ from pesq import pesq
14
+ from joblib import Parallel, delayed
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
20
+ from toolbox.torchaudio.models.mpnet.conformer import ConformerBlock
21
+ from toolbox.torchaudio.models.mpnet.transformers import TransformerBlock
22
+ from toolbox.torchaudio.models.mpnet.configuation_mpnet import MPNetConfig
23
+ from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid2d
24
+
25
+
26
+ class SPConvTranspose2d(nn.Module):
27
+ def __init__(self, in_channels, out_channels, kernel_size, r=1):
28
+ super(SPConvTranspose2d, self).__init__()
29
+ self.pad1 = nn.ConstantPad2d((1, 1, 0, 0), value=0.)
30
+ self.out_channels = out_channels
31
+ self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1))
32
+ self.r = r
33
+
34
+ def forward(self, x):
35
+ x = self.pad1(x)
36
+ out = self.conv(x)
37
+ batch_size, nchannels, H, W = out.shape
38
+ out = out.view((batch_size, self.r, nchannels // self.r, H, W))
39
+ out = out.permute(0, 2, 3, 4, 1)
40
+ out = out.contiguous().view((batch_size, nchannels // self.r, H, -1))
41
+ return out
42
+
43
+
44
+ class DenseBlock(nn.Module):
45
+ def __init__(self, h, kernel_size=(2, 3), depth=4):
46
+ super(DenseBlock, self).__init__()
47
+ self.h = h
48
+ self.depth = depth
49
+ self.dense_block = nn.ModuleList([])
50
+ for i in range(depth):
51
+ dilation = 2 ** i
52
+ pad_length = dilation
53
+ dense_conv = nn.Sequential(
54
+ nn.ConstantPad2d((1, 1, pad_length, 0), value=0.),
55
+ nn.Conv2d(h.dense_channel*(i+1), h.dense_channel, kernel_size, dilation=(dilation, 1)),
56
+ nn.InstanceNorm2d(h.dense_channel, affine=True),
57
+ nn.PReLU(h.dense_channel)
58
+ )
59
+ self.dense_block.append(dense_conv)
60
+
61
+ def forward(self, x):
62
+ skip = x
63
+ for i in range(self.depth):
64
+ x = self.dense_block[i](skip)
65
+ skip = torch.cat([x, skip], dim=1)
66
+ return x
67
+
68
+
69
+ class DenseEncoder(nn.Module):
70
+ def __init__(self, h, in_channel):
71
+ super(DenseEncoder, self).__init__()
72
+ self.h = h
73
+ self.dense_conv_1 = nn.Sequential(
74
+ nn.Conv2d(in_channel, h.dense_channel, (1, 1)),
75
+ nn.InstanceNorm2d(h.dense_channel, affine=True),
76
+ nn.PReLU(h.dense_channel))
77
+
78
+ self.dense_block = DenseBlock(h, depth=4)
79
+
80
+ self.dense_conv_2 = nn.Sequential(
81
+ nn.Conv2d(h.dense_channel, h.dense_channel, (1, 3), (1, 2), padding=(0, 1)),
82
+ nn.InstanceNorm2d(h.dense_channel, affine=True),
83
+ nn.PReLU(h.dense_channel))
84
+
85
+ def forward(self, x):
86
+ x = self.dense_conv_1(x) # [b, 64, T, F]
87
+ x = self.dense_block(x) # [b, 64, T, F]
88
+ x = self.dense_conv_2(x) # [b, 64, T, F//2]
89
+ return x
90
+
91
+
92
+ class MaskDecoder(nn.Module):
93
+ def __init__(self, h, out_channel=1):
94
+ super(MaskDecoder, self).__init__()
95
+ self.dense_block = DenseBlock(h, depth=4)
96
+ self.mask_conv = nn.Sequential(
97
+ SPConvTranspose2d(h.dense_channel, h.dense_channel, (1, 3), 2),
98
+ nn.InstanceNorm2d(h.dense_channel, affine=True),
99
+ nn.PReLU(h.dense_channel),
100
+ nn.Conv2d(h.dense_channel, out_channel, (1, 2))
101
+ )
102
+ self.lsigmoid = LearnableSigmoid2d(h.n_fft//2+1, beta=h.beta)
103
+
104
+ def forward(self, x):
105
+ x = self.dense_block(x)
106
+ x = self.mask_conv(x)
107
+ x = x.permute(0, 3, 2, 1).squeeze(-1) # [B, F, T]
108
+ x = self.lsigmoid(x)
109
+ return x
110
+
111
+
112
+ class PhaseDecoder(nn.Module):
113
+ def __init__(self, h, out_channel=1):
114
+ super(PhaseDecoder, self).__init__()
115
+ self.dense_block = DenseBlock(h, depth=4)
116
+ self.phase_conv = nn.Sequential(
117
+ SPConvTranspose2d(h.dense_channel, h.dense_channel, (1, 3), 2),
118
+ nn.InstanceNorm2d(h.dense_channel, affine=True),
119
+ nn.PReLU(h.dense_channel)
120
+ )
121
+ self.phase_conv_r = nn.Conv2d(h.dense_channel, out_channel, (1, 2))
122
+ self.phase_conv_i = nn.Conv2d(h.dense_channel, out_channel, (1, 2))
123
+
124
+ def forward(self, x):
125
+ x = self.dense_block(x)
126
+ x = self.phase_conv(x)
127
+ x_r = self.phase_conv_r(x)
128
+ x_i = self.phase_conv_i(x)
129
+ x = torch.atan2(x_i, x_r)
130
+ x = x.permute(0, 3, 2, 1).squeeze(-1) # [B, F, T]
131
+ return x
132
+
133
+
134
+ class TSTransformerBlock(nn.Module):
135
+ def __init__(self, h):
136
+ super(TSTransformerBlock, self).__init__()
137
+ self.h = h
138
+ self.time_transformer = TransformerBlock(d_model=h.dense_channel, n_heads=4)
139
+ self.freq_transformer = TransformerBlock(d_model=h.dense_channel, n_heads=4)
140
+
141
+ def forward(self, x):
142
+ b, c, t, f = x.size()
143
+ x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c)
144
+ x = self.time_transformer(x) + x
145
+ x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c)
146
+ x = self.freq_transformer(x) + x
147
+ x = x.view(b, t, f, c).permute(0, 3, 1, 2)
148
+ return x
149
+
150
+
151
+ class MPNet(nn.Module):
152
+ def __init__(self, config: MPNetConfig, num_tsblocks=4):
153
+ super(MPNet, self).__init__()
154
+ self.config = config
155
+ self.num_tscblocks = num_tsblocks
156
+ self.dense_encoder = DenseEncoder(config, in_channel=2)
157
+
158
+ self.TSTransformer = nn.ModuleList([])
159
+ for i in range(num_tsblocks):
160
+ self.TSTransformer.append(TSTransformerBlock(config))
161
+
162
+ self.mask_decoder = MaskDecoder(config, out_channel=1)
163
+ self.phase_decoder = PhaseDecoder(config, out_channel=1)
164
+
165
+ def forward(self, noisy_amp, noisy_pha): # [B, F, T]
166
+
167
+ x = torch.stack((noisy_amp, noisy_pha), dim=-1).permute(0, 3, 2, 1) # [B, 2, T, F]
168
+ x = self.dense_encoder(x)
169
+
170
+ for i in range(self.num_tscblocks):
171
+ x = self.TSTransformer[i](x)
172
+
173
+ denoised_amp = noisy_amp * self.mask_decoder(x)
174
+ denoised_pha = self.phase_decoder(x)
175
+ denoised_com = torch.stack(
176
+ tensors=(
177
+ denoised_amp * torch.cos(denoised_pha),
178
+ denoised_amp * torch.sin(denoised_pha)
179
+ ),
180
+ dim=-1
181
+ )
182
+
183
+ return denoised_amp, denoised_pha, denoised_com
184
+
185
+
186
+ MODEL_FILE = "model.pt"
187
+
188
+
189
+ class MPNetPretrainedModel(MPNet):
190
+ def __init__(self,
191
+ config: MPNetConfig,
192
+ ):
193
+ super(MPNetPretrainedModel, self).__init__(
194
+ config=config,
195
+ )
196
+
197
+ @classmethod
198
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
199
+ config = MPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
200
+
201
+ model = cls(config)
202
+
203
+ if os.path.isdir(pretrained_model_name_or_path):
204
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
205
+ else:
206
+ ckpt_file = pretrained_model_name_or_path
207
+
208
+ with open(ckpt_file, "rb") as f:
209
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
210
+ model.load_state_dict(state_dict, strict=True)
211
+ return model
212
+
213
+ def save_pretrained(self,
214
+ save_directory: Union[str, os.PathLike],
215
+ state_dict: Optional[dict] = None,
216
+ ):
217
+
218
+ model = self
219
+
220
+ if state_dict is None:
221
+ state_dict = model.state_dict()
222
+
223
+ os.makedirs(save_directory, exist_ok=True)
224
+
225
+ # save state dict
226
+ model_file = os.path.join(save_directory, MODEL_FILE)
227
+ torch.save(state_dict, model_file)
228
+
229
+ # save config
230
+ config_file = os.path.join(save_directory, CONFIG_FILE)
231
+ self.config.to_yaml_file(config_file)
232
+ return save_directory
233
+
234
+
235
+ def phase_losses(phase_r, phase_g):
236
+
237
+ ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
238
+ gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
239
+ iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
240
+
241
+ return ip_loss, gd_loss, iaf_loss
242
+
243
+
244
+ def anti_wrapping_function(x):
245
+
246
+ return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
247
+
248
+
249
+ def pesq_score(utts_r, utts_g, h):
250
+
251
+ pesq_score = Parallel(n_jobs=30)(delayed(eval_pesq)(
252
+ utts_r[i].squeeze().cpu().numpy(),
253
+ utts_g[i].squeeze().cpu().numpy(),
254
+ h.sampling_rate)
255
+ for i in range(len(utts_r)))
256
+ pesq_score = np.mean(pesq_score)
257
+
258
+ return pesq_score
259
+
260
+
261
+ def eval_pesq(clean_utt, esti_utt, sr):
262
+ try:
263
+ pesq_score = pesq(sr, clean_utt, esti_utt)
264
+ except:
265
+ pesq_score = -1
266
+
267
+ return pesq_score
268
+
269
+
270
+ def main():
271
+ import torchaudio
272
+
273
+ config = MPNetConfig()
274
+ model = MPNet(config=config)
275
+
276
+ transformer = torchaudio.transforms.Spectrogram(
277
+ n_fft=config.n_fft,
278
+ win_length=config.win_size,
279
+ hop_length=config.hop_size,
280
+ window_fn=torch.hamming_window,
281
+ )
282
+
283
+ inputs = torch.randn(size=(1, 32000), dtype=torch.float32)
284
+ spec = transformer.forward(inputs)
285
+ print(spec.shape)
286
+
287
+ denoised_amp, denoised_pha, denoised_com = model.forward(spec, spec)
288
+ print(denoised_amp.shape)
289
+ print(denoised_pha.shape)
290
+ print(denoised_com.shape)
291
+
292
+ return
293
 
294
 
295
  if __name__ == '__main__':
296
+ main()
toolbox/torchaudio/models/mpnet/transformers.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn import MultiheadAttention, GRU, Linear, LayerNorm, Dropout
7
+
8
+
9
+ class FFN(nn.Module):
10
+ def __init__(self, d_model, bidirectional=True, dropout=0):
11
+ super(FFN, self).__init__()
12
+ self.gru = GRU(d_model, d_model * 2, 1, bidirectional=bidirectional)
13
+ if bidirectional:
14
+ self.linear = Linear(d_model * 2 * 2, d_model)
15
+ else:
16
+ self.linear = Linear(d_model * 2, d_model)
17
+ self.dropout = Dropout(dropout)
18
+
19
+ def forward(self, x):
20
+ self.gru.flatten_parameters()
21
+ x, _ = self.gru(x)
22
+ x = F.leaky_relu(x)
23
+ x = self.dropout(x)
24
+ x = self.linear(x)
25
+
26
+ return x
27
+
28
+
29
+ class TransformerBlock(nn.Module):
30
+ def __init__(self, d_model, n_heads, bidirectional=True, dropout=0):
31
+ super(TransformerBlock, self).__init__()
32
+
33
+ self.norm1 = LayerNorm(d_model)
34
+ self.attention = MultiheadAttention(d_model, n_heads, dropout=dropout)
35
+ self.dropout1 = Dropout(dropout)
36
+
37
+ self.norm2 = LayerNorm(d_model)
38
+ self.ffn = FFN(d_model, bidirectional=bidirectional)
39
+ self.dropout2 = Dropout(dropout)
40
+
41
+ self.norm3 = LayerNorm(d_model)
42
+
43
+ def forward(self, x, attn_mask=None, key_padding_mask=None):
44
+ xt = self.norm1(x)
45
+ xt, _ = self.attention(xt, xt, xt,
46
+ attn_mask=attn_mask,
47
+ key_padding_mask=key_padding_mask)
48
+ x = x + self.dropout1(xt)
49
+
50
+ xt = self.norm2(x)
51
+ xt = self.ffn(xt)
52
+ x = x + self.dropout2(xt)
53
+
54
+ x = self.norm3(x)
55
+
56
+ return x
57
+
58
+
59
+ def main():
60
+ x = torch.randn(4, 64, 401, 201)
61
+ b, c, t, f = x.size()
62
+ x = x.permute(0, 3, 2, 1).contiguous().view(b, f * t, c)
63
+ transformer = TransformerBlock(d_model=64, n_heads=4)
64
+ x = transformer(x)
65
+ x = x.view(b, f, t, c).permute(0, 3, 2, 1)
66
+ print(x.size())
67
+
68
+
69
+ if __name__ == '__main__':
70
+ main()
toolbox/torchaudio/models/mpnet/utils.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from einops.layers.torch import Rearrange
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from pesq import pesq
9
+ from joblib import Parallel, delayed
10
+
11
+
12
+ def phase_losses(phase_r, phase_g):
13
+
14
+ ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
15
+ gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
16
+ iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
17
+
18
+ return ip_loss, gd_loss, iaf_loss
19
+
20
+
21
+ def anti_wrapping_function(x):
22
+
23
+ return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
24
+
25
+
26
+ def pesq_score(utts_r, utts_g, h):
27
+
28
+ pesq_score = Parallel(n_jobs=30)(delayed(eval_pesq)(
29
+ utts_r[i].squeeze().cpu().numpy(),
30
+ utts_g[i].squeeze().cpu().numpy(),
31
+ h.sampling_rate)
32
+ for i in range(len(utts_r)))
33
+ pesq_score = np.mean(pesq_score)
34
+
35
+ return pesq_score
36
+
37
+
38
+ def eval_pesq(clean_utt, esti_utt, sr):
39
+ try:
40
+ pesq_score = pesq(sr, clean_utt, esti_utt)
41
+ except:
42
+ pesq_score = -1
43
+
44
+ return pesq_score
45
+
46
+
47
+ def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
48
+
49
+ hann_window = torch.hann_window(win_size).to(y.device)
50
+ stft_spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window,
51
+ center=center, pad_mode='reflect', normalized=False, return_complex=True)
52
+ stft_spec = torch.view_as_real(stft_spec)
53
+ mag = torch.sqrt(stft_spec.pow(2).sum(-1) + 1e-9)
54
+ pha = torch.atan2(stft_spec[:, :, :, 1] + 1e-10, stft_spec[:, :, :, 0] + 1e-5)
55
+ # Magnitude Compression
56
+ mag = torch.pow(mag, compress_factor)
57
+ com = torch.stack((mag*torch.cos(pha), mag*torch.sin(pha)), dim=-1)
58
+
59
+ return mag, pha, com
60
+
61
+
62
+ def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
63
+ # Magnitude Decompression
64
+ mag = torch.pow(mag, (1.0/compress_factor))
65
+ com = torch.complex(mag*torch.cos(pha), mag*torch.sin(pha))
66
+ hann_window = torch.hann_window(win_size).to(com.device)
67
+ wav = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center)
68
+
69
+ return wav
70
+
71
+
72
+ class LearnableSigmoid1d(nn.Module):
73
+ def __init__(self, in_features, beta=1):
74
+ super().__init__()
75
+ self.beta = beta
76
+ self.slope = nn.Parameter(torch.ones(in_features))
77
+ self.slope.requiresGrad = True
78
+
79
+ def forward(self, x):
80
+ # x shape: [batch_size, time_steps, spec_bins]
81
+ return self.beta * torch.sigmoid(self.slope * x)
82
+
83
+
84
+ class LearnableSigmoid2d(nn.Module):
85
+ def __init__(self, in_features, beta=1):
86
+ super().__init__()
87
+ self.beta = beta
88
+ self.slope = nn.Parameter(torch.ones(in_features, 1))
89
+ self.slope.requiresGrad = True
90
+
91
+ def forward(self, x):
92
+ return self.beta * torch.sigmoid(self.slope * x)
93
+
94
+
95
+ def main():
96
+ learnable_sigmoid = LearnableSigmoid1d(201)
97
+ a = torch.randn(4, 100, 201)
98
+
99
+ result = learnable_sigmoid.forward(a)
100
+ print(result.shape)
101
+
102
+ return
103
+
104
+
105
+ if __name__ == '__main__':
106
+ main()
toolbox/torchaudio/models/mpnet/yaml/config.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "mpnet"
2
+
3
+ num_gpus: 0
4
+ batch_size: 4
5
+ learning_rate: 0.0005
6
+ adam_b1: 0.8
7
+ adam_b2: 0.99
8
+ lr_decay: 0.99
9
+ seed: 1234
10
+
11
+ dense_channel: 64
12
+ compress_factor: 0.3
13
+ num_tsconformers: 4
14
+ beta: 2.0
15
+
16
+ sample_rate: 16000
17
+ segment_size: 32000
18
+ n_fft: 400
19
+ hop_size: 100
20
+ win_size: 400
21
+
22
+ num_workers: 4
23
+
24
+ dist_config:
25
+ dist_backend: nccl
26
+ dist_url: tcp://localhost:54321
27
+ world_size: 1