HoneyTian commited on
Commit
294430e
·
1 Parent(s): cb8eb69
examples/clean_unet_aishell/step_2_train_model.py CHANGED
@@ -28,7 +28,7 @@ from torch.utils.data.dataloader import DataLoader
28
  from tqdm import tqdm
29
 
30
  from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
31
- from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanUnetConfig
32
  from toolbox.torchaudio.models.clean_unet.modeling_clean_unet import CleanUNetPretrainedModel
33
  from toolbox.torchaudio.models.clean_unet.training import LinearWarmupCosineDecay
34
  from toolbox.torchaudio.models.clean_unet.loss import MultiResolutionSTFTLoss
@@ -112,7 +112,7 @@ collate_fn = CollateFunction()
112
  def main():
113
  args = get_args()
114
 
115
- config = CleanUnetConfig.from_pretrained(
116
  pretrained_model_name_or_path=args.config_file,
117
  )
118
 
@@ -186,7 +186,7 @@ def main():
186
  model_pt = serialization_dir / f"epoch-{last_epoch}/model.pt"
187
  optimizer_pth = serialization_dir / f"epoch-{last_epoch}/optimizer.pth"
188
 
189
- logger.info(f"load state dict for generator.")
190
  with open(model_pt.as_posix(), "rb") as f:
191
  state_dict = torch.load(f, map_location="cpu", weights_only=True)
192
  model.load_state_dict(state_dict, strict=True)
@@ -317,7 +317,7 @@ def main():
317
  enhanced_audios = model.forward(noisy_audios)
318
  enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
319
 
320
- ae_loss = ae_loss_fn(enhanced_audios, enhanced_audios)
321
  sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios)
322
 
323
  loss = ae_loss + sc_loss + mag_loss
 
28
  from tqdm import tqdm
29
 
30
  from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
31
+ from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanUNetConfig
32
  from toolbox.torchaudio.models.clean_unet.modeling_clean_unet import CleanUNetPretrainedModel
33
  from toolbox.torchaudio.models.clean_unet.training import LinearWarmupCosineDecay
34
  from toolbox.torchaudio.models.clean_unet.loss import MultiResolutionSTFTLoss
 
112
  def main():
113
  args = get_args()
114
 
115
+ config = CleanUNetConfig.from_pretrained(
116
  pretrained_model_name_or_path=args.config_file,
117
  )
118
 
 
186
  model_pt = serialization_dir / f"epoch-{last_epoch}/model.pt"
187
  optimizer_pth = serialization_dir / f"epoch-{last_epoch}/optimizer.pth"
188
 
189
+ logger.info(f"load state dict for model.")
190
  with open(model_pt.as_posix(), "rb") as f:
191
  state_dict = torch.load(f, map_location="cpu", weights_only=True)
192
  model.load_state_dict(state_dict, strict=True)
 
317
  enhanced_audios = model.forward(noisy_audios)
318
  enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
319
 
320
+ ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
321
  sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios)
322
 
323
  loss = ae_loss + sc_loss + mag_loss
examples/nx_clean_unet/run.sh ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+
6
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \
7
+ --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
8
+ --speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train"
9
+
10
+
11
+ sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \
12
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
+
15
+ sh run.sh --stage 5 --stop_stage 5 --system_version centos --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \
16
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
+
19
+
20
+ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name mpnet-nx-speech-20250224 \
21
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
22
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech" \
23
+ --max_epochs 1
24
+
25
+
26
+ END
27
+
28
+
29
+ # params
30
+ system_version="windows";
31
+ verbose=true;
32
+ stage=0 # start from 0 if you need to start from data preparation
33
+ stop_stage=9
34
+
35
+ work_dir="$(pwd)"
36
+ file_folder_name=file_folder_name
37
+ final_model_name=final_model_name
38
+ config_file="yaml/config.yaml"
39
+ limit=10
40
+
41
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
42
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
43
+
44
+ nohup_name=nohup.out
45
+
46
+ # model params
47
+ batch_size=64
48
+ max_epochs=200
49
+ save_top_k=10
50
+ patience=5
51
+
52
+
53
+ # parse options
54
+ while true; do
55
+ [ -z "${1:-}" ] && break; # break if there are no arguments
56
+ case "$1" in
57
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
58
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
59
+ old_value="(eval echo \\$$name)";
60
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
61
+ was_bool=true;
62
+ else
63
+ was_bool=false;
64
+ fi
65
+
66
+ # Set the variable to the right value-- the escaped quotes make it work if
67
+ # the option had spaces, like --cmd "queue.pl -sync y"
68
+ eval "${name}=\"$2\"";
69
+
70
+ # Check that Boolean-valued arguments are really Boolean.
71
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
72
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
73
+ exit 1;
74
+ fi
75
+ shift 2;
76
+ ;;
77
+
78
+ *) break;
79
+ esac
80
+ done
81
+
82
+ file_dir="${work_dir}/${file_folder_name}"
83
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
84
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
85
+
86
+ dataset="${file_dir}/dataset.xlsx"
87
+ train_dataset="${file_dir}/train.xlsx"
88
+ valid_dataset="${file_dir}/valid.xlsx"
89
+
90
+ $verbose && echo "system_version: ${system_version}"
91
+ $verbose && echo "file_folder_name: ${file_folder_name}"
92
+
93
+ if [ $system_version == "windows" ]; then
94
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
95
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
96
+ #source /data/local/bin/nx_denoise/bin/activate
97
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
98
+ fi
99
+
100
+
101
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
102
+ $verbose && echo "stage 1: prepare data"
103
+ cd "${work_dir}" || exit 1
104
+ python3 step_1_prepare_data.py \
105
+ --file_dir "${file_dir}" \
106
+ --noise_dir "${noise_dir}" \
107
+ --speech_dir "${speech_dir}" \
108
+ --train_dataset "${train_dataset}" \
109
+ --valid_dataset "${valid_dataset}" \
110
+
111
+ fi
112
+
113
+
114
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
115
+ $verbose && echo "stage 2: train model"
116
+ cd "${work_dir}" || exit 1
117
+ python3 step_2_train_model.py \
118
+ --train_dataset "${train_dataset}" \
119
+ --valid_dataset "${valid_dataset}" \
120
+ --serialization_dir "${file_dir}" \
121
+ --config_file "${config_file}" \
122
+
123
+ fi
124
+
125
+
126
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
127
+ $verbose && echo "stage 3: test model"
128
+ cd "${work_dir}" || exit 1
129
+ python3 step_3_evaluation.py \
130
+ --valid_dataset "${valid_dataset}" \
131
+ --model_dir "${file_dir}/best" \
132
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
133
+ --limit "${limit}" \
134
+
135
+ fi
136
+
137
+
138
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
139
+ $verbose && echo "stage 4: collect files"
140
+ cd "${work_dir}" || exit 1
141
+
142
+ mkdir -p ${final_model_dir}
143
+
144
+ cp "${file_dir}/best"/* "${final_model_dir}"
145
+ cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
146
+
147
+ cd "${final_model_dir}/.." || exit 1;
148
+
149
+ if [ -e "${final_model_name}.zip" ]; then
150
+ rm -rf "${final_model_name}_backup.zip"
151
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
152
+ fi
153
+
154
+ zip -r "${final_model_name}.zip" "${final_model_name}"
155
+ rm -rf "${final_model_name}"
156
+
157
+ fi
158
+
159
+
160
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
161
+ $verbose && echo "stage 5: clear file_dir"
162
+ cd "${work_dir}" || exit 1
163
+
164
+ rm -rf "${file_dir}";
165
+
166
+ fi
examples/nx_clean_unet/step_1_prepare_data.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import random
7
+ import sys
8
+ import shutil
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import pandas as pd
14
+ from scipy.io import wavfile
15
+ from tqdm import tqdm
16
+ import librosa
17
+
18
+ from project_settings import project_path
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--file_dir", default="./", type=str)
24
+
25
+ parser.add_argument(
26
+ "--noise_dir",
27
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
28
+ type=str
29
+ )
30
+ parser.add_argument(
31
+ "--speech_dir",
32
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
33
+ type=str
34
+ )
35
+
36
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
+
39
+ parser.add_argument("--duration", default=2.0, type=float)
40
+ parser.add_argument("--min_snr_db", default=-10, type=float)
41
+ parser.add_argument("--max_snr_db", default=20, type=float)
42
+
43
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
44
+
45
+ parser.add_argument("--max_count", default=10000, type=int)
46
+
47
+ args = parser.parse_args()
48
+ return args
49
+
50
+
51
+ def filename_generator(data_dir: str):
52
+ data_dir = Path(data_dir)
53
+ for filename in data_dir.glob("**/*.wav"):
54
+ yield filename.as_posix()
55
+
56
+
57
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
58
+ data_dir = Path(data_dir)
59
+ for filename in data_dir.glob("**/*.wav"):
60
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
61
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
62
+
63
+ if raw_duration < duration:
64
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
65
+ continue
66
+ if signal.ndim != 1:
67
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
68
+
69
+ signal_length = len(signal)
70
+ win_size = int(duration * sample_rate)
71
+ for begin in range(0, signal_length - win_size, win_size):
72
+ row = {
73
+ "filename": filename.as_posix(),
74
+ "raw_duration": round(raw_duration, 4),
75
+ "offset": round(begin / sample_rate, 4),
76
+ "duration": round(duration, 4),
77
+ }
78
+ yield row
79
+
80
+
81
+ def get_dataset(args):
82
+ file_dir = Path(args.file_dir)
83
+ file_dir.mkdir(exist_ok=True)
84
+
85
+ noise_dir = Path(args.noise_dir)
86
+ speech_dir = Path(args.speech_dir)
87
+
88
+ noise_generator = target_second_signal_generator(
89
+ noise_dir.as_posix(),
90
+ duration=args.duration,
91
+ sample_rate=args.target_sample_rate
92
+ )
93
+ speech_generator = target_second_signal_generator(
94
+ speech_dir.as_posix(),
95
+ duration=args.duration,
96
+ sample_rate=args.target_sample_rate
97
+ )
98
+
99
+ dataset = list()
100
+
101
+ count = 0
102
+ process_bar = tqdm(desc="build dataset excel")
103
+ for noise, speech in zip(noise_generator, speech_generator):
104
+ if count >= args.max_count:
105
+ break
106
+
107
+ noise_filename = noise["filename"]
108
+ noise_raw_duration = noise["raw_duration"]
109
+ noise_offset = noise["offset"]
110
+ noise_duration = noise["duration"]
111
+
112
+ speech_filename = speech["filename"]
113
+ speech_raw_duration = speech["raw_duration"]
114
+ speech_offset = speech["offset"]
115
+ speech_duration = speech["duration"]
116
+
117
+ random1 = random.random()
118
+ random2 = random.random()
119
+
120
+ row = {
121
+ "noise_filename": noise_filename,
122
+ "noise_raw_duration": noise_raw_duration,
123
+ "noise_offset": noise_offset,
124
+ "noise_duration": noise_duration,
125
+
126
+ "speech_filename": speech_filename,
127
+ "speech_raw_duration": speech_raw_duration,
128
+ "speech_offset": speech_offset,
129
+ "speech_duration": speech_duration,
130
+
131
+ "snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
132
+
133
+ "random1": random1,
134
+ "random2": random2,
135
+ "flag": "TRAIN" if random2 < 0.8 else "TEST",
136
+ }
137
+ dataset.append(row)
138
+ count += 1
139
+ duration_seconds = count * args.duration
140
+ duration_hours = duration_seconds / 3600
141
+
142
+ process_bar.update(n=1)
143
+ process_bar.set_postfix({
144
+ # "duration_seconds": round(duration_seconds, 4),
145
+ "duration_hours": round(duration_hours, 4),
146
+
147
+ })
148
+
149
+ dataset = pd.DataFrame(dataset)
150
+ dataset = dataset.sort_values(by=["random1"], ascending=False)
151
+ dataset.to_excel(
152
+ file_dir / "dataset.xlsx",
153
+ index=False,
154
+ )
155
+ return
156
+
157
+
158
+
159
+ def split_dataset(args):
160
+ """分割训练集, 测试集"""
161
+ file_dir = Path(args.file_dir)
162
+ file_dir.mkdir(exist_ok=True)
163
+
164
+ df = pd.read_excel(file_dir / "dataset.xlsx")
165
+
166
+ train = list()
167
+ test = list()
168
+
169
+ for i, row in df.iterrows():
170
+ flag = row["flag"]
171
+ if flag == "TRAIN":
172
+ train.append(row)
173
+ else:
174
+ test.append(row)
175
+
176
+ train = pd.DataFrame(train)
177
+ train.to_excel(
178
+ args.train_dataset,
179
+ index=False,
180
+ # encoding="utf_8_sig"
181
+ )
182
+ test = pd.DataFrame(test)
183
+ test.to_excel(
184
+ args.valid_dataset,
185
+ index=False,
186
+ # encoding="utf_8_sig"
187
+ )
188
+
189
+ return
190
+
191
+
192
+ def main():
193
+ args = get_args()
194
+
195
+ get_dataset(args)
196
+ split_dataset(args)
197
+ return
198
+
199
+
200
+ if __name__ == "__main__":
201
+ main()
examples/nx_clean_unet/step_2_train_model.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/yxlu-0102/MP-SENet/blob/main/train.py
5
+ """
6
+ import argparse
7
+ import json
8
+ import logging
9
+ from logging.handlers import TimedRotatingFileHandler
10
+ import os
11
+ import platform
12
+ from pathlib import Path
13
+ import random
14
+ import sys
15
+ import shutil
16
+ from typing import List
17
+
18
+ pwd = os.path.abspath(os.path.dirname(__file__))
19
+ sys.path.append(os.path.join(pwd, "../../"))
20
+
21
+ import numpy as np
22
+ import torch
23
+ from torch.nn import functional as F
24
+ from torch.utils.data.dataloader import DataLoader
25
+ from tqdm import tqdm
26
+
27
+ from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
28
+ from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
29
+ from toolbox.torchaudio.models.nx_clean_unet.discriminator import MetricDiscriminator, MetricDiscriminatorPretrainedModel
30
+ from toolbox.torchaudio.models.nx_clean_unet.modeling_nx_clean_unet import NXCleanUNet, NXCleanUNetPretrainedModel
31
+ from toolbox.torchaudio.models.nx_clean_unet.metrics import run_batch_pesq, run_pesq_score
32
+ from toolbox.torchaudio.models.nx_clean_unet.utils import mag_pha_stft, mag_pha_istft
33
+ from toolbox.torchaudio.models.nx_clean_unet.loss import phase_losses
34
+
35
+
36
+ def get_args():
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
39
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
40
+
41
+ parser.add_argument("--max_epochs", default=100, type=int)
42
+
43
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
44
+ parser.add_argument("--patience", default=5, type=int)
45
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
46
+
47
+ parser.add_argument("--config_file", default="config.yaml", type=str)
48
+
49
+ args = parser.parse_args()
50
+ return args
51
+
52
+
53
+ def logging_config(file_dir: str):
54
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
55
+
56
+ logging.basicConfig(format=fmt,
57
+ datefmt="%m/%d/%Y %H:%M:%S",
58
+ level=logging.INFO)
59
+ file_handler = TimedRotatingFileHandler(
60
+ filename=os.path.join(file_dir, "main.log"),
61
+ encoding="utf-8",
62
+ when="D",
63
+ interval=1,
64
+ backupCount=7
65
+ )
66
+ file_handler.setLevel(logging.INFO)
67
+ file_handler.setFormatter(logging.Formatter(fmt))
68
+ logger = logging.getLogger(__name__)
69
+ logger.addHandler(file_handler)
70
+
71
+ return logger
72
+
73
+
74
+ class CollateFunction(object):
75
+ def __init__(self):
76
+ pass
77
+
78
+ def __call__(self, batch: List[dict]):
79
+ clean_audios = list()
80
+ noisy_audios = list()
81
+
82
+ for sample in batch:
83
+ # noise_wave: torch.Tensor = sample["noise_wave"]
84
+ clean_audio: torch.Tensor = sample["speech_wave"]
85
+ noisy_audio: torch.Tensor = sample["mix_wave"]
86
+ # snr_db: float = sample["snr_db"]
87
+
88
+ clean_audios.append(clean_audio)
89
+ noisy_audios.append(noisy_audio)
90
+
91
+ clean_audios = torch.stack(clean_audios)
92
+ noisy_audios = torch.stack(noisy_audios)
93
+
94
+ # assert
95
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
96
+ raise AssertionError("nan or inf in clean_audios")
97
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
98
+ raise AssertionError("nan or inf in noisy_audios")
99
+ return clean_audios, noisy_audios
100
+
101
+
102
+ collate_fn = CollateFunction()
103
+
104
+
105
+ def main():
106
+ args = get_args()
107
+
108
+ config = NXCleanUNetConfig.from_pretrained(
109
+ pretrained_model_name_or_path=args.config_file,
110
+ )
111
+
112
+ serialization_dir = Path(args.serialization_dir)
113
+ serialization_dir.mkdir(parents=True, exist_ok=True)
114
+
115
+ logger = logging_config(serialization_dir)
116
+
117
+ random.seed(config.seed)
118
+ np.random.seed(config.seed)
119
+ torch.manual_seed(config.seed)
120
+ logger.info(f"set seed: {config.seed}")
121
+
122
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
123
+ n_gpu = torch.cuda.device_count()
124
+ logger.info(f"GPU available count: {n_gpu}; device: {device}")
125
+
126
+ # datasets
127
+ train_dataset = DenoiseExcelDataset(
128
+ excel_file=args.train_dataset,
129
+ expected_sample_rate=8000,
130
+ max_wave_value=32768.0,
131
+ )
132
+ valid_dataset = DenoiseExcelDataset(
133
+ excel_file=args.valid_dataset,
134
+ expected_sample_rate=8000,
135
+ max_wave_value=32768.0,
136
+ )
137
+ train_data_loader = DataLoader(
138
+ dataset=train_dataset,
139
+ batch_size=config.batch_size,
140
+ shuffle=True,
141
+ sampler=None,
142
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
143
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
144
+ collate_fn=collate_fn,
145
+ pin_memory=False,
146
+ # prefetch_factor=64,
147
+ )
148
+ valid_data_loader = DataLoader(
149
+ dataset=valid_dataset,
150
+ batch_size=config.batch_size,
151
+ shuffle=True,
152
+ sampler=None,
153
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
154
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
155
+ collate_fn=collate_fn,
156
+ pin_memory=False,
157
+ # prefetch_factor=64,
158
+ )
159
+
160
+ # models
161
+ logger.info(f"prepare models. config_file: {args.config_file}")
162
+ generator = NXCleanUNetPretrainedModel(config).to(device)
163
+ discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
164
+
165
+ # optimizer
166
+ logger.info("prepare optimizer, lr_scheduler")
167
+ num_params = 0
168
+ for p in generator.parameters():
169
+ num_params += p.numel()
170
+ logger.info("total parameters (generator): {:.3f}M".format(num_params/1e6))
171
+
172
+ optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
173
+ optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
174
+
175
+ # resume training
176
+ last_epoch = -1
177
+ for epoch_i in serialization_dir.glob("epoch-*"):
178
+ epoch_i = Path(epoch_i)
179
+ epoch_idx = epoch_i.stem.split("-")[1]
180
+ epoch_idx = int(epoch_idx)
181
+ if epoch_idx > last_epoch:
182
+ last_epoch = epoch_idx
183
+
184
+ if last_epoch != -1:
185
+ logger.info(f"resume from epoch-{last_epoch}.")
186
+ generator_pt = serialization_dir / f"epoch-{last_epoch}/generator.pt"
187
+ discriminator_pt = serialization_dir / f"epoch-{last_epoch}/discriminator.pt"
188
+ optim_g_pth = serialization_dir / f"epoch-{last_epoch}/optim_g.pth"
189
+ optim_d_pth = serialization_dir / f"epoch-{last_epoch}/optim_d.pth"
190
+
191
+ logger.info(f"load state dict for generator.")
192
+ with open(generator_pt.as_posix(), "rb") as f:
193
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
194
+ generator.load_state_dict(state_dict, strict=True)
195
+ logger.info(f"load state dict for discriminator.")
196
+ with open(discriminator_pt.as_posix(), "rb") as f:
197
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
198
+ discriminator.load_state_dict(state_dict, strict=True)
199
+
200
+ logger.info(f"load state dict for optim_g.")
201
+ with open(optim_g_pth.as_posix(), "rb") as f:
202
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
203
+ optim_g.load_state_dict(state_dict)
204
+ logger.info(f"load state dict for optim_d.")
205
+ with open(optim_d_pth.as_posix(), "rb") as f:
206
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
207
+ optim_d.load_state_dict(state_dict)
208
+
209
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch)
210
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch)
211
+
212
+ # training loop
213
+
214
+ # state
215
+ loss_d = 10000000000
216
+ loss_g = 10000000000
217
+ pesq_metric = 10000000000
218
+ mag_err = 10000000000
219
+ pha_err = 10000000000
220
+ com_err = 10000000000
221
+
222
+ model_list = list()
223
+ best_idx_epoch = None
224
+ best_metric = None
225
+ patience_count = 0
226
+
227
+ logger.info("training")
228
+ for idx_epoch in range(max(0, last_epoch+1), args.max_epochs):
229
+ # train
230
+ generator.train()
231
+ discriminator.train()
232
+
233
+ total_loss_d = 0.
234
+ total_loss_g = 0.
235
+ total_batches = 0.
236
+ progress_bar = tqdm(
237
+ total=len(train_data_loader),
238
+ desc="Training; epoch: {}".format(idx_epoch),
239
+ )
240
+ for batch in train_data_loader:
241
+ clean_audios, noisy_audios = batch
242
+ clean_audios = clean_audios.to(device)
243
+ noisy_audios = noisy_audios.to(device)
244
+ one_labels = torch.ones(clean_audios.shape[0]).to(device)
245
+
246
+ audio_g = generator.forward(noisy_audios)
247
+
248
+ clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audios, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
249
+ mag_g, pha_g, com_g = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
250
+
251
+ clean_audio_list = torch.split(clean_audios, 1, dim=0)
252
+ enhanced_audio_list = torch.split(audio_g, 1, dim=0)
253
+ clean_audio_list = [t.squeeze().cpu().numpy() for t in clean_audio_list]
254
+ enhanced_audio_list = [t.squeeze().cpu().numpy() for t in enhanced_audio_list]
255
+
256
+ pesq_score_list: List[float] = run_batch_pesq(clean_audio_list, enhanced_audio_list, sample_rate=config.sample_rate, mode="nb")
257
+
258
+ # Discriminator
259
+ optim_d.zero_grad()
260
+ metric_r = discriminator.forward(clean_audios, clean_audios)
261
+ metric_g = discriminator.forward(clean_audios, audio_g.detach())
262
+ loss_disc_r = F.mse_loss(one_labels, metric_r.flatten())
263
+
264
+ if -1 in pesq_score_list:
265
+ # print("-1 in batch_pesq_score!")
266
+ loss_disc_g = 0
267
+ else:
268
+ pesq_score_list: torch.FloatTensor = torch.tensor([(score - 1) / 3.5 for score in pesq_score_list], dtype=torch.float32)
269
+ loss_disc_g = F.mse_loss(pesq_score_list.to(device), metric_g.flatten())
270
+
271
+ loss_disc_all = loss_disc_r + loss_disc_g
272
+ loss_disc_all.backward()
273
+ optim_d.step()
274
+
275
+ # Generator
276
+ optim_g.zero_grad()
277
+ # L2 Magnitude Loss
278
+ loss_mag = F.mse_loss(clean_mag, mag_g)
279
+ # Anti-wrapping Phase Loss
280
+ loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g)
281
+ loss_pha = loss_ip + loss_gd + loss_iaf
282
+ # L2 Complex Loss
283
+ loss_com = F.mse_loss(clean_com, com_g) * 2
284
+ # L2 Consistency Loss
285
+ # Time Loss
286
+ loss_time = F.l1_loss(clean_audios, audio_g)
287
+ # Metric Loss
288
+ metric_g = discriminator.forward(clean_mag, mag_g)
289
+ loss_metric = F.mse_loss(metric_g.flatten(), one_labels)
290
+
291
+ loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_metric * 0.05 + loss_time * 0.2
292
+
293
+ loss_gen_all.backward()
294
+ optim_g.step()
295
+
296
+ total_loss_d += loss_disc_all.item()
297
+ total_loss_g += loss_gen_all.item()
298
+ total_batches += 1
299
+
300
+ loss_d = round(total_loss_d / total_batches, 4)
301
+ loss_g = round(total_loss_g / total_batches, 4)
302
+
303
+ progress_bar.update(1)
304
+ progress_bar.set_postfix({
305
+ "loss_d": loss_d,
306
+ "loss_g": loss_g,
307
+ })
308
+
309
+ # evaluation
310
+ generator.eval()
311
+ discriminator.eval()
312
+
313
+ torch.cuda.empty_cache()
314
+ total_pesq_score = 0.
315
+ total_mag_err = 0.
316
+ total_pha_err = 0.
317
+ total_com_err = 0.
318
+ total_batches = 0.
319
+
320
+ progress_bar = tqdm(
321
+ total=len(valid_data_loader),
322
+ desc="Evaluation; epoch: {}".format(idx_epoch),
323
+ )
324
+ with torch.no_grad():
325
+ for batch in valid_data_loader:
326
+ clean_audios, noisy_audios = batch
327
+ clean_audios = clean_audios.to(device)
328
+ noisy_audios = noisy_audios.to(device)
329
+
330
+ audio_g = generator.forward(noisy_audios)
331
+
332
+ clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audios, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
333
+ mag_g, pha_g, com_g = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
334
+
335
+ clean_audio_list = torch.split(clean_audios, 1, dim=0)
336
+ enhanced_audio_list = torch.split(audio_g, 1, dim=0)
337
+ clean_audio_list = [t.squeeze().cpu().numpy() for t in clean_audio_list]
338
+ enhanced_audio_list = [t.squeeze().cpu().numpy() for t in enhanced_audio_list]
339
+ pesq_score = run_pesq_score(
340
+ clean_audio_list,
341
+ enhanced_audio_list,
342
+ sample_rate = config.sample_rate,
343
+ mode = "nb",
344
+ )
345
+ total_pesq_score += pesq_score
346
+ total_mag_err += F.mse_loss(clean_mag, mag_g).item()
347
+ val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g)
348
+ total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item()
349
+ total_com_err += F.mse_loss(clean_com, com_g).item()
350
+
351
+ total_batches += 1
352
+
353
+ pesq_metric = round(total_pesq_score / total_batches, 4)
354
+ mag_err = round(total_mag_err / total_batches, 4)
355
+ pha_err = round(total_pha_err / total_batches, 4)
356
+ com_err = round(total_com_err / total_batches, 4)
357
+
358
+ progress_bar.update(1)
359
+ progress_bar.set_postfix({
360
+ "pesq_metric": pesq_metric,
361
+ "mag_err": mag_err,
362
+ "pha_err": pha_err,
363
+ "com_err": com_err,
364
+ })
365
+
366
+ # scheduler
367
+ scheduler_g.step()
368
+ scheduler_d.step()
369
+
370
+ # save path
371
+ epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
372
+ epoch_dir.mkdir(parents=True, exist_ok=False)
373
+
374
+ # save models
375
+ generator.save_pretrained(epoch_dir.as_posix())
376
+ discriminator.save_pretrained(epoch_dir.as_posix())
377
+
378
+ # save optim
379
+ torch.save(optim_d.state_dict(), (epoch_dir / "optim_d.pth").as_posix())
380
+ torch.save(optim_g.state_dict(), (epoch_dir / "optim_g.pth").as_posix())
381
+
382
+ model_list.append(epoch_dir)
383
+ if len(model_list) >= args.num_serialized_models_to_keep:
384
+ model_to_delete: Path = model_list.pop(0)
385
+ shutil.rmtree(model_to_delete.as_posix())
386
+
387
+ # save metric
388
+ if best_metric is None:
389
+ best_idx_epoch = idx_epoch
390
+ best_metric = pesq_metric
391
+ elif pesq_metric > best_metric:
392
+ # great is better.
393
+ best_idx_epoch = idx_epoch
394
+ best_metric = pesq_metric
395
+ else:
396
+ pass
397
+
398
+ metrics = {
399
+ "idx_epoch": idx_epoch,
400
+ "best_idx_epoch": best_idx_epoch,
401
+ "loss_d": loss_d,
402
+ "loss_g": loss_g,
403
+
404
+ "pesq_metric": pesq_metric,
405
+ "mag_err": mag_err,
406
+ "pha_err": pha_err,
407
+ "com_err": com_err,
408
+
409
+ }
410
+ metrics_filename = epoch_dir / "metrics_epoch.json"
411
+ with open(metrics_filename, "w", encoding="utf-8") as f:
412
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
413
+
414
+ # save best
415
+ best_dir = serialization_dir / "best"
416
+ if best_idx_epoch == idx_epoch:
417
+ if best_dir.exists():
418
+ shutil.rmtree(best_dir)
419
+ shutil.copytree(epoch_dir, best_dir)
420
+
421
+ # early stop
422
+ early_stop_flag = False
423
+ if best_idx_epoch == idx_epoch:
424
+ patience_count = 0
425
+ else:
426
+ patience_count += 1
427
+ if patience_count >= args.patience:
428
+ early_stop_flag = True
429
+
430
+ # early stop
431
+ if early_stop_flag:
432
+ break
433
+
434
+ return
435
+
436
+
437
+ if __name__ == "__main__":
438
+ main()
examples/nx_clean_unet/step_3_evaluation.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
examples/nx_clean_unet/yaml/config.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "nx_clean_unet"
2
+
3
+ sample_rate: 8000
4
+ segment_size: 16000
5
+ n_fft: 512
6
+ win_size: 200
7
+ hop_size: 80
8
+
9
+ down_sampling_num_layers: 5
10
+ down_sampling_in_channels: 1
11
+ down_sampling_hidden_channels: 64
12
+ down_sampling_kernel_size: 4
13
+ down_sampling_stride: 2
14
+
15
+ tsfm_hidden_size: 256
16
+ tsfm_attention_heads: 4
17
+ tsfm_num_blocks: 6
18
+ tsfm_dropout_rate: 0.1
19
+
20
+ discriminator_dim: 32
21
+ discriminator_in_channel: 2
22
+
23
+ compress_factor: 0.3
toolbox/torchaudio/models/clean_unet/configuration_clean_unet.py CHANGED
@@ -3,7 +3,7 @@
3
  from toolbox.torchaudio.configuration_utils import PretrainedConfig
4
 
5
 
6
- class CleanUnetConfig(PretrainedConfig):
7
  def __init__(self,
8
  channels_input: int = 1,
9
  channels_output: int = 1,
@@ -21,7 +21,7 @@ class CleanUnetConfig(PretrainedConfig):
21
 
22
  **kwargs
23
  ):
24
- super(CleanUnetConfig, self).__init__(**kwargs)
25
  self.channels_input = channels_input
26
  self.channels_output = channels_output
27
 
 
3
  from toolbox.torchaudio.configuration_utils import PretrainedConfig
4
 
5
 
6
+ class CleanUNetConfig(PretrainedConfig):
7
  def __init__(self,
8
  channels_input: int = 1,
9
  channels_output: int = 1,
 
21
 
22
  **kwargs
23
  ):
24
+ super(CleanUNetConfig, self).__init__(**kwargs)
25
  self.channels_input = channels_input
26
  self.channels_output = channels_output
27
 
toolbox/torchaudio/models/clean_unet/inference_clean_unet.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import logging
4
+ from pathlib import Path
5
+ import shutil
6
+ import tempfile
7
+ import zipfile
8
+
9
+ import librosa
10
+ import numpy as np
11
+ import torch
12
+ import torchaudio
13
+
14
+ from project_settings import project_path
15
+ from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanUNetConfig
16
+ from toolbox.torchaudio.models.clean_unet.modeling_clean_unet import CleanUNetPretrainedModel, MODEL_FILE
17
+
18
+ logger = logging.getLogger("toolbox")
19
+
20
+
21
+ class InferenceCleanUNet(object):
22
+ def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
23
+ self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
24
+ self.device = torch.device(device)
25
+
26
+ logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
27
+ config, model = self.load_models(self.pretrained_model_path_or_zip_file)
28
+ logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
29
+
30
+ self.config = config
31
+ self.model = model
32
+ self.model.to(device)
33
+ self.model.eval()
34
+
35
+ def load_models(self, model_path: str):
36
+ model_path = Path(model_path)
37
+ if model_path.name.endswith(".zip"):
38
+ with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
39
+ out_root = Path(tempfile.gettempdir()) / "nx_denoise"
40
+ out_root.mkdir(parents=True, exist_ok=True)
41
+ f_zip.extractall(path=out_root)
42
+ model_path = out_root / model_path.stem
43
+
44
+ config = CleanUNetConfig.from_pretrained(
45
+ pretrained_model_name_or_path=model_path.as_posix(),
46
+ )
47
+ model = CleanUNetPretrainedModel.from_pretrained(
48
+ pretrained_model_name_or_path=model_path.as_posix(),
49
+ )
50
+ model.to(self.device)
51
+ model.eval()
52
+
53
+ shutil.rmtree(model_path)
54
+ return config, model
55
+
56
+ def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray:
57
+ noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
58
+ noisy_audio = noisy_audio.unsqueeze(dim=0)
59
+
60
+ # noisy_audio shape: [batch_size, n_samples]
61
+ enhanced_audio = self.enhancement_by_tensor(noisy_audio)
62
+ # noisy_audio shape: [channels, n_samples]
63
+ return enhanced_audio.cpu().numpy()
64
+
65
+ def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
66
+ if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
67
+ raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
68
+
69
+ # noisy_audio shape: [batch_size, num_samples]
70
+ noisy_audios = noisy_audio.to(self.device)
71
+
72
+ with torch.no_grad():
73
+ enhanced_audios = self.model.forward(noisy_audios)
74
+ # enhanced_audio shape: [batch_size, channels, num_samples]
75
+ # enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
76
+
77
+ enhanced_audio = enhanced_audios[0]
78
+
79
+ # enhanced_audio shape: [channels, num_samples]
80
+ return enhanced_audio
81
+
82
+ def main():
83
+ model_zip_file = project_path / "trained_models/clean-unet-aishell-18-epoch.zip"
84
+ infer_mpnet = InferenceCleanUNet(model_zip_file)
85
+
86
+ sample_rate = 8000
87
+ noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav"
88
+ noisy_audio, _ = librosa.load(
89
+ noisy_audio_file.as_posix(),
90
+ sr=sample_rate,
91
+ )
92
+ noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
93
+ noisy_audio = noisy_audio.unsqueeze(dim=0)
94
+
95
+ enhanced_audio = infer_mpnet.enhancement_by_tensor(noisy_audio)
96
+
97
+ filename = "enhanced_audio.wav"
98
+ torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
99
+
100
+ return
101
+
102
+
103
+ if __name__ == '__main__':
104
+ main()
toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py CHANGED
@@ -22,7 +22,7 @@ import torch.nn.functional as F
22
 
23
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
24
  from toolbox.torchaudio.models.clean_unet.transformer import TransformerEncoder
25
- from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanUnetConfig
26
 
27
 
28
  def weight_scaling_init(layer):
@@ -196,7 +196,7 @@ class CleanUNet(nn.Module):
196
 
197
  x = self.tsfm_conv1(x) # C 1024 -> 512
198
  x = x.permute(0, 2, 1)
199
- x = self.tsfm_encoder(x, src_mask=attn_mask)
200
  x = x.permute(0, 2, 1)
201
  x = self.tsfm_conv2(x) # C 512 -> 1024
202
 
@@ -215,7 +215,7 @@ MODEL_FILE = "model.pt"
215
 
216
  class CleanUNetPretrainedModel(CleanUNet):
217
  def __init__(self,
218
- config: CleanUnetConfig,
219
  ):
220
  super(CleanUNetPretrainedModel, self).__init__(
221
  channels_input=config.channels_input,
@@ -234,7 +234,7 @@ class CleanUNetPretrainedModel(CleanUNet):
234
 
235
  @classmethod
236
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
237
- config = CleanUnetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
238
 
239
  model = cls(config)
240
 
@@ -272,7 +272,7 @@ class CleanUNetPretrainedModel(CleanUNet):
272
 
273
  def main():
274
 
275
- config = CleanUnetConfig()
276
  model = CleanUNetPretrainedModel(config)
277
 
278
  print_size(model, keyword="tsfm")
 
22
 
23
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
24
  from toolbox.torchaudio.models.clean_unet.transformer import TransformerEncoder
25
+ from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanUNetConfig
26
 
27
 
28
  def weight_scaling_init(layer):
 
196
 
197
  x = self.tsfm_conv1(x) # C 1024 -> 512
198
  x = x.permute(0, 2, 1)
199
+ x = self.tsfm_encoder.forward(x, src_mask=attn_mask)
200
  x = x.permute(0, 2, 1)
201
  x = self.tsfm_conv2(x) # C 512 -> 1024
202
 
 
215
 
216
  class CleanUNetPretrainedModel(CleanUNet):
217
  def __init__(self,
218
+ config: CleanUNetConfig,
219
  ):
220
  super(CleanUNetPretrainedModel, self).__init__(
221
  channels_input=config.channels_input,
 
234
 
235
  @classmethod
236
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
237
+ config = CleanUNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
238
 
239
  model = cls(config)
240
 
 
272
 
273
  def main():
274
 
275
+ config = CleanUNetConfig()
276
  model = CleanUNetPretrainedModel(config)
277
 
278
  print_size(model, keyword="tsfm")
toolbox/torchaudio/models/nx_clean_unet/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/models/nx_clean_unet/configuration_nx_clean_unet.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class NXCleanUNetConfig(PretrainedConfig):
7
+ """
8
+ https://github.com/yxlu-0102/MP-SENet/blob/main/config.json
9
+ """
10
+ def __init__(self,
11
+ n_fft: int = 512,
12
+ win_length: int = 200,
13
+ hop_length: int = 80,
14
+
15
+ down_sampling_num_layers: int = 5,
16
+ down_sampling_in_channels: int = 1,
17
+ down_sampling_hidden_channels: int = 64,
18
+ down_sampling_kernel_size: int = 4,
19
+ down_sampling_stride: int = 2,
20
+
21
+ tsfm_hidden_size: int = 256,
22
+ tsfm_attention_heads: int = 4,
23
+ tsfm_num_blocks: int = 6,
24
+ tsfm_dropout_rate: float = 0.1,
25
+
26
+ discriminator_dim: int = 32,
27
+ discriminator_in_channel: int = 2,
28
+
29
+ compress_factor: float = 0.3,
30
+
31
+ **kwargs
32
+ ):
33
+ super(NXCleanUNetConfig, self).__init__(**kwargs)
34
+ self.n_fft = n_fft
35
+ self.win_length = win_length
36
+ self.hop_length = hop_length
37
+
38
+ self.down_sampling_num_layers = down_sampling_num_layers
39
+ self.down_sampling_in_channels = down_sampling_in_channels
40
+ self.down_sampling_hidden_channels = down_sampling_hidden_channels
41
+ self.down_sampling_kernel_size = down_sampling_kernel_size
42
+ self.down_sampling_stride = down_sampling_stride
43
+
44
+ self.tsfm_hidden_size = tsfm_hidden_size
45
+ self.tsfm_attention_heads = tsfm_attention_heads
46
+ self.tsfm_num_blocks = tsfm_num_blocks
47
+ self.tsfm_dropout_rate = tsfm_dropout_rate
48
+
49
+ self.discriminator_dim = discriminator_dim
50
+ self.discriminator_in_channel = discriminator_in_channel
51
+
52
+ self.compress_factor = compress_factor
53
+
54
+
55
+ if __name__ == '__main__':
56
+ pass
toolbox/torchaudio/models/nx_clean_unet/discriminator.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from typing import Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torchaudio
9
+
10
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
11
+ from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
12
+ from toolbox.torchaudio.models.nx_clean_unet.utils import LearnableSigmoid1d
13
+
14
+
15
+ class MetricDiscriminator(nn.Module):
16
+ def __init__(self, config: NXCleanUNetConfig):
17
+ super(MetricDiscriminator, self).__init__()
18
+ dim = config.discriminator_dim
19
+ self.in_channel = config.discriminator_in_channel
20
+
21
+ self.n_fft = config.n_fft
22
+ self.win_length = config.win_length
23
+ self.hop_length = config.hop_length
24
+
25
+ self.layers = nn.Sequential(
26
+ torchaudio.transforms.Spectrogram(
27
+ n_fft=self.n_fft,
28
+ win_length=self.win_length,
29
+ hop_length=self.hop_length,
30
+ power=1.0,
31
+ window_fn=torch.hamming_window,
32
+ # window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
33
+ ),
34
+ nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
35
+ nn.InstanceNorm2d(dim, affine=True),
36
+ nn.PReLU(dim),
37
+ nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
38
+ nn.InstanceNorm2d(dim*2, affine=True),
39
+ nn.PReLU(dim*2),
40
+ nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)),
41
+ nn.InstanceNorm2d(dim*4, affine=True),
42
+ nn.PReLU(dim*4),
43
+ nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)),
44
+ nn.InstanceNorm2d(dim*8, affine=True),
45
+ nn.PReLU(dim*8),
46
+ nn.AdaptiveMaxPool2d(1),
47
+ nn.Flatten(),
48
+ nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)),
49
+ nn.Dropout(0.3),
50
+ nn.PReLU(dim*4),
51
+ nn.utils.spectral_norm(nn.Linear(dim*4, 1)),
52
+ LearnableSigmoid1d(1)
53
+ )
54
+
55
+ def forward(self, x, y):
56
+ xy = torch.stack((x, y), dim=1)
57
+ return self.layers(xy)
58
+
59
+
60
+ MODEL_FILE = "discriminator.pt"
61
+
62
+
63
+ class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
64
+ def __init__(self,
65
+ config: NXCleanUNetConfig,
66
+ ):
67
+ super(MetricDiscriminatorPretrainedModel, self).__init__(
68
+ config=config,
69
+ )
70
+ self.config = config
71
+
72
+ @classmethod
73
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
74
+ config = NXCleanUNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
75
+
76
+ model = cls(config)
77
+
78
+ if os.path.isdir(pretrained_model_name_or_path):
79
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
80
+ else:
81
+ ckpt_file = pretrained_model_name_or_path
82
+
83
+ with open(ckpt_file, "rb") as f:
84
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
85
+ model.load_state_dict(state_dict, strict=True)
86
+ return model
87
+
88
+ def save_pretrained(self,
89
+ save_directory: Union[str, os.PathLike],
90
+ state_dict: Optional[dict] = None,
91
+ ):
92
+
93
+ model = self
94
+
95
+ if state_dict is None:
96
+ state_dict = model.state_dict()
97
+
98
+ os.makedirs(save_directory, exist_ok=True)
99
+
100
+ # save state dict
101
+ model_file = os.path.join(save_directory, MODEL_FILE)
102
+ torch.save(state_dict, model_file)
103
+
104
+ # save config
105
+ config_file = os.path.join(save_directory, CONFIG_FILE)
106
+ self.config.to_yaml_file(config_file)
107
+ return save_directory
108
+
109
+
110
+ def main():
111
+ config = NXCleanUNetConfig()
112
+ discriminator = MetricDiscriminator(config=config)
113
+
114
+ # shape: [batch_size, num_samples]
115
+ x = torch.ones([4, int(4.5 * 16000)])
116
+ y = torch.ones([4, int(4.5 * 16000)])
117
+
118
+ output = discriminator.forward(x, y)
119
+ print(output.shape)
120
+ print(output)
121
+
122
+ return
123
+
124
+
125
+ if __name__ == "__main__":
126
+ main()
toolbox/torchaudio/models/nx_clean_unet/loss.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def anti_wrapping_function(x):
8
+
9
+ return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
10
+
11
+
12
+ def phase_losses(phase_r, phase_g):
13
+
14
+ ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
15
+ gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
16
+ iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
17
+
18
+ return ip_loss, gd_loss, iaf_loss
19
+
20
+
21
+ if __name__ == '__main__':
22
+ pass
toolbox/torchaudio/models/nx_clean_unet/metrics.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from joblib import Parallel, delayed
4
+ import numpy as np
5
+ from pesq import pesq
6
+ from typing import List
7
+
8
+ from pesq import cypesq
9
+
10
+
11
+ def run_pesq(clean_audio: np.ndarray,
12
+ noisy_audio: np.ndarray,
13
+ sample_rate: int = 16000,
14
+ mode: str = "wb",
15
+ ) -> float:
16
+ if sample_rate == 8000 and mode == "wb":
17
+ raise AssertionError(f"mode should be `nb` when sample_rate is 8000")
18
+ try:
19
+ pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode)
20
+ except cypesq.NoUtterancesError as e:
21
+ pesq_score = -1
22
+ except Exception as e:
23
+ print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
24
+ pesq_score = -1
25
+ return pesq_score
26
+
27
+
28
+ def run_batch_pesq(clean_audio_list: List[np.ndarray],
29
+ noisy_audio_list: List[np.ndarray],
30
+ sample_rate: int = 16000,
31
+ mode: str = "wb",
32
+ n_jobs: int = 4,
33
+ ) -> List[float]:
34
+ parallel = Parallel(n_jobs=n_jobs)
35
+
36
+ parallel_tasks = list()
37
+ for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list):
38
+ parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode)
39
+ parallel_tasks.append(parallel_task)
40
+
41
+ pesq_score_list = parallel.__call__(parallel_tasks)
42
+ return pesq_score_list
43
+
44
+
45
+ def run_pesq_score(clean_audio_list: List[np.ndarray],
46
+ noisy_audio_list: List[np.ndarray],
47
+ sample_rate: int = 16000,
48
+ mode: str = "wb",
49
+ n_jobs: int = 4,
50
+ ) -> List[float]:
51
+
52
+ pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list,
53
+ noisy_audio_list=noisy_audio_list,
54
+ sample_rate=sample_rate,
55
+ mode=mode,
56
+ n_jobs=n_jobs,
57
+ )
58
+
59
+ pesq_score = np.mean(pesq_score_list)
60
+ return pesq_score
61
+
62
+
63
+ def main():
64
+ clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
65
+ noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
66
+
67
+ clean_audio_list = list(clean_audio)
68
+ noisy_audio_list = list(noisy_audio)
69
+
70
+ pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list)
71
+ print(pesq_score_list)
72
+
73
+ pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list)
74
+ print(pesq_score)
75
+
76
+ return
77
+
78
+
79
+ if __name__ == "__main__":
80
+ main()
toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from typing import Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+
11
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
12
+ from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
13
+ from toolbox.torchaudio.models.nx_clean_unet.transformer.transformer import TransformerEncoder
14
+
15
+
16
+ class DownSamplingBlock(nn.Module):
17
+ def __init__(self,
18
+ in_channels: int,
19
+ hidden_channels: int,
20
+ kernel_size: int,
21
+ stride: int,
22
+ ):
23
+ super(DownSamplingBlock, self).__init__()
24
+ self.conv1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, stride)
25
+ self.relu = nn.ReLU()
26
+ self.conv2 = nn.Conv1d(hidden_channels, hidden_channels * 2, 1)
27
+ self.glu = nn.GLU(dim=1)
28
+
29
+ def forward(self, x: torch.Tensor):
30
+ # x shape: [batch_size, 1, num_samples]
31
+ x = self.conv1.forward(x)
32
+ # x shape: [batch_size, hidden_channels, new_num_samples]
33
+ x = self.relu(x)
34
+ x = self.conv2.forward(x)
35
+ # x shape: [batch_size, hidden_channels*2, new_num_samples]
36
+ x = self.glu(x)
37
+ # x shape: [batch_size, hidden_channels, new_num_samples]
38
+ # new_num_samples = (num_samples-kernel_size) // stride + 1
39
+ return x
40
+
41
+
42
+ class DownSampling(nn.Module):
43
+ def __init__(self,
44
+ num_layers: int,
45
+ in_channels: int,
46
+ hidden_channels: int,
47
+ kernel_size: int,
48
+ stride: int,
49
+ ):
50
+ super(DownSampling, self).__init__()
51
+ self.num_layers = num_layers
52
+
53
+ self.down_sampling_block_list = list()
54
+
55
+ for idx in range(self.num_layers):
56
+ down_sampling_block = DownSamplingBlock(
57
+ in_channels=in_channels,
58
+ hidden_channels=hidden_channels,
59
+ kernel_size=kernel_size,
60
+ stride=stride,
61
+ )
62
+ self.down_sampling_block_list.append(down_sampling_block)
63
+ in_channels = hidden_channels
64
+
65
+ def forward(self, x: torch.Tensor):
66
+ # x shape: [batch_size, channels, num_samples]
67
+ for down_sampling_block in self.down_sampling_block_list:
68
+ x = down_sampling_block.forward(x)
69
+ # x shape: [batch_size, hidden_channels, num_samples**]
70
+ return x
71
+
72
+
73
+ class UpSamplingBlock(nn.Module):
74
+ def __init__(self,
75
+ out_channels: int,
76
+ hidden_channels: int,
77
+ kernel_size: int,
78
+ stride: int,
79
+ do_relu: bool = True,
80
+ ):
81
+ super(UpSamplingBlock, self).__init__()
82
+ self.do_relu = do_relu
83
+
84
+ self.conv1 = nn.Conv1d(hidden_channels, hidden_channels * 2, 1)
85
+ self.glu = nn.GLU(dim=1)
86
+ self.convt = nn.ConvTranspose1d(hidden_channels, out_channels, kernel_size, stride)
87
+ self.relu = nn.ReLU()
88
+
89
+ def forward(self, x: torch.Tensor):
90
+ # x shape: [batch_size, hidden_channels*2, num_samples]
91
+ x = self.conv1.forward(x)
92
+ # x shape: [batch_size, hidden_channels, num_samples]
93
+ x = self.glu(x)
94
+ # x shape: [batch_size, hidden_channels, num_samples]
95
+ x = self.convt.forward(x)
96
+ # x shape: [batch_size, hidden_channels, new_num_samples]
97
+ # new_num_samples = (num_samples - 1) * stride + kernel_size
98
+ if self.do_relu:
99
+ x = self.relu(x)
100
+ return x
101
+
102
+
103
+ class UpSampling(nn.Module):
104
+ def __init__(self,
105
+ num_layers: int,
106
+ out_channels: int,
107
+ hidden_channels: int,
108
+ kernel_size: int,
109
+ stride: int,
110
+ ):
111
+ super(UpSampling, self).__init__()
112
+ self.num_layers = num_layers
113
+
114
+ self.up_sampling_block_list = list()
115
+
116
+ for idx in range(self.num_layers-1):
117
+ up_sampling_block = UpSamplingBlock(
118
+ out_channels=hidden_channels,
119
+ hidden_channels=hidden_channels,
120
+ kernel_size=kernel_size,
121
+ stride=stride,
122
+ do_relu=True,
123
+ )
124
+ self.up_sampling_block_list.append(up_sampling_block)
125
+ else:
126
+ up_sampling_block = UpSamplingBlock(
127
+ out_channels=out_channels,
128
+ hidden_channels=hidden_channels,
129
+ kernel_size=kernel_size,
130
+ stride=stride,
131
+ do_relu=False,
132
+ )
133
+ self.up_sampling_block_list.append(up_sampling_block)
134
+
135
+ def forward(self, x: torch.Tensor):
136
+ # x shape: [batch_size, channels, num_samples]
137
+ for up_sampling_block in self.up_sampling_block_list:
138
+ x = up_sampling_block.forward(x)
139
+ return x
140
+
141
+
142
+ def get_padding_length(length, num_layers: int, kernel_size: int, stride: int):
143
+ for _ in range(num_layers):
144
+ if length < kernel_size:
145
+ length = 1
146
+ else:
147
+ length = 1 + np.ceil((length - kernel_size) / stride)
148
+
149
+ for _ in range(num_layers):
150
+ length = (length - 1) * stride + kernel_size
151
+
152
+ padded_length = int(length)
153
+ return padded_length
154
+
155
+
156
+ class NXCleanUNet(nn.Module):
157
+ def __init__(self, config):
158
+ super().__init__()
159
+ self.config = config
160
+
161
+ self.down_sampling = DownSampling(
162
+ num_layers=config.down_sampling_num_layers,
163
+ in_channels=config.down_sampling_in_channels,
164
+ hidden_channels=config.down_sampling_hidden_channels,
165
+ kernel_size=config.down_sampling_kernel_size,
166
+ stride=config.down_sampling_stride,
167
+ )
168
+ self.transformer = TransformerEncoder(
169
+ input_size=config.down_sampling_hidden_channels,
170
+ hidden_size=config.tsfm_hidden_size,
171
+ attention_heads=config.tsfm_attention_heads,
172
+ num_blocks=config.tsfm_num_blocks,
173
+ dropout_rate=config.tsfm_dropout_rate,
174
+ )
175
+ self.up_sampling = UpSampling(
176
+ num_layers=config.down_sampling_num_layers,
177
+ out_channels=config.down_sampling_in_channels,
178
+ hidden_channels=config.down_sampling_hidden_channels,
179
+ kernel_size=config.down_sampling_kernel_size,
180
+ stride=config.down_sampling_stride,
181
+ )
182
+
183
+ def forward(self, noisy_audios: torch.Tensor):
184
+ # noisy_audios shape: [batch_size, 1, n_samples]
185
+
186
+ n_samples = noisy_audios.shape[-1]
187
+ padded_length = get_padding_length(
188
+ n_samples,
189
+ num_layers=self.config.down_sampling_num_layers,
190
+ kernel_size=self.config.down_sampling_kernel_size,
191
+ stride=self.config.down_sampling_stride,
192
+ )
193
+ noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
194
+
195
+ bottle_neck = self.down_sampling.forward(noisy_audios_padded)
196
+ # bottle_neck shape: [batch_size, channels, time_steps]
197
+
198
+ bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
199
+ # bottle_neck shape: [batch_size, time_steps, input_size]
200
+
201
+ bottle_neck = self.transformer.forward(bottle_neck)
202
+ # bottle_neck shape: [batch_size, time_steps, input_size]
203
+
204
+ bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
205
+ # bottle_neck shape: [batch_size, channels, time_steps]
206
+
207
+ enhanced_audios = self.up_sampling.forward(bottle_neck)
208
+
209
+ enhanced_audios = enhanced_audios[:, :, :n_samples]
210
+ # enhanced_audios shape: [batch_size, 1, n_samples]
211
+ return enhanced_audios
212
+
213
+
214
+ MODEL_FILE = "generator.pt"
215
+
216
+
217
+ class NXCleanUNetPretrainedModel(NXCleanUNet):
218
+ def __init__(self,
219
+ config: NXCleanUNetConfig,
220
+ ):
221
+ super(NXCleanUNetPretrainedModel, self).__init__(
222
+ config=config,
223
+ )
224
+ self.config = config
225
+
226
+ @classmethod
227
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
228
+ config = NXCleanUNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
229
+
230
+ model = cls(config)
231
+
232
+ if os.path.isdir(pretrained_model_name_or_path):
233
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
234
+ else:
235
+ ckpt_file = pretrained_model_name_or_path
236
+
237
+ with open(ckpt_file, "rb") as f:
238
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
239
+ model.load_state_dict(state_dict, strict=True)
240
+ return model
241
+
242
+ def save_pretrained(self,
243
+ save_directory: Union[str, os.PathLike],
244
+ state_dict: Optional[dict] = None,
245
+ ):
246
+
247
+ model = self
248
+
249
+ if state_dict is None:
250
+ state_dict = model.state_dict()
251
+
252
+ os.makedirs(save_directory, exist_ok=True)
253
+
254
+ # save state dict
255
+ model_file = os.path.join(save_directory, MODEL_FILE)
256
+ torch.save(state_dict, model_file)
257
+
258
+ # save config
259
+ config_file = os.path.join(save_directory, CONFIG_FILE)
260
+ self.config.to_yaml_file(config_file)
261
+ return save_directory
262
+
263
+
264
+
265
+ def main2():
266
+
267
+ config = NXCleanUNetConfig()
268
+ down_sampling = DownSampling(
269
+ num_layers=config.down_sampling_num_layers,
270
+ in_channels=config.down_sampling_in_channels,
271
+ hidden_channels=config.down_sampling_hidden_channels,
272
+ kernel_size=config.down_sampling_kernel_size,
273
+ stride=config.down_sampling_stride,
274
+ )
275
+ up_sampling = UpSampling(
276
+ num_layers=config.down_sampling_num_layers,
277
+ out_channels=config.down_sampling_in_channels,
278
+ hidden_channels=config.down_sampling_hidden_channels,
279
+ kernel_size=config.down_sampling_kernel_size,
280
+ stride=config.down_sampling_stride,
281
+ )
282
+
283
+ # shape: [batch_size, channels, num_samples]
284
+ # min length: 94, stride: 32, 32 == 2**5
285
+ # x = torch.ones([4, 1, 94])
286
+ # x = torch.ones([4, 1, 126])
287
+ # x = torch.ones([4, 1, 158])
288
+ x = torch.ones([4, 1, 190])
289
+
290
+ length = x.shape[-1]
291
+ padded_length = get_padding_length(
292
+ length,
293
+ num_layers=config.down_sampling_num_layers,
294
+ kernel_size=config.down_sampling_kernel_size,
295
+ stride=config.down_sampling_stride,
296
+ )
297
+ x = F.pad(input=x, pad=(0, padded_length - length), mode="constant", value=0)
298
+ # print(x)
299
+ print(x.shape)
300
+ bottle_neck = down_sampling.forward(x)
301
+ print("-" * 150)
302
+ x = up_sampling.forward(bottle_neck)
303
+ print(x.shape)
304
+ return
305
+
306
+
307
+ def main():
308
+
309
+ config = NXCleanUNetConfig()
310
+
311
+ # shape: [batch_size, channels, num_samples]
312
+ # min length: 94, stride: 32, 32 == 2**5
313
+ # x = torch.ones([4, 1, 94])
314
+ # x = torch.ones([4, 1, 126])
315
+ # x = torch.ones([4, 1, 158])
316
+ # x = torch.ones([4, 1, 190])
317
+ x = torch.ones([4, 1, 16000])
318
+
319
+ model = NXCleanUNet(config)
320
+ enhanced_audios = model.forward(x)
321
+ print(enhanced_audios.shape)
322
+ return
323
+
324
+
325
+ if __name__ == "__main__":
326
+ main()
toolbox/torchaudio/models/nx_clean_unet/transformer/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/models/nx_clean_unet/transformer/mask.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+
5
+
6
+ def make_pad_mask(lengths: torch.Tensor,
7
+ max_len: int = 0,
8
+ ) -> torch.Tensor:
9
+ batch_size = lengths.size(0)
10
+ max_len = max_len if max_len > 0 else lengths.max().item()
11
+ seq_range = torch.arange(
12
+ 0,
13
+ max_len,
14
+ dtype=torch.int64,
15
+ device=lengths.device
16
+ )
17
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
18
+ seq_length_expand = lengths.unsqueeze(-1)
19
+ mask = seq_range_expand >= seq_length_expand
20
+ return mask
21
+
22
+
23
+
24
+ def subsequent_chunk_mask(
25
+ size: int,
26
+ chunk_size: int,
27
+ num_left_chunks: int = -1,
28
+ device: torch.device = torch.device("cpu"),
29
+ ) -> torch.Tensor:
30
+ """
31
+ Create mask for subsequent steps (size, size) with chunk size,
32
+ this is for streaming encoder
33
+
34
+ Examples:
35
+ > subsequent_chunk_mask(4, 2)
36
+ [[1, 1, 0, 0],
37
+ [1, 1, 0, 0],
38
+ [1, 1, 1, 1],
39
+ [1, 1, 1, 1]]
40
+
41
+ :param size: int. size of mask.
42
+ :param chunk_size: int. size of chunk.
43
+ :param num_left_chunks: int. number of left chunks. <0: use full chunk. >=0 use num_left_chunks.
44
+ :param device: torch.device. "cpu" or "cuda" or torch.Tensor.device.
45
+ :return: torch.Tensor. mask
46
+ """
47
+
48
+ ret = torch.zeros(size, size, device=device, dtype=torch.bool)
49
+ for i in range(size):
50
+ if num_left_chunks < 0:
51
+ start = 0
52
+ else:
53
+ start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
54
+ ending = min((i // chunk_size + 1) * chunk_size, size)
55
+ ret[i, start:ending] = True
56
+ return ret
57
+
58
+
59
+ def main():
60
+ chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2)
61
+ print(chunk_mask)
62
+ return
63
+
64
+
65
+ if __name__ == '__main__':
66
+ main()
toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import math
4
+ from typing import Dict, Optional, Tuple, List, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as f
9
+
10
+ from toolbox.torchaudio.models.nx_clean_unet.transformer.mask import subsequent_chunk_mask
11
+
12
+
13
+ class SinusoidalPositionalEncoding(nn.Module):
14
+ """
15
+ Positional Encoding
16
+
17
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
18
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
19
+ """
20
+
21
+ @staticmethod
22
+ def demo1():
23
+ batch_size = 2
24
+ time_steps = 10
25
+ embedding_dim = 64
26
+
27
+ pe = SinusoidalPositionalEncoding(
28
+ embedding_dim=embedding_dim,
29
+ dropout_rate=0.1,
30
+ )
31
+
32
+ x = torch.randn(size=(batch_size, time_steps, embedding_dim))
33
+
34
+ x, pos_emb = pe.forward(x)
35
+
36
+ # torch.Size([2, 10, 64])
37
+ print(x.shape)
38
+ # torch.Size([1, 10, 64])
39
+ print(pos_emb.shape)
40
+ return
41
+
42
+ @staticmethod
43
+ def demo2():
44
+ batch_size = 2
45
+ time_steps = 10
46
+ embedding_dim = 64
47
+
48
+ pe = SinusoidalPositionalEncoding(
49
+ embedding_dim=embedding_dim,
50
+ dropout_rate=0.1,
51
+ )
52
+
53
+ x = torch.randn(size=(batch_size, time_steps, embedding_dim))
54
+ offset = torch.randint(low=3, high=7, size=(batch_size,))
55
+ x, pos_emb = pe.forward(x, offset=offset)
56
+
57
+ # tensor([3, 4])
58
+ print(offset)
59
+ # torch.Size([2, 10, 64])
60
+ print(x.shape)
61
+ # torch.Size([2, 10, 64])
62
+ print(pos_emb.shape)
63
+ return
64
+
65
+ def __init__(self,
66
+ embedding_dim: int,
67
+ dropout_rate: float,
68
+ max_length: int = 5000,
69
+ reverse: bool = False
70
+ ):
71
+ super().__init__()
72
+ self.embedding_dim = embedding_dim
73
+ self.dropout_rate = dropout_rate
74
+ self.max_length = max_length
75
+ self.reverse = reverse
76
+
77
+ self.x_scale = math.sqrt(self.embedding_dim)
78
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
79
+
80
+ self.pe = torch.zeros(self.max_length, self.embedding_dim)
81
+ position = torch.arange(0, self.max_length, dtype=torch.float32).unsqueeze(1)
82
+
83
+ div_term = torch.exp(
84
+ torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) *
85
+ - (math.log(10000.0) / self.embedding_dim)
86
+ )
87
+ self.pe[:, 0::2] = torch.sin(position * div_term)
88
+ self.pe[:, 1::2] = torch.cos(position * div_term)
89
+ self.pe = self.pe.unsqueeze(0)
90
+
91
+ def forward(self,
92
+ x: torch.Tensor,
93
+ offset: Union[int, torch.Tensor] = 0
94
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
95
+ """
96
+ Add positional encoding.
97
+ :param x: torch.Tensor. Input. shape=(batch_size, time_steps, ...).
98
+ :param offset: int or torch.Tensor. position offset.
99
+ :return:
100
+ torch.Tensor. Encoded tensor. shape=(batch_size, time_steps, ...).
101
+ torch.Tensor. for compatibility to RelPositionalEncoding. shape=(1, time_steps, ...).
102
+ """
103
+ self.pe = self.pe.to(x.device)
104
+ pos_emb = self.position_encoding(offset, x.size(1), False)
105
+ x = x * self.x_scale + pos_emb
106
+ return self.dropout(x), self.dropout(pos_emb)
107
+
108
+ def position_encoding(self,
109
+ offset: Union[int, torch.Tensor],
110
+ size: int,
111
+ apply_dropout: bool = True
112
+ ) -> torch.Tensor:
113
+ """
114
+ For getting encoding in a streaming fashion.
115
+
116
+ Attention!!!!!
117
+ we apply dropout only once at the whole utterance level in a none
118
+ streaming way, but will call this function several times with
119
+ increasing input size in a streaming scenario, so the dropout will
120
+ be applied several times.
121
+
122
+ :param offset: int or torch.Tensor. start offset.
123
+ :param size: int. required size of position encoding.
124
+ :param apply_dropout:
125
+ :return: torch.Tensor. Corresponding encoding.
126
+ """
127
+ if isinstance(offset, int):
128
+ assert offset + size <= self.max_length
129
+ pos_emb = self.pe[:, offset:offset + size]
130
+ elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
131
+ assert offset + size <= self.max_length
132
+ pos_emb = self.pe[:, offset:offset + size]
133
+ else: # for batched streaming decoding on GPU
134
+ # offset. shape=(batch_size,)
135
+ assert torch.max(offset) + size <= self.max_length
136
+
137
+ # shape=(batch_size, time_steps)
138
+ index = offset.unsqueeze(1) + torch.arange(0, size).to(offset.device)
139
+ flag = index > 0
140
+ # remove negative offset
141
+ index = index * flag
142
+ # shape=(batch_size, time_steps, embedding_dim)
143
+ pos_emb = f.embedding(index, self.pe[0])
144
+
145
+ if apply_dropout:
146
+ pos_emb = self.dropout(pos_emb)
147
+ return pos_emb
148
+
149
+
150
+ class RelPositionalEncoding(SinusoidalPositionalEncoding):
151
+ """
152
+ Relative positional encoding module.
153
+
154
+ See : Appendix B in https://arxiv.org/abs/1901.02860
155
+
156
+ """
157
+ def __init__(self,
158
+ embedding_dim: int,
159
+ dropout_rate: float,
160
+ max_length: int = 5000,
161
+ ):
162
+ super().__init__(embedding_dim, dropout_rate, max_length, reverse=True)
163
+
164
+ def forward(self,
165
+ x: torch.Tensor,
166
+ offset: Union[int, torch.Tensor] = 0
167
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
168
+ """
169
+ Compute positional encoding.
170
+ :param x: torch.Tensor. Input. shape=(batch_size, time_steps, ...).
171
+ :param offset:
172
+ :return:
173
+ torch.Tensor. Encoded tensor. shape=(batch_size, time_steps, ...).
174
+ torch.Tensor. Positional embedding tensor. shape=(1, time_steps, ...).
175
+ """
176
+ self.pe = self.pe.to(x.device)
177
+ x = x * self.x_scale
178
+ pos_emb = self.position_encoding(offset, x.size(1), False)
179
+ return self.dropout(x), self.dropout(pos_emb)
180
+
181
+
182
+ class PositionwiseFeedForward(nn.Module):
183
+ def __init__(self,
184
+ input_dim: int,
185
+ hidden_units: int,
186
+ dropout_rate: float,
187
+ activation: torch.nn.Module = torch.nn.ReLU()):
188
+ """
189
+ FeedForward are applied on each position of the sequence.
190
+ the output dim is same with the input dim.
191
+
192
+ :param input_dim: int. input dimension.
193
+ :param hidden_units: int. the number of hidden units.
194
+ :param dropout_rate: float. dropout rate.
195
+ :param activation: torch.nn.Module. activation function.
196
+ """
197
+ super(PositionwiseFeedForward, self).__init__()
198
+ self.w_1 = torch.nn.Linear(input_dim, hidden_units)
199
+ self.activation = activation
200
+ self.dropout = torch.nn.Dropout(dropout_rate)
201
+ self.w_2 = torch.nn.Linear(hidden_units, input_dim)
202
+
203
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
204
+ """
205
+ Forward function.
206
+ :param xs: torch.Tensor. input tensor. shape=(batch_size, max_length, dim).
207
+ :return: output tensor. shape=(batch_size, max_length, dim).
208
+ """
209
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
210
+
211
+
212
+ class MultiHeadedAttention(nn.Module):
213
+ def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
214
+ """
215
+ :param n_head: int. the number of heads.
216
+ :param n_feat: int. the number of features.
217
+ :param dropout_rate: float. dropout rate.
218
+ """
219
+ super().__init__()
220
+ assert n_feat % n_head == 0
221
+ # We assume d_v always equals d_k
222
+ self.d_k = n_feat // n_head
223
+ self.h = n_head
224
+ self.linear_q = nn.Linear(n_feat, n_feat)
225
+ self.linear_k = nn.Linear(n_feat, n_feat)
226
+ self.linear_v = nn.Linear(n_feat, n_feat)
227
+ self.linear_out = nn.Linear(n_feat, n_feat)
228
+ self.dropout = nn.Dropout(p=dropout_rate)
229
+
230
+ def forward_qkv(self,
231
+ query: torch.Tensor,
232
+ key: torch.Tensor,
233
+ value: torch.Tensor
234
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
235
+ """
236
+ transform query, key and value.
237
+ :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
238
+ :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
239
+ :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
240
+ :return:
241
+ """
242
+ n_batch = query.size(0)
243
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
244
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
245
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
246
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
247
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
248
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
249
+
250
+ return q, k, v
251
+
252
+ def forward_attention(self,
253
+ value: torch.Tensor,
254
+ scores: torch.Tensor,
255
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
256
+ ) -> torch.Tensor:
257
+ """
258
+ compute attention context vector.
259
+ :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, time2, d_k).
260
+ :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, time1, time2).
261
+ :param mask: torch.Tensor. mask. shape=(batch_size, 1, time2) or
262
+ (batch_size, time1, time2), (0, 0, 0) means fake mask.
263
+ :return: torch.Tensor. transformed value. (batch_size, time1, d_model).
264
+ weighted by the attention score (batch_size, time1, time2).
265
+ """
266
+ n_batch = value.size(0)
267
+ # NOTE: When will `if mask.size(2) > 0` be True?
268
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
269
+ # 1st chunk to ease the onnx export.]
270
+ # 2. pytorch training
271
+ if mask.size(2) > 0: # time2 > 0
272
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
273
+ # For last chunk, time2 might be larger than scores.size(-1)
274
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
275
+ scores = scores.masked_fill(mask, -float('inf'))
276
+ attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
277
+
278
+ # NOTE: When will `if mask.size(2) > 0` be False?
279
+ # 1. onnx(16/-1, -1/-1, 16/0)
280
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
281
+ else:
282
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
283
+
284
+ p_attn = self.dropout(attn)
285
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
286
+ x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
287
+
288
+ return self.linear_out(x) # (batch, time1, n_feat)
289
+
290
+ def forward(self,
291
+ query: torch.Tensor,
292
+ key: torch.Tensor,
293
+ value: torch.Tensor,
294
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
295
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
296
+ **kwargs,
297
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
298
+ """
299
+ compute scaled dot product attention.
300
+ :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
301
+ :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
302
+ :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
303
+ :param mask: torch.Tensor. mask tensor (batch_size, 1, time2) or
304
+ (batch_size, time1, time2).
305
+ :param cache: torch.Tensor. cache tensor. shape=(1, head, cache_t, d_k * 2),
306
+ where `cache_t == chunk_size * num_decoding_left_chunks`
307
+ and `head * d_k == n_feat`
308
+ :return:
309
+ torch.Tensor. output tensor. shape=(batch_size, time1, n_feat).
310
+ torch.Tensor. cache tensor. (1, head, cache_t + time1, d_k * 2)
311
+ where `cache_t == chunk_size * num_decoding_left_chunks`
312
+ and `head * d_k == n_feat`
313
+ """
314
+ q, k, v = self.forward_qkv(query, key, value)
315
+
316
+ # NOTE:
317
+ # when export onnx model, for 1st chunk, we feed
318
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
319
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
320
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
321
+ # and we will always do splitting and
322
+ # concatnation(this will simplify onnx export). Note that
323
+ # it's OK to concat & split zero-shaped tensors(see code below).
324
+ # when export jit model, for 1st chunk, we always feed
325
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
326
+ # >>> a = torch.ones((1, 2, 0, 4))
327
+ # >>> b = torch.ones((1, 2, 3, 4))
328
+ # >>> c = torch.cat((a, b), dim=2)
329
+ # >>> torch.equal(b, c) # True
330
+ # >>> d = torch.split(a, 2, dim=-1)
331
+ # >>> torch.equal(d[0], d[1]) # True
332
+ if cache.size(0) > 0:
333
+ key_cache, value_cache = torch.split(
334
+ cache, cache.size(-1) // 2, dim=-1)
335
+ k = torch.cat([key_cache, k], dim=2)
336
+ v = torch.cat([value_cache, v], dim=2)
337
+ # NOTE: We do cache slicing in encoder.forward_chunk, since it's
338
+ # non-trivial to calculate `next_cache_start` here.
339
+ new_cache = torch.cat((k, v), dim=-1)
340
+
341
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
342
+ return self.forward_attention(v, scores, mask), new_cache
343
+
344
+
345
+ class TransformerEncoderLayer(nn.Module):
346
+ def __init__(self,
347
+ input_dim: int,
348
+ dropout_rate: float = 0.1,
349
+ n_heads: int = 4,
350
+ ):
351
+ super().__init__()
352
+ self.norm1 = nn.LayerNorm(input_dim, eps=1e-5)
353
+ self.attention = MultiHeadedAttention(
354
+ n_head=n_heads,
355
+ n_feat=input_dim,
356
+ dropout_rate=dropout_rate
357
+ )
358
+
359
+ self.dropout1 = nn.Dropout(dropout_rate)
360
+ self.norm2 = nn.LayerNorm(input_dim, eps=1e-5)
361
+ self.ffn = PositionwiseFeedForward(
362
+ input_dim=input_dim,
363
+ hidden_units=input_dim,
364
+ dropout_rate=dropout_rate
365
+ )
366
+ self.dropout2 = nn.Dropout(dropout_rate)
367
+ self.norm3 = nn.LayerNorm(input_dim, eps=1e-5)
368
+
369
+ def forward(
370
+ self,
371
+ x: torch.Tensor,
372
+ mask: torch.Tensor,
373
+ position_embedding: torch.Tensor,
374
+ attention_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
375
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
376
+ """
377
+
378
+ :param x: torch.Tensor. shape=(batch_size, time, input_dim).
379
+ :param mask: torch.Tensor. mask tensor for the input. shape=(batch_size, time,time).
380
+ :param position_embedding: torch.Tensor.
381
+ :param attention_cache: torch.Tensor. cache tensor of the KEY & VALUE
382
+ shape=(batch_size=1, head, cache_t1, d_k * 2), head * d_k == input_dim.
383
+ :return:
384
+ torch.Tensor: Output tensor (batch_size, time, input_dim).
385
+ torch.Tensor: att_cache tensor, (batch_size=1, head, cache_t1 + time, d_k * 2).
386
+ """
387
+
388
+ xt = self.norm1(x)
389
+
390
+ x_att, new_att_cache = self.attention.forward(
391
+ xt, xt, xt, mask=mask, cache=attention_cache, position_embedding=position_embedding
392
+ )
393
+ x = x + self.dropout1(xt)
394
+ xt = self.norm2(x)
395
+ xt = self.ffn.forward(xt)
396
+ x = x + self.dropout2(xt)
397
+
398
+ x = self.norm3(x)
399
+
400
+ return x, new_att_cache
401
+
402
+
403
+ class TransformerEncoder(nn.Module):
404
+ """
405
+ https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/encoder.py#L364
406
+ """
407
+ def __init__(self,
408
+ input_size: int = 64,
409
+ hidden_size: int = 256,
410
+ attention_heads: int = 4,
411
+ num_blocks: int = 6,
412
+ dropout_rate: float = 0.1,
413
+ max_length: int = 512,
414
+ chunk_size: int = 1,
415
+ num_left_chunks: int = 128,
416
+ ):
417
+ super().__init__()
418
+ self.input_size = input_size
419
+ self.hidden_size = hidden_size
420
+
421
+ self.max_length = max_length
422
+ self.chunk_size = chunk_size
423
+ self.num_left_chunks = num_left_chunks
424
+
425
+ self.input_linear = nn.Linear(
426
+ in_features=self.input_size,
427
+ out_features=self.hidden_size,
428
+ )
429
+
430
+ self.positional_encoding = RelPositionalEncoding(
431
+ embedding_dim=hidden_size,
432
+ dropout_rate=dropout_rate,
433
+ max_length=max_length,
434
+ )
435
+
436
+ self.encoder_layer_list = torch.nn.ModuleList([
437
+ TransformerEncoderLayer(
438
+ input_dim=hidden_size,
439
+ n_heads=attention_heads,
440
+ dropout_rate=dropout_rate,
441
+ ) for _ in range(num_blocks)
442
+ ])
443
+
444
+ self.output_linear = nn.Linear(
445
+ in_features=self.hidden_size,
446
+ out_features=self.input_size,
447
+ )
448
+
449
+ def forward(self,
450
+ xs: torch.Tensor,
451
+ ):
452
+ """
453
+ :param xs: Tensor, shape: [batch_size, time_steps, input_size]
454
+ :return: Tensor, shape: [batch_size, time_steps, input_size]
455
+ """
456
+ batch_size, time_steps, _ = xs.shape
457
+ # xs shape: [batch_size, time_steps, input_size]
458
+ xs = self.input_linear.forward(xs)
459
+ # xs shape: [batch_size, time_steps, hidden_size]
460
+
461
+ xs, position_embedding = self.positional_encoding.forward(xs)
462
+ # xs shape: [batch_size, time_steps, hidden_size]
463
+ # position_embedding shape: [1, time_steps, hidden_size]
464
+
465
+ chunk_masks = subsequent_chunk_mask(
466
+ size=time_steps,
467
+ chunk_size=self.chunk_size,
468
+ num_left_chunks=self.num_left_chunks
469
+ )
470
+ # chunk_masks shape: [1, time_steps, time_steps]
471
+ chunk_masks = torch.broadcast_to(chunk_masks, size=(batch_size, time_steps, time_steps))
472
+ # chunk_masks shape: [batch_size, time_steps, time_steps]
473
+
474
+ for encoder_layer in self.encoder_layer_list:
475
+ xs, _ = encoder_layer.forward(xs, chunk_masks, position_embedding)
476
+
477
+ # xs shape: [batch_size, time_steps, hidden_size]
478
+ xs = self.output_linear.forward(xs)
479
+ # xs shape: [batch_size, time_steps, input_size]
480
+
481
+ return xs
482
+
483
+ def forward_chunk(self,
484
+ xs: torch.Tensor,
485
+ offset: int,
486
+ attention_mask: torch.Tensor = torch.zeros(0, 0, 0),
487
+ attention_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
488
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
489
+ """
490
+ Forward just one chunk.
491
+ :param xs: torch.Tensor. chunk input, with shape (b=1, time, mel-dim),
492
+ where `time == (chunk_size - 1) * subsample_rate + subsample.right_context + 1`
493
+ :param offset: int. current offset in encoder output timestamp.
494
+ :param attention_mask:
495
+ :param attention_cache: torch.Tensor. cache tensor for KEY & VALUE in
496
+ transformer/conformer attention, with shape
497
+ (elayers, head, cache_t1, d_k * 2), where
498
+ `head * d_k == hidden-dim` and
499
+ `cache_t1 == chunk_size * num_decoding_left_chunks`.
500
+ :return:
501
+ """
502
+ # xs shape: [batch_size, time_steps, input_size]
503
+ xs = self.input_linear.forward(xs)
504
+ # xs shape: [batch_size, time_steps, hidden_size]
505
+
506
+ xs, position_embedding = self.positional_encoding.forward(xs, offset=offset)
507
+ # xs shape: [batch_size, time_steps, hidden_size]
508
+ # position_embedding shape: [1, time_steps, hidden_size]
509
+
510
+ r_att_cache = []
511
+ for encoder_layer in self.encoder_layer_list:
512
+ xs, new_att_cache = encoder_layer.forward(
513
+ x=xs, mask=attention_mask,
514
+ position_embedding=position_embedding,
515
+ attention_cache=attention_cache,
516
+ )
517
+ r_att_cache.append(new_att_cache[:, :, self.chunk_size:, :])
518
+
519
+ r_att_cache = torch.cat(r_att_cache, dim=0)
520
+
521
+ return xs, r_att_cache
522
+
523
+ def forward_chunk_by_chunk(
524
+ self,
525
+ xs: torch.Tensor,
526
+ ) -> torch.Tensor:
527
+
528
+ batch_size, time_steps, _ = xs.shape
529
+
530
+ offset = 0
531
+ attention_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
532
+ attention_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
533
+
534
+ outputs = []
535
+ for idx in range(0, time_steps - self.chunk_size + 1, self.chunk_size):
536
+ begin = idx * self.chunk_size
537
+ end = begin + self.chunk_size
538
+ chunk_xs = xs[:, begin:end, :]
539
+
540
+ ys, att_cache = self.forward_chunk(
541
+ xs=chunk_xs, attention_mask=attention_mask,
542
+ offset=offset, attention_cache=attention_cache
543
+ )
544
+ # xs shape: [batch_size, chunk_size, hidden_size]
545
+ ys = self.output_linear.forward(ys)
546
+ # xs shape: [batch_size, chunk_size, input_size]
547
+
548
+ offset += self.chunk_size
549
+ outputs.append(ys)
550
+
551
+ ys = torch.cat(outputs, 1)
552
+ return ys
553
+
554
+
555
+ def main():
556
+
557
+ encoder = TransformerEncoder(
558
+ input_size=64,
559
+ hidden_size=256,
560
+ attention_heads=4,
561
+ num_blocks=6,
562
+ dropout_rate=0.1,
563
+ )
564
+
565
+ x = torch.ones([4, 200, 64])
566
+
567
+ y = encoder.forward(xs=x)
568
+ print(y.shape)
569
+
570
+ # y = encoder.forward_chunk_by_chunk(xs=x)
571
+ # print(y.shape)
572
+
573
+ return
574
+
575
+
576
+ if __name__ == '__main__':
577
+ main()
toolbox/torchaudio/models/nx_clean_unet/utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class LearnableSigmoid1d(nn.Module):
8
+ def __init__(self, in_features, beta=1):
9
+ super().__init__()
10
+ self.beta = beta
11
+ self.slope = nn.Parameter(torch.ones(in_features))
12
+ self.slope.requiresGrad = True
13
+
14
+ def forward(self, x):
15
+ # x shape: [batch_size, time_steps, spec_bins]
16
+ return self.beta * torch.sigmoid(self.slope * x)
17
+
18
+
19
+ def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
20
+
21
+ hann_window = torch.hann_window(win_size).to(y.device)
22
+ stft_spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window,
23
+ center=center, pad_mode='reflect', normalized=False, return_complex=True)
24
+ stft_spec = torch.view_as_real(stft_spec)
25
+ mag = torch.sqrt(stft_spec.pow(2).sum(-1) + 1e-9)
26
+ pha = torch.atan2(stft_spec[:, :, :, 1] + 1e-10, stft_spec[:, :, :, 0] + 1e-5)
27
+ # Magnitude Compression
28
+ mag = torch.pow(mag, compress_factor)
29
+ com = torch.stack((mag*torch.cos(pha), mag*torch.sin(pha)), dim=-1)
30
+
31
+ return mag, pha, com
32
+
33
+
34
+ def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
35
+ # Magnitude Decompression
36
+ mag = torch.pow(mag, (1.0/compress_factor))
37
+ com = torch.complex(mag*torch.cos(pha), mag*torch.sin(pha))
38
+ hann_window = torch.hann_window(win_size).to(com.device)
39
+ wav = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center)
40
+
41
+ return wav
42
+
43
+
44
+ if __name__ == '__main__':
45
+ pass
toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "nx_clean_unet"
2
+
3
+ sample_rate: 8000
4
+ segment_size: 16000
5
+ n_fft: 512
6
+ win_size: 200
7
+ hop_size: 80
8
+
9
+ down_sampling_num_layers: 5
10
+ down_sampling_in_channels: 1
11
+ down_sampling_hidden_channels: 64
12
+ down_sampling_kernel_size: 4
13
+ down_sampling_stride: 2
14
+
15
+ tsfm_hidden_size: 256
16
+ tsfm_attention_heads: 4
17
+ tsfm_num_blocks: 6
18
+ tsfm_dropout_rate: 0.1
19
+
20
+ discriminator_dim: 32
21
+ discriminator_in_channel: 2
22
+
23
+ compress_factor: 0.3