HoneyTian commited on
Commit
bd94e77
·
0 Parent(s):

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .gitignore +21 -0
  3. Dockerfile +21 -0
  4. README.md +26 -0
  5. examples/simple_linear_irm_aishell/run.sh +172 -0
  6. examples/simple_linear_irm_aishell/step_1_prepare_data.py +196 -0
  7. examples/simple_linear_irm_aishell/step_2_train_model.py +348 -0
  8. examples/simple_linear_irm_aishell/step_3_evaluation.py +239 -0
  9. examples/simple_linear_irm_aishell/yaml/config.yaml +13 -0
  10. examples/simple_lstm_irm_aishell/run.sh +172 -0
  11. examples/simple_lstm_irm_aishell/step_1_prepare_data.py +197 -0
  12. examples/simple_lstm_irm_aishell/step_2_train_model.py +348 -0
  13. examples/simple_lstm_irm_aishell/step_3_evaluation.py +239 -0
  14. examples/spectrum_unet_irm_aishell/run.sh +174 -0
  15. examples/spectrum_unet_irm_aishell/step_1_prepare_data.py +197 -0
  16. examples/spectrum_unet_irm_aishell/step_2_train_model.py +371 -0
  17. examples/spectrum_unet_irm_aishell/step_3_evaluation.py +270 -0
  18. examples/spectrum_unet_irm_aishell/yaml/config.yaml +35 -0
  19. install.sh +64 -0
  20. main.py +45 -0
  21. project_settings.py +25 -0
  22. requirements-python-3-9-9.txt +10 -0
  23. requirements.txt +10 -0
  24. toolbox/__init__.py +6 -0
  25. toolbox/json/__init__.py +6 -0
  26. toolbox/json/misc.py +63 -0
  27. toolbox/os/__init__.py +6 -0
  28. toolbox/os/command.py +59 -0
  29. toolbox/os/environment.py +114 -0
  30. toolbox/os/other.py +9 -0
  31. toolbox/torch/__init__.py +6 -0
  32. toolbox/torch/utils/__init__.py +6 -0
  33. toolbox/torch/utils/data/__init__.py +6 -0
  34. toolbox/torch/utils/data/dataset/__init__.py +6 -0
  35. toolbox/torch/utils/data/dataset/denoise_excel_dataset.py +131 -0
  36. toolbox/torchaudio/__init__.py +5 -0
  37. toolbox/torchaudio/configuration_utils.py +63 -0
  38. toolbox/torchaudio/models/__init__.py +5 -0
  39. toolbox/torchaudio/models/clean_unet/__init__.py +6 -0
  40. toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py +9 -0
  41. toolbox/torchaudio/models/dfnet3/__init__.py +6 -0
  42. toolbox/torchaudio/models/dfnet3/configuration_dfnet3.py +89 -0
  43. toolbox/torchaudio/models/dfnet3/features.py +192 -0
  44. toolbox/torchaudio/models/dfnet3/modeling_dfnet3.py +835 -0
  45. toolbox/torchaudio/models/dfnet3/multiframes.py +145 -0
  46. toolbox/torchaudio/models/dfnet3/utils.py +17 -0
  47. toolbox/torchaudio/models/ehnet/__init__.py +6 -0
  48. toolbox/torchaudio/models/ehnet/modeling_ehnet.py +132 -0
  49. toolbox/torchaudio/models/percepnet/__init__.py +6 -0
  50. toolbox/torchaudio/models/percepnet/modeling_percetnet.py +11 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ .git/
3
+ .idea/
4
+
5
+ **/evaluation_audio/
6
+ **/file_dir/
7
+ **/flagged/
8
+ **/log/
9
+ **/logs/
10
+ **/__pycache__/
11
+
12
+ /data/
13
+ /docs/
14
+ /dotenv/
15
+ /hub_datasets/
16
+ /thirdparty/
17
+ /trained_models/
18
+ /temp/
19
+
20
+ #**/*.wav
21
+ **/*.xlsx
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12
2
+
3
+ WORKDIR /code
4
+
5
+ COPY . /code
6
+
7
+ RUN pip install --upgrade pip
8
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
9
+
10
+ RUN useradd -m -u 1000 user
11
+
12
+ USER user
13
+
14
+ ENV HOME=/home/user \
15
+ PATH=/home/user/.local/bin:$PATH
16
+
17
+ WORKDIR $HOME/app
18
+
19
+ COPY --chown=user . $HOME/app
20
+
21
+ CMD ["python3", "main.py"]
README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: VM Sound Classification
3
+ emoji: 🐢
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: docker
7
+ pinned: false
8
+ license: apache-2.0
9
+ ---
10
+
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
12
+ ## NX Denoise
13
+
14
+
15
+ ### speech datasets
16
+
17
+ ```text
18
+
19
+ AISHELL (15G)
20
+ https://openslr.trmal.net/resources/33/
21
+
22
+ AISHELL-3 (19G)
23
+ http://www.openslr.org/93/
24
+
25
+ ```
26
+
examples/simple_linear_irm_aishell/run.sh ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir
6
+
7
+ sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir
8
+
9
+ sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name file_dir \
10
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
11
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
12
+
13
+
14
+ END
15
+
16
+
17
+ # params
18
+ system_version="windows";
19
+ verbose=true;
20
+ stage=0 # start from 0 if you need to start from data preparation
21
+ stop_stage=9
22
+
23
+ work_dir="$(pwd)"
24
+ file_folder_name=file_folder_name
25
+ final_model_name=final_model_name
26
+ config_file="yaml/config.yaml"
27
+ limit=10
28
+
29
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
30
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
31
+
32
+ nohup_name=nohup.out
33
+
34
+ # model params
35
+ batch_size=64
36
+ max_epochs=200
37
+ save_top_k=10
38
+ patience=5
39
+
40
+
41
+ # parse options
42
+ while true; do
43
+ [ -z "${1:-}" ] && break; # break if there are no arguments
44
+ case "$1" in
45
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
46
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
47
+ old_value="(eval echo \\$$name)";
48
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
49
+ was_bool=true;
50
+ else
51
+ was_bool=false;
52
+ fi
53
+
54
+ # Set the variable to the right value-- the escaped quotes make it work if
55
+ # the option had spaces, like --cmd "queue.pl -sync y"
56
+ eval "${name}=\"$2\"";
57
+
58
+ # Check that Boolean-valued arguments are really Boolean.
59
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
60
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
61
+ exit 1;
62
+ fi
63
+ shift 2;
64
+ ;;
65
+
66
+ *) break;
67
+ esac
68
+ done
69
+
70
+ file_dir="${work_dir}/${file_folder_name}"
71
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
72
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
73
+
74
+ dataset="${file_dir}/dataset.xlsx"
75
+ train_dataset="${file_dir}/train.xlsx"
76
+ valid_dataset="${file_dir}/valid.xlsx"
77
+
78
+ $verbose && echo "system_version: ${system_version}"
79
+ $verbose && echo "file_folder_name: ${file_folder_name}"
80
+
81
+ if [ $system_version == "windows" ]; then
82
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
83
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
84
+ #source /data/local/bin/nx_denoise/bin/activate
85
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
86
+ fi
87
+
88
+
89
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
90
+ $verbose && echo "stage 1: prepare data"
91
+ cd "${work_dir}" || exit 1
92
+ python3 step_1_prepare_data.py \
93
+ --file_dir "${file_dir}" \
94
+ --noise_dir "${noise_dir}" \
95
+ --speech_dir "${speech_dir}" \
96
+ --train_dataset "${train_dataset}" \
97
+ --valid_dataset "${valid_dataset}" \
98
+
99
+ fi
100
+
101
+
102
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
103
+ $verbose && echo "stage 2: train model"
104
+ cd "${work_dir}" || exit 1
105
+ python3 step_2_train_model.py \
106
+ --train_dataset "${train_dataset}" \
107
+ --valid_dataset "${valid_dataset}" \
108
+ --serialization_dir "${file_dir}" \
109
+ --config_file "${config_file}" \
110
+
111
+ fi
112
+
113
+
114
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
115
+ $verbose && echo "stage 3: test model"
116
+ cd "${work_dir}" || exit 1
117
+ python3 step_3_evaluation.py \
118
+ --valid_dataset "${valid_dataset}" \
119
+ --model_dir "${file_dir}/best" \
120
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
121
+ --limit "${limit}" \
122
+
123
+ fi
124
+
125
+
126
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
127
+ $verbose && echo "stage 4: export model"
128
+ cd "${work_dir}" || exit 1
129
+ python3 step_5_export_models.py \
130
+ --vocabulary_dir "${vocabulary_dir}" \
131
+ --model_dir "${file_dir}/best" \
132
+ --serialization_dir "${file_dir}" \
133
+
134
+ fi
135
+
136
+
137
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
138
+ $verbose && echo "stage 5: collect files"
139
+ cd "${work_dir}" || exit 1
140
+
141
+ mkdir -p ${final_model_dir}
142
+
143
+ cp "${file_dir}/best"/* "${final_model_dir}"
144
+ cp -r "${file_dir}/vocabulary" "${final_model_dir}"
145
+
146
+ cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
147
+
148
+ cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
149
+ cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
150
+ cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
151
+ cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
152
+
153
+ cd "${final_model_dir}/.." || exit 1;
154
+
155
+ if [ -e "${final_model_name}.zip" ]; then
156
+ rm -rf "${final_model_name}_backup.zip"
157
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
158
+ fi
159
+
160
+ zip -r "${final_model_name}.zip" "${final_model_name}"
161
+ rm -rf "${final_model_name}"
162
+
163
+ fi
164
+
165
+
166
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
167
+ $verbose && echo "stage 6: clear file_dir"
168
+ cd "${work_dir}" || exit 1
169
+
170
+ rm -rf "${file_dir}";
171
+
172
+ fi
examples/simple_linear_irm_aishell/step_1_prepare_data.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import random
7
+ import sys
8
+ import shutil
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import pandas as pd
14
+ from scipy.io import wavfile
15
+ from tqdm import tqdm
16
+ import librosa
17
+
18
+ from project_settings import project_path
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--file_dir", default="./", type=str)
24
+
25
+ parser.add_argument(
26
+ "--noise_dir",
27
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
28
+ type=str
29
+ )
30
+ parser.add_argument(
31
+ "--speech_dir",
32
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
33
+ type=str
34
+ )
35
+
36
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
+
39
+ parser.add_argument("--duration", default=2.0, type=float)
40
+ parser.add_argument("--min_nsr_db", default=-20, type=float)
41
+ parser.add_argument("--max_nsr_db", default=5, type=float)
42
+
43
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
44
+
45
+ args = parser.parse_args()
46
+ return args
47
+
48
+
49
+ def filename_generator(data_dir: str):
50
+ data_dir = Path(data_dir)
51
+ for filename in data_dir.glob("**/*.wav"):
52
+ yield filename.as_posix()
53
+
54
+
55
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
56
+ data_dir = Path(data_dir)
57
+ for filename in data_dir.glob("**/*.wav"):
58
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
+
61
+ if raw_duration < duration:
62
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
+ continue
64
+ if signal.ndim != 1:
65
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
+
67
+ signal_length = len(signal)
68
+ win_size = int(duration * sample_rate)
69
+ for begin in range(0, signal_length - win_size, win_size):
70
+ row = {
71
+ "filename": filename.as_posix(),
72
+ "raw_duration": round(raw_duration, 4),
73
+ "offset": round(begin / sample_rate, 4),
74
+ "duration": round(duration, 4),
75
+ }
76
+ yield row
77
+
78
+
79
+ def get_dataset(args):
80
+ file_dir = Path(args.file_dir)
81
+ file_dir.mkdir(exist_ok=True)
82
+
83
+ noise_dir = Path(args.noise_dir)
84
+ speech_dir = Path(args.speech_dir)
85
+
86
+ noise_generator = target_second_signal_generator(
87
+ noise_dir.as_posix(),
88
+ duration=args.duration,
89
+ sample_rate=args.target_sample_rate
90
+ )
91
+ speech_generator = target_second_signal_generator(
92
+ speech_dir.as_posix(),
93
+ duration=args.duration,
94
+ sample_rate=args.target_sample_rate
95
+ )
96
+
97
+ dataset = list()
98
+
99
+ count = 0
100
+ process_bar = tqdm(desc="build dataset excel")
101
+ for noise, speech in zip(noise_generator, speech_generator):
102
+
103
+ noise_filename = noise["filename"]
104
+ noise_raw_duration = noise["raw_duration"]
105
+ noise_offset = noise["offset"]
106
+ noise_duration = noise["duration"]
107
+
108
+ speech_filename = speech["filename"]
109
+ speech_raw_duration = speech["raw_duration"]
110
+ speech_offset = speech["offset"]
111
+ speech_duration = speech["duration"]
112
+
113
+ random1 = random.random()
114
+ random2 = random.random()
115
+
116
+ row = {
117
+ "noise_filename": noise_filename,
118
+ "noise_raw_duration": noise_raw_duration,
119
+ "noise_offset": noise_offset,
120
+ "noise_duration": noise_duration,
121
+
122
+ "speech_filename": speech_filename,
123
+ "speech_raw_duration": speech_raw_duration,
124
+ "speech_offset": speech_offset,
125
+ "speech_duration": speech_duration,
126
+
127
+ "snr_db": random.uniform(args.min_nsr_db, args.max_nsr_db),
128
+
129
+ "random1": random1,
130
+ "random2": random2,
131
+ "flag": "TRAIN" if random2 < 0.8 else "TEST",
132
+ }
133
+ dataset.append(row)
134
+ count += 1
135
+ duration_seconds = count * args.duration
136
+ duration_hours = duration_seconds / 3600
137
+
138
+ process_bar.update(n=1)
139
+ process_bar.set_postfix({
140
+ # "duration_seconds": round(duration_seconds, 4),
141
+ "duration_hours": round(duration_hours, 4),
142
+ })
143
+
144
+ dataset = pd.DataFrame(dataset)
145
+ dataset = dataset.sort_values(by=["random1"], ascending=False)
146
+ dataset.to_excel(
147
+ file_dir / "dataset.xlsx",
148
+ index=False,
149
+ )
150
+ return
151
+
152
+
153
+
154
+ def split_dataset(args):
155
+ """分割训练集, 测试集"""
156
+ file_dir = Path(args.file_dir)
157
+ file_dir.mkdir(exist_ok=True)
158
+
159
+ df = pd.read_excel(file_dir / "dataset.xlsx")
160
+
161
+ train = list()
162
+ test = list()
163
+
164
+ for i, row in df.iterrows():
165
+ flag = row["flag"]
166
+ if flag == "TRAIN":
167
+ train.append(row)
168
+ else:
169
+ test.append(row)
170
+
171
+ train = pd.DataFrame(train)
172
+ train.to_excel(
173
+ args.train_dataset,
174
+ index=False,
175
+ # encoding="utf_8_sig"
176
+ )
177
+ test = pd.DataFrame(test)
178
+ test.to_excel(
179
+ args.valid_dataset,
180
+ index=False,
181
+ # encoding="utf_8_sig"
182
+ )
183
+
184
+ return
185
+
186
+
187
+ def main():
188
+ args = get_args()
189
+
190
+ get_dataset(args)
191
+ split_dataset(args)
192
+ return
193
+
194
+
195
+ if __name__ == "__main__":
196
+ main()
examples/simple_linear_irm_aishell/step_2_train_model.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
5
+ """
6
+ import argparse
7
+ import json
8
+ import logging
9
+ from logging.handlers import TimedRotatingFileHandler
10
+ import os
11
+ import platform
12
+ from pathlib import Path
13
+ import random
14
+ import sys
15
+ import shutil
16
+ from typing import List
17
+
18
+ from torch import dtype
19
+
20
+ pwd = os.path.abspath(os.path.dirname(__file__))
21
+ sys.path.append(os.path.join(pwd, "../../"))
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ from torch.utils.data.dataloader import DataLoader
27
+ import torchaudio
28
+ from tqdm import tqdm
29
+
30
+ from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
31
+ from toolbox.torchaudio.models.simple_linear_irm.configuration_simple_linear_irm import SimpleLinearIRMConfig
32
+ from toolbox.torchaudio.models.simple_linear_irm.modeling_simple_linear_irm import SimpleLinearIRMPretrainedModel
33
+
34
+
35
+ def get_args():
36
+ parser = argparse.ArgumentParser()
37
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
38
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
39
+
40
+ parser.add_argument("--max_epochs", default=100, type=int)
41
+
42
+ parser.add_argument("--batch_size", default=64, type=int)
43
+ parser.add_argument("--learning_rate", default=1e-3, type=float)
44
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
45
+ parser.add_argument("--patience", default=5, type=int)
46
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
47
+ parser.add_argument("--seed", default=0, type=int)
48
+
49
+ parser.add_argument("--config_file", default="config.yaml", type=str)
50
+
51
+ args = parser.parse_args()
52
+ return args
53
+
54
+
55
+ def logging_config(file_dir: str):
56
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
57
+
58
+ logging.basicConfig(format=fmt,
59
+ datefmt="%m/%d/%Y %H:%M:%S",
60
+ level=logging.INFO)
61
+ file_handler = TimedRotatingFileHandler(
62
+ filename=os.path.join(file_dir, "main.log"),
63
+ encoding="utf-8",
64
+ when="D",
65
+ interval=1,
66
+ backupCount=7
67
+ )
68
+ file_handler.setLevel(logging.INFO)
69
+ file_handler.setFormatter(logging.Formatter(fmt))
70
+ logger = logging.getLogger(__name__)
71
+ logger.addHandler(file_handler)
72
+
73
+ return logger
74
+
75
+
76
+ class CollateFunction(object):
77
+ def __init__(self,
78
+ n_fft: int = 512,
79
+ win_length: int = 200,
80
+ hop_length: int = 80,
81
+ window_fn: str = "hamming",
82
+ irm_beta: float = 1.0,
83
+ epsilon: float = 1e-8,
84
+ ):
85
+ self.n_fft = n_fft
86
+ self.win_length = win_length
87
+ self.hop_length = hop_length
88
+ self.window_fn = window_fn
89
+ self.irm_beta = irm_beta
90
+ self.epsilon = epsilon
91
+
92
+ self.transform = torchaudio.transforms.Spectrogram(
93
+ n_fft=self.n_fft,
94
+ win_length=self.win_length,
95
+ hop_length=self.hop_length,
96
+ power=2.0,
97
+ window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
98
+ )
99
+
100
+ def __call__(self, batch: List[dict]):
101
+ mix_spec_list = list()
102
+ speech_irm_list = list()
103
+ snr_db_list = list()
104
+ for sample in batch:
105
+ noise_wave: torch.Tensor = sample["noise_wave"]
106
+ speech_wave: torch.Tensor = sample["speech_wave"]
107
+ mix_wave: torch.Tensor = sample["mix_wave"]
108
+ snr_db: float = sample["snr_db"]
109
+
110
+ noise_spec = self.transform.forward(noise_wave)
111
+ speech_spec = self.transform.forward(speech_wave)
112
+ mix_spec = self.transform.forward(mix_wave)
113
+
114
+ # noise_irm = noise_spec / (noise_spec + speech_spec)
115
+ speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon)
116
+ speech_irm = torch.pow(speech_irm, self.irm_beta)
117
+
118
+ mix_spec_list.append(mix_spec)
119
+ speech_irm_list.append(speech_irm)
120
+ snr_db_list.append(torch.tensor(snr_db, dtype=torch.float32))
121
+
122
+ mix_spec_list = torch.stack(mix_spec_list)
123
+ speech_irm_list = torch.stack(speech_irm_list)
124
+ snr_db_list = torch.stack(snr_db_list) # shape: (batch_size,)
125
+
126
+ # assert
127
+ if torch.any(torch.isnan(mix_spec_list)):
128
+ raise AssertionError("nan in mix_spec Tensor")
129
+ if torch.any(torch.isnan(speech_irm_list)):
130
+ raise AssertionError("nan in speech_irm Tensor")
131
+ if torch.any(torch.isnan(snr_db_list)):
132
+ raise AssertionError("nan in snr_db Tensor")
133
+
134
+ return mix_spec_list, speech_irm_list, snr_db_list
135
+
136
+
137
+ collate_fn = CollateFunction()
138
+
139
+
140
+ def main():
141
+ args = get_args()
142
+
143
+ serialization_dir = Path(args.serialization_dir)
144
+ serialization_dir.mkdir(parents=True, exist_ok=True)
145
+
146
+ logger = logging_config(serialization_dir)
147
+
148
+ random.seed(args.seed)
149
+ np.random.seed(args.seed)
150
+ torch.manual_seed(args.seed)
151
+ logger.info("set seed: {}".format(args.seed))
152
+
153
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154
+ n_gpu = torch.cuda.device_count()
155
+ logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
156
+
157
+ # datasets
158
+ logger.info("prepare datasets")
159
+ train_dataset = DenoiseExcelDataset(
160
+ excel_file=args.train_dataset,
161
+ expected_sample_rate=8000,
162
+ max_wave_value=32768.0,
163
+ )
164
+ valid_dataset = DenoiseExcelDataset(
165
+ excel_file=args.valid_dataset,
166
+ expected_sample_rate=8000,
167
+ max_wave_value=32768.0,
168
+ )
169
+ train_data_loader = DataLoader(
170
+ dataset=train_dataset,
171
+ batch_size=args.batch_size,
172
+ shuffle=True,
173
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
174
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
175
+ collate_fn=collate_fn,
176
+ pin_memory=False,
177
+ # prefetch_factor=64,
178
+ )
179
+ valid_data_loader = DataLoader(
180
+ dataset=valid_dataset,
181
+ batch_size=args.batch_size,
182
+ shuffle=True,
183
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
184
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
185
+ collate_fn=collate_fn,
186
+ pin_memory=False,
187
+ # prefetch_factor=64,
188
+ )
189
+
190
+ # models
191
+ logger.info(f"prepare models. config_file: {args.config_file}")
192
+ config = SimpleLinearIRMConfig.from_pretrained(
193
+ pretrained_model_name_or_path=args.config_file,
194
+ # num_labels=vocabulary.get_vocab_size(namespace="labels")
195
+ )
196
+ model = SimpleLinearIRMPretrainedModel(
197
+ config=config,
198
+ )
199
+ model.to(device)
200
+ model.train()
201
+
202
+ # optimizer
203
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
204
+ param_optimizer = model.parameters()
205
+ optimizer = torch.optim.Adam(
206
+ param_optimizer,
207
+ lr=args.learning_rate,
208
+ )
209
+ # lr_scheduler = torch.optim.lr_scheduler.StepLR(
210
+ # optimizer,
211
+ # step_size=2000
212
+ # )
213
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
214
+ optimizer,
215
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
216
+ )
217
+ mse_loss = nn.MSELoss(
218
+ reduction="mean",
219
+ )
220
+
221
+ # training loop
222
+ logger.info("training")
223
+
224
+ training_loss = 10000000000
225
+ evaluation_loss = 10000000000
226
+
227
+ model_list = list()
228
+ best_idx_epoch = None
229
+ best_metric = None
230
+ patience_count = 0
231
+
232
+ for idx_epoch in range(args.max_epochs):
233
+ total_loss = 0.
234
+ total_examples = 0.
235
+ progress_bar = tqdm(
236
+ total=len(train_data_loader),
237
+ desc="Training; epoch: {}".format(idx_epoch),
238
+ )
239
+
240
+ for batch in train_data_loader:
241
+ mix_spec, speech_irm, snr_db = batch
242
+ mix_spec = mix_spec.to(device)
243
+ speech_irm_target = speech_irm.to(device)
244
+ snr_db_target = snr_db.to(device)
245
+
246
+ speech_irm_prediction = model.forward(mix_spec)
247
+ loss = mse_loss.forward(speech_irm_prediction, speech_irm_target)
248
+
249
+ total_loss += loss.item()
250
+ total_examples += mix_spec.size(0)
251
+
252
+ optimizer.zero_grad()
253
+ loss.backward()
254
+ optimizer.step()
255
+ lr_scheduler.step()
256
+
257
+ training_loss = total_loss / total_examples
258
+ training_loss = round(training_loss, 4)
259
+
260
+ progress_bar.update(1)
261
+ progress_bar.set_postfix({
262
+ "training_loss": training_loss,
263
+ })
264
+
265
+ total_loss = 0.
266
+ total_examples = 0.
267
+ progress_bar = tqdm(
268
+ total=len(valid_data_loader),
269
+ desc="Evaluation; epoch: {}".format(idx_epoch),
270
+ )
271
+ for batch in valid_data_loader:
272
+ mix_spec, speech_irm, snr_db = batch
273
+ mix_spec = mix_spec.to(device)
274
+ speech_irm_target = speech_irm.to(device)
275
+ snr_db_target = snr_db.to(device)
276
+
277
+ with torch.no_grad():
278
+ speech_irm_prediction = model.forward(mix_spec)
279
+ loss = mse_loss.forward(speech_irm_prediction, speech_irm_target)
280
+
281
+ total_loss += loss.item()
282
+ total_examples += mix_spec.size(0)
283
+
284
+ evaluation_loss = total_loss / total_examples
285
+ evaluation_loss = round(evaluation_loss, 4)
286
+
287
+ progress_bar.update(1)
288
+ progress_bar.set_postfix({
289
+ "evaluation_loss": evaluation_loss,
290
+ })
291
+
292
+ # save path
293
+ epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
294
+ epoch_dir.mkdir(parents=True, exist_ok=False)
295
+
296
+ # save models
297
+ model.save_pretrained(epoch_dir.as_posix())
298
+
299
+ model_list.append(epoch_dir)
300
+ if len(model_list) >= args.num_serialized_models_to_keep:
301
+ model_to_delete: Path = model_list.pop(0)
302
+ shutil.rmtree(model_to_delete.as_posix())
303
+
304
+ # save metric
305
+ if best_metric is None:
306
+ best_idx_epoch = idx_epoch
307
+ best_metric = evaluation_loss
308
+ elif evaluation_loss < best_metric:
309
+ best_idx_epoch = idx_epoch
310
+ best_metric = evaluation_loss
311
+ else:
312
+ pass
313
+
314
+ metrics = {
315
+ "idx_epoch": idx_epoch,
316
+ "best_idx_epoch": best_idx_epoch,
317
+ "training_loss": training_loss,
318
+ "evaluation_loss": evaluation_loss,
319
+ "learning_rate": optimizer.param_groups[0]["lr"],
320
+ }
321
+ metrics_filename = epoch_dir / "metrics_epoch.json"
322
+ with open(metrics_filename, "w", encoding="utf-8") as f:
323
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
324
+
325
+ # save best
326
+ best_dir = serialization_dir / "best"
327
+ if best_idx_epoch == idx_epoch:
328
+ if best_dir.exists():
329
+ shutil.rmtree(best_dir)
330
+ shutil.copytree(epoch_dir, best_dir)
331
+
332
+ # early stop
333
+ early_stop_flag = False
334
+ if best_idx_epoch == idx_epoch:
335
+ patience_count = 0
336
+ else:
337
+ patience_count += 1
338
+ if patience_count >= args.patience:
339
+ early_stop_flag = True
340
+
341
+ # early stop
342
+ if early_stop_flag:
343
+ break
344
+ return
345
+
346
+
347
+ if __name__ == '__main__':
348
+ main()
examples/simple_linear_irm_aishell/step_3_evaluation.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import logging
5
+ import os
6
+ from pathlib import Path
7
+ import sys
8
+ import uuid
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ import pandas as pd
16
+ from scipy.io import wavfile
17
+ import torch
18
+ import torch.nn as nn
19
+ import torchaudio
20
+ from tqdm import tqdm
21
+
22
+ from toolbox.torchaudio.models.simple_linear_irm.modeling_simple_linear_irm import SimpleLinearIRMPretrainedModel
23
+
24
+
25
+ def get_args():
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
28
+ parser.add_argument("--model_dir", default="serialization_dir/best", type=str)
29
+ parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str)
30
+
31
+ parser.add_argument("--limit", default=10, type=int)
32
+
33
+ args = parser.parse_args()
34
+ return args
35
+
36
+
37
+ def logging_config():
38
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
39
+
40
+ logging.basicConfig(format=fmt,
41
+ datefmt="%m/%d/%Y %H:%M:%S",
42
+ level=logging.INFO)
43
+ stream_handler = logging.StreamHandler()
44
+ stream_handler.setLevel(logging.INFO)
45
+ stream_handler.setFormatter(logging.Formatter(fmt))
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+ return logger
50
+
51
+
52
+ def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float):
53
+ l1 = len(speech)
54
+ l2 = len(noise)
55
+ l = min(l1, l2)
56
+ speech = speech[:l]
57
+ noise = noise[:l]
58
+
59
+ # np.float32, value between (-1, 1).
60
+
61
+ speech_power = np.mean(np.square(speech))
62
+ noise_power = speech_power / (10 ** (snr_db / 10))
63
+
64
+ noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2))
65
+
66
+ noisy_signal = speech + noise_adjusted
67
+
68
+ return noisy_signal
69
+
70
+
71
+ stft_power = torchaudio.transforms.Spectrogram(
72
+ n_fft=512,
73
+ win_length=200,
74
+ hop_length=80,
75
+ power=2.0,
76
+ window_fn=torch.hamming_window,
77
+ )
78
+
79
+
80
+ stft_complex = torchaudio.transforms.Spectrogram(
81
+ n_fft=512,
82
+ win_length=200,
83
+ hop_length=80,
84
+ power=None,
85
+ window_fn=torch.hamming_window,
86
+ )
87
+
88
+
89
+ istft = torchaudio.transforms.InverseSpectrogram(
90
+ n_fft=512,
91
+ win_length=200,
92
+ hop_length=80,
93
+ window_fn=torch.hamming_window,
94
+ )
95
+
96
+
97
+ def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor):
98
+ mix_spec_complex = mix_spec_complex.detach().cpu()
99
+ speech_irm_prediction = speech_irm_prediction.detach().cpu()
100
+
101
+ mask_speech = speech_irm_prediction
102
+ mask_noise = 1.0 - speech_irm_prediction
103
+
104
+ speech_spec = mix_spec_complex * mask_speech
105
+ noise_spec = mix_spec_complex * mask_noise
106
+
107
+ speech_wave = istft.forward(speech_spec)
108
+ noise_wave = istft.forward(noise_spec)
109
+
110
+ return speech_wave, noise_wave
111
+
112
+
113
+ def save_audios(noise_wave: torch.Tensor,
114
+ speech_wave: torch.Tensor,
115
+ mix_wave: torch.Tensor,
116
+ speech_wave_enhanced: torch.Tensor,
117
+ noise_wave_enhanced: torch.Tensor,
118
+ output_dir: str,
119
+ sample_rate: int = 8000,
120
+ ):
121
+ basename = uuid.uuid4().__str__()
122
+ output_dir = Path(output_dir) / basename
123
+ output_dir.mkdir(parents=True, exist_ok=True)
124
+
125
+ filename = output_dir / "noise_wave.wav"
126
+ torchaudio.save(filename, noise_wave, sample_rate)
127
+ filename = output_dir / "speech_wave.wav"
128
+ torchaudio.save(filename, speech_wave, sample_rate)
129
+ filename = output_dir / "mix_wave.wav"
130
+ torchaudio.save(filename, mix_wave, sample_rate)
131
+
132
+ filename = output_dir / "speech_wave_enhanced.wav"
133
+ torchaudio.save(filename, speech_wave_enhanced, sample_rate)
134
+ filename = output_dir / "noise_wave_enhanced.wav"
135
+ torchaudio.save(filename, noise_wave_enhanced, sample_rate)
136
+
137
+ return output_dir.as_posix()
138
+
139
+
140
+ def main():
141
+ args = get_args()
142
+
143
+ logger = logging_config()
144
+
145
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
146
+ n_gpu = torch.cuda.device_count()
147
+ logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
148
+
149
+ logger.info("prepare model")
150
+ model = SimpleLinearIRMPretrainedModel.from_pretrained(
151
+ pretrained_model_name_or_path=args.model_dir,
152
+ )
153
+ model.to(device)
154
+ model.eval()
155
+
156
+ # optimizer
157
+ logger.info("prepare loss_fn")
158
+ mse_loss = nn.MSELoss(
159
+ reduction="mean",
160
+ )
161
+
162
+ logger.info("read excel")
163
+ df = pd.read_excel(args.valid_dataset)
164
+
165
+ total_loss = 0.
166
+ total_examples = 0.
167
+ progress_bar = tqdm(total=len(df), desc="Evaluation")
168
+ for idx, row in df.iterrows():
169
+ noise_filename = row["noise_filename"]
170
+ noise_offset = row["noise_offset"]
171
+ noise_duration = row["noise_duration"]
172
+
173
+ speech_filename = row["speech_filename"]
174
+ speech_offset = row["speech_offset"]
175
+ speech_duration = row["speech_duration"]
176
+
177
+ snr_db = row["snr_db"]
178
+
179
+ noise_wave, _ = librosa.load(
180
+ noise_filename,
181
+ sr=8000,
182
+ offset=noise_offset,
183
+ duration=noise_duration,
184
+ )
185
+ speech_wave, _ = librosa.load(
186
+ speech_filename,
187
+ sr=8000,
188
+ offset=speech_offset,
189
+ duration=speech_duration,
190
+ )
191
+ mix_wave: np.ndarray = mix_speech_and_noise(
192
+ speech=speech_wave,
193
+ noise=noise_wave,
194
+ snr_db=snr_db,
195
+ )
196
+ noise_wave = torch.tensor(noise_wave, dtype=torch.float32)
197
+ speech_wave = torch.tensor(speech_wave, dtype=torch.float32)
198
+ mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32)
199
+
200
+ noise_wave = noise_wave.unsqueeze(dim=0)
201
+ speech_wave = speech_wave.unsqueeze(dim=0)
202
+ mix_wave = mix_wave.unsqueeze(dim=0)
203
+
204
+ noise_spec: torch.Tensor = stft_power.forward(noise_wave)
205
+ speech_spec: torch.Tensor = stft_power.forward(speech_wave)
206
+ mix_spec: torch.Tensor = stft_power.forward(mix_wave)
207
+ mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave)
208
+
209
+ speech_irm = speech_spec / (noise_spec + speech_spec)
210
+ speech_irm = torch.pow(speech_irm, 1.0)
211
+
212
+ mix_spec = mix_spec.to(device)
213
+ speech_irm_target = speech_irm.to(device)
214
+ with torch.no_grad():
215
+ speech_irm_prediction = model.forward(mix_spec)
216
+ loss = mse_loss.forward(speech_irm_prediction, speech_irm_target)
217
+
218
+ speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_irm_prediction)
219
+ save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
220
+
221
+ total_loss += loss.item()
222
+ total_examples += mix_spec.size(0)
223
+
224
+ evaluation_loss = total_loss / total_examples
225
+ evaluation_loss = round(evaluation_loss, 4)
226
+
227
+ progress_bar.update(1)
228
+ progress_bar.set_postfix({
229
+ "evaluation_loss": evaluation_loss,
230
+ })
231
+
232
+ if idx > args.limit:
233
+ break
234
+
235
+ return
236
+
237
+
238
+ if __name__ == '__main__':
239
+ main()
examples/simple_linear_irm_aishell/yaml/config.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "simple_linear_irm"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+
9
+ # model
10
+ num_bins: 257
11
+ hidden_size: 2048
12
+ lookback: 3
13
+ lookahead: 3
examples/simple_lstm_irm_aishell/run.sh ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir
6
+
7
+ sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir
8
+
9
+ sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name file_dir \
10
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
11
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
12
+
13
+
14
+ END
15
+
16
+
17
+ # params
18
+ system_version="windows";
19
+ verbose=true;
20
+ stage=0 # start from 0 if you need to start from data preparation
21
+ stop_stage=9
22
+
23
+ work_dir="$(pwd)"
24
+ file_folder_name=file_folder_name
25
+ final_model_name=final_model_name
26
+ config_file="yaml/config.yaml"
27
+ limit=10
28
+
29
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
30
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
31
+
32
+ nohup_name=nohup.out
33
+
34
+ # model params
35
+ batch_size=64
36
+ max_epochs=200
37
+ save_top_k=10
38
+ patience=5
39
+
40
+
41
+ # parse options
42
+ while true; do
43
+ [ -z "${1:-}" ] && break; # break if there are no arguments
44
+ case "$1" in
45
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
46
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
47
+ old_value="(eval echo \\$$name)";
48
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
49
+ was_bool=true;
50
+ else
51
+ was_bool=false;
52
+ fi
53
+
54
+ # Set the variable to the right value-- the escaped quotes make it work if
55
+ # the option had spaces, like --cmd "queue.pl -sync y"
56
+ eval "${name}=\"$2\"";
57
+
58
+ # Check that Boolean-valued arguments are really Boolean.
59
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
60
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
61
+ exit 1;
62
+ fi
63
+ shift 2;
64
+ ;;
65
+
66
+ *) break;
67
+ esac
68
+ done
69
+
70
+ file_dir="${work_dir}/${file_folder_name}"
71
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
72
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
73
+
74
+ dataset="${file_dir}/dataset.xlsx"
75
+ train_dataset="${file_dir}/train.xlsx"
76
+ valid_dataset="${file_dir}/valid.xlsx"
77
+
78
+ $verbose && echo "system_version: ${system_version}"
79
+ $verbose && echo "file_folder_name: ${file_folder_name}"
80
+
81
+ if [ $system_version == "windows" ]; then
82
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
83
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
84
+ #source /data/local/bin/nx_denoise/bin/activate
85
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
86
+ fi
87
+
88
+
89
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
90
+ $verbose && echo "stage 1: prepare data"
91
+ cd "${work_dir}" || exit 1
92
+ python3 step_1_prepare_data.py \
93
+ --file_dir "${file_dir}" \
94
+ --noise_dir "${noise_dir}" \
95
+ --speech_dir "${speech_dir}" \
96
+ --train_dataset "${train_dataset}" \
97
+ --valid_dataset "${valid_dataset}" \
98
+
99
+ fi
100
+
101
+
102
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
103
+ $verbose && echo "stage 2: train model"
104
+ cd "${work_dir}" || exit 1
105
+ python3 step_2_train_model.py \
106
+ --train_dataset "${train_dataset}" \
107
+ --valid_dataset "${valid_dataset}" \
108
+ --serialization_dir "${file_dir}" \
109
+ --config_file "${config_file}" \
110
+
111
+ fi
112
+
113
+
114
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
115
+ $verbose && echo "stage 3: test model"
116
+ cd "${work_dir}" || exit 1
117
+ python3 step_3_evaluation.py \
118
+ --valid_dataset "${valid_dataset}" \
119
+ --model_dir "${file_dir}/best" \
120
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
121
+ --limit "${limit}" \
122
+
123
+ fi
124
+
125
+
126
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
127
+ $verbose && echo "stage 4: export model"
128
+ cd "${work_dir}" || exit 1
129
+ python3 step_5_export_models.py \
130
+ --vocabulary_dir "${vocabulary_dir}" \
131
+ --model_dir "${file_dir}/best" \
132
+ --serialization_dir "${file_dir}" \
133
+
134
+ fi
135
+
136
+
137
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
138
+ $verbose && echo "stage 5: collect files"
139
+ cd "${work_dir}" || exit 1
140
+
141
+ mkdir -p ${final_model_dir}
142
+
143
+ cp "${file_dir}/best"/* "${final_model_dir}"
144
+ cp -r "${file_dir}/vocabulary" "${final_model_dir}"
145
+
146
+ cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
147
+
148
+ cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
149
+ cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
150
+ cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
151
+ cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
152
+
153
+ cd "${final_model_dir}/.." || exit 1;
154
+
155
+ if [ -e "${final_model_name}.zip" ]; then
156
+ rm -rf "${final_model_name}_backup.zip"
157
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
158
+ fi
159
+
160
+ zip -r "${final_model_name}.zip" "${final_model_name}"
161
+ rm -rf "${final_model_name}"
162
+
163
+ fi
164
+
165
+
166
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
167
+ $verbose && echo "stage 6: clear file_dir"
168
+ cd "${work_dir}" || exit 1
169
+
170
+ rm -rf "${file_dir}";
171
+
172
+ fi
examples/simple_lstm_irm_aishell/step_1_prepare_data.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import random
7
+ import sys
8
+ import shutil
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import pandas as pd
14
+ from scipy.io import wavfile
15
+ from tqdm import tqdm
16
+ import librosa
17
+
18
+ from project_settings import project_path
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--file_dir", default="./", type=str)
24
+
25
+ parser.add_argument(
26
+ "--noise_dir",
27
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
28
+ type=str
29
+ )
30
+ parser.add_argument(
31
+ "--speech_dir",
32
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
33
+ type=str
34
+ )
35
+
36
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
+
39
+ parser.add_argument("--duration", default=2.0, type=float)
40
+ parser.add_argument("--min_nsr_db", default=-20, type=float)
41
+ parser.add_argument("--max_nsr_db", default=5, type=float)
42
+
43
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
44
+
45
+ args = parser.parse_args()
46
+ return args
47
+
48
+
49
+ def filename_generator(data_dir: str):
50
+ data_dir = Path(data_dir)
51
+ for filename in data_dir.glob("**/*.wav"):
52
+ yield filename.as_posix()
53
+
54
+
55
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
56
+ data_dir = Path(data_dir)
57
+ for filename in data_dir.glob("**/*.wav"):
58
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
+
61
+ if raw_duration < duration:
62
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
+ continue
64
+ if signal.ndim != 1:
65
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
+
67
+ signal_length = len(signal)
68
+ win_size = int(duration * sample_rate)
69
+ for begin in range(0, signal_length - win_size, win_size):
70
+ row = {
71
+ "filename": filename.as_posix(),
72
+ "raw_duration": round(raw_duration, 4),
73
+ "offset": round(begin / sample_rate, 4),
74
+ "duration": round(duration, 4),
75
+ }
76
+ yield row
77
+
78
+
79
+ def get_dataset(args):
80
+ file_dir = Path(args.file_dir)
81
+ file_dir.mkdir(exist_ok=True)
82
+
83
+ noise_dir = Path(args.noise_dir)
84
+ speech_dir = Path(args.speech_dir)
85
+
86
+ noise_generator = target_second_signal_generator(
87
+ noise_dir.as_posix(),
88
+ duration=args.duration,
89
+ sample_rate=args.target_sample_rate
90
+ )
91
+ speech_generator = target_second_signal_generator(
92
+ speech_dir.as_posix(),
93
+ duration=args.duration,
94
+ sample_rate=args.target_sample_rate
95
+ )
96
+
97
+ dataset = list()
98
+
99
+ count = 0
100
+ process_bar = tqdm(desc="build dataset excel")
101
+ for noise, speech in zip(noise_generator, speech_generator):
102
+
103
+ noise_filename = noise["filename"]
104
+ noise_raw_duration = noise["raw_duration"]
105
+ noise_offset = noise["offset"]
106
+ noise_duration = noise["duration"]
107
+
108
+ speech_filename = speech["filename"]
109
+ speech_raw_duration = speech["raw_duration"]
110
+ speech_offset = speech["offset"]
111
+ speech_duration = speech["duration"]
112
+
113
+ random1 = random.random()
114
+ random2 = random.random()
115
+
116
+ row = {
117
+ "noise_filename": noise_filename,
118
+ "noise_raw_duration": noise_raw_duration,
119
+ "noise_offset": noise_offset,
120
+ "noise_duration": noise_duration,
121
+
122
+ "speech_filename": speech_filename,
123
+ "speech_raw_duration": speech_raw_duration,
124
+ "speech_offset": speech_offset,
125
+ "speech_duration": speech_duration,
126
+
127
+ "snr_db": random.uniform(args.min_nsr_db, args.max_nsr_db),
128
+
129
+ "random1": random1,
130
+ "random2": random2,
131
+ "flag": "TRAIN" if random2 < 0.8 else "TEST",
132
+ }
133
+ dataset.append(row)
134
+ count += 1
135
+ duration_seconds = count * args.duration
136
+ duration_hours = duration_seconds / 3600
137
+
138
+ process_bar.update(n=1)
139
+ process_bar.set_postfix({
140
+ # "duration_seconds": round(duration_seconds, 4),
141
+ "duration_hours": round(duration_hours, 4),
142
+
143
+ })
144
+
145
+ dataset = pd.DataFrame(dataset)
146
+ dataset = dataset.sort_values(by=["random1"], ascending=False)
147
+ dataset.to_excel(
148
+ file_dir / "dataset.xlsx",
149
+ index=False,
150
+ )
151
+ return
152
+
153
+
154
+
155
+ def split_dataset(args):
156
+ """分割训练集, 测试集"""
157
+ file_dir = Path(args.file_dir)
158
+ file_dir.mkdir(exist_ok=True)
159
+
160
+ df = pd.read_excel(file_dir / "dataset.xlsx")
161
+
162
+ train = list()
163
+ test = list()
164
+
165
+ for i, row in df.iterrows():
166
+ flag = row["flag"]
167
+ if flag == "TRAIN":
168
+ train.append(row)
169
+ else:
170
+ test.append(row)
171
+
172
+ train = pd.DataFrame(train)
173
+ train.to_excel(
174
+ args.train_dataset,
175
+ index=False,
176
+ # encoding="utf_8_sig"
177
+ )
178
+ test = pd.DataFrame(test)
179
+ test.to_excel(
180
+ args.valid_dataset,
181
+ index=False,
182
+ # encoding="utf_8_sig"
183
+ )
184
+
185
+ return
186
+
187
+
188
+ def main():
189
+ args = get_args()
190
+
191
+ get_dataset(args)
192
+ split_dataset(args)
193
+ return
194
+
195
+
196
+ if __name__ == "__main__":
197
+ main()
examples/simple_lstm_irm_aishell/step_2_train_model.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
5
+ """
6
+ import argparse
7
+ import json
8
+ import logging
9
+ from logging.handlers import TimedRotatingFileHandler
10
+ import os
11
+ import platform
12
+ from pathlib import Path
13
+ import random
14
+ import sys
15
+ import shutil
16
+ from typing import List
17
+
18
+ from torch import dtype
19
+
20
+ pwd = os.path.abspath(os.path.dirname(__file__))
21
+ sys.path.append(os.path.join(pwd, "../../"))
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ from torch.utils.data.dataloader import DataLoader
27
+ import torchaudio
28
+ from tqdm import tqdm
29
+
30
+ from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
31
+ from toolbox.torchaudio.models.simple_lstm_irm.configuration_simple_lstm_irm import SimpleLstmIRMConfig
32
+ from toolbox.torchaudio.models.simple_lstm_irm.modeling_simple_lstm_irm import SimpleLstmIRMPretrainedModel
33
+
34
+
35
+ def get_args():
36
+ parser = argparse.ArgumentParser()
37
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
38
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
39
+
40
+ parser.add_argument("--max_epochs", default=100, type=int)
41
+
42
+ parser.add_argument("--batch_size", default=64, type=int)
43
+ parser.add_argument("--learning_rate", default=1e-3, type=float)
44
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
45
+ parser.add_argument("--patience", default=5, type=int)
46
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
47
+ parser.add_argument("--seed", default=0, type=int)
48
+
49
+ parser.add_argument("--config_file", default="config.yaml", type=str)
50
+
51
+ args = parser.parse_args()
52
+ return args
53
+
54
+
55
+ def logging_config(file_dir: str):
56
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
57
+
58
+ logging.basicConfig(format=fmt,
59
+ datefmt="%m/%d/%Y %H:%M:%S",
60
+ level=logging.INFO)
61
+ file_handler = TimedRotatingFileHandler(
62
+ filename=os.path.join(file_dir, "main.log"),
63
+ encoding="utf-8",
64
+ when="D",
65
+ interval=1,
66
+ backupCount=7
67
+ )
68
+ file_handler.setLevel(logging.INFO)
69
+ file_handler.setFormatter(logging.Formatter(fmt))
70
+ logger = logging.getLogger(__name__)
71
+ logger.addHandler(file_handler)
72
+
73
+ return logger
74
+
75
+
76
+ class CollateFunction(object):
77
+ def __init__(self,
78
+ n_fft: int = 512,
79
+ win_length: int = 200,
80
+ hop_length: int = 80,
81
+ window_fn: str = "hamming",
82
+ irm_beta: float = 1.0,
83
+ epsilon: float = 1e-8,
84
+ ):
85
+ self.n_fft = n_fft
86
+ self.win_length = win_length
87
+ self.hop_length = hop_length
88
+ self.window_fn = window_fn
89
+ self.irm_beta = irm_beta
90
+ self.epsilon = epsilon
91
+
92
+ self.transform = torchaudio.transforms.Spectrogram(
93
+ n_fft=self.n_fft,
94
+ win_length=self.win_length,
95
+ hop_length=self.hop_length,
96
+ power=2.0,
97
+ window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
98
+ )
99
+
100
+ def __call__(self, batch: List[dict]):
101
+ mix_spec_list = list()
102
+ speech_irm_list = list()
103
+ snr_db_list = list()
104
+ for sample in batch:
105
+ noise_wave: torch.Tensor = sample["noise_wave"]
106
+ speech_wave: torch.Tensor = sample["speech_wave"]
107
+ mix_wave: torch.Tensor = sample["mix_wave"]
108
+ snr_db: float = sample["snr_db"]
109
+
110
+ noise_spec = self.transform.forward(noise_wave)
111
+ speech_spec = self.transform.forward(speech_wave)
112
+ mix_spec = self.transform.forward(mix_wave)
113
+
114
+ # noise_irm = noise_spec / (noise_spec + speech_spec)
115
+ speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon)
116
+ speech_irm = torch.pow(speech_irm, self.irm_beta)
117
+
118
+ mix_spec_list.append(mix_spec)
119
+ speech_irm_list.append(speech_irm)
120
+ snr_db_list.append(torch.tensor(snr_db, dtype=torch.float32))
121
+
122
+ mix_spec_list = torch.stack(mix_spec_list)
123
+ speech_irm_list = torch.stack(speech_irm_list)
124
+ snr_db_list = torch.stack(snr_db_list) # shape: (batch_size,)
125
+
126
+ # assert
127
+ if torch.any(torch.isnan(mix_spec_list)):
128
+ raise AssertionError("nan in mix_spec Tensor")
129
+ if torch.any(torch.isnan(speech_irm_list)):
130
+ raise AssertionError("nan in speech_irm Tensor")
131
+ if torch.any(torch.isnan(snr_db_list)):
132
+ raise AssertionError("nan in snr_db Tensor")
133
+
134
+ return mix_spec_list, speech_irm_list, snr_db_list
135
+
136
+
137
+ collate_fn = CollateFunction()
138
+
139
+
140
+ def main():
141
+ args = get_args()
142
+
143
+ serialization_dir = Path(args.serialization_dir)
144
+ serialization_dir.mkdir(parents=True, exist_ok=True)
145
+
146
+ logger = logging_config(serialization_dir)
147
+
148
+ random.seed(args.seed)
149
+ np.random.seed(args.seed)
150
+ torch.manual_seed(args.seed)
151
+ logger.info("set seed: {}".format(args.seed))
152
+
153
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154
+ n_gpu = torch.cuda.device_count()
155
+ logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
156
+
157
+ # datasets
158
+ logger.info("prepare datasets")
159
+ train_dataset = DenoiseExcelDataset(
160
+ excel_file=args.train_dataset,
161
+ expected_sample_rate=8000,
162
+ max_wave_value=32768.0,
163
+ )
164
+ valid_dataset = DenoiseExcelDataset(
165
+ excel_file=args.valid_dataset,
166
+ expected_sample_rate=8000,
167
+ max_wave_value=32768.0,
168
+ )
169
+ train_data_loader = DataLoader(
170
+ dataset=train_dataset,
171
+ batch_size=args.batch_size,
172
+ shuffle=True,
173
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
174
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
175
+ collate_fn=collate_fn,
176
+ pin_memory=False,
177
+ # prefetch_factor=64,
178
+ )
179
+ valid_data_loader = DataLoader(
180
+ dataset=valid_dataset,
181
+ batch_size=args.batch_size,
182
+ shuffle=True,
183
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
184
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
185
+ collate_fn=collate_fn,
186
+ pin_memory=False,
187
+ # prefetch_factor=64,
188
+ )
189
+
190
+ # models
191
+ logger.info(f"prepare models. config_file: {args.config_file}")
192
+ config = SimpleLstmIRMConfig.from_pretrained(
193
+ pretrained_model_name_or_path=args.config_file,
194
+ # num_labels=vocabulary.get_vocab_size(namespace="labels")
195
+ )
196
+ model = SimpleLstmIRMPretrainedModel(
197
+ config=config,
198
+ )
199
+ model.to(device)
200
+ model.train()
201
+
202
+ # optimizer
203
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
204
+ param_optimizer = model.parameters()
205
+ optimizer = torch.optim.Adam(
206
+ param_optimizer,
207
+ lr=args.learning_rate,
208
+ )
209
+ # lr_scheduler = torch.optim.lr_scheduler.StepLR(
210
+ # optimizer,
211
+ # step_size=2000
212
+ # )
213
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
214
+ optimizer,
215
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
216
+ )
217
+ mse_loss = nn.MSELoss(
218
+ reduction="mean",
219
+ )
220
+
221
+ # training loop
222
+ logger.info("training")
223
+
224
+ training_loss = 10000000000
225
+ evaluation_loss = 10000000000
226
+
227
+ model_list = list()
228
+ best_idx_epoch = None
229
+ best_metric = None
230
+ patience_count = 0
231
+
232
+ for idx_epoch in range(args.max_epochs):
233
+ total_loss = 0.
234
+ total_examples = 0.
235
+ progress_bar = tqdm(
236
+ total=len(train_data_loader),
237
+ desc="Training; epoch: {}".format(idx_epoch),
238
+ )
239
+
240
+ for batch in train_data_loader:
241
+ mix_spec, speech_irm, snr_db = batch
242
+ mix_spec = mix_spec.to(device)
243
+ speech_irm_target = speech_irm.to(device)
244
+ snr_db_target = snr_db.to(device)
245
+
246
+ speech_irm_prediction = model.forward(mix_spec)
247
+ loss = mse_loss.forward(speech_irm_prediction, speech_irm_target)
248
+
249
+ total_loss += loss.item()
250
+ total_examples += mix_spec.size(0)
251
+
252
+ optimizer.zero_grad()
253
+ loss.backward()
254
+ optimizer.step()
255
+ lr_scheduler.step()
256
+
257
+ training_loss = total_loss / total_examples
258
+ training_loss = round(training_loss, 4)
259
+
260
+ progress_bar.update(1)
261
+ progress_bar.set_postfix({
262
+ "training_loss": training_loss,
263
+ })
264
+
265
+ total_loss = 0.
266
+ total_examples = 0.
267
+ progress_bar = tqdm(
268
+ total=len(valid_data_loader),
269
+ desc="Evaluation; epoch: {}".format(idx_epoch),
270
+ )
271
+ for batch in valid_data_loader:
272
+ mix_spec, speech_irm, snr_db = batch
273
+ mix_spec = mix_spec.to(device)
274
+ speech_irm_target = speech_irm.to(device)
275
+ snr_db_target = snr_db.to(device)
276
+
277
+ with torch.no_grad():
278
+ speech_irm_prediction = model.forward(mix_spec)
279
+ loss = mse_loss.forward(speech_irm_prediction, speech_irm_target)
280
+
281
+ total_loss += loss.item()
282
+ total_examples += mix_spec.size(0)
283
+
284
+ evaluation_loss = total_loss / total_examples
285
+ evaluation_loss = round(evaluation_loss, 4)
286
+
287
+ progress_bar.update(1)
288
+ progress_bar.set_postfix({
289
+ "evaluation_loss": evaluation_loss,
290
+ })
291
+
292
+ # save path
293
+ epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
294
+ epoch_dir.mkdir(parents=True, exist_ok=False)
295
+
296
+ # save models
297
+ model.save_pretrained(epoch_dir.as_posix())
298
+
299
+ model_list.append(epoch_dir)
300
+ if len(model_list) >= args.num_serialized_models_to_keep:
301
+ model_to_delete: Path = model_list.pop(0)
302
+ shutil.rmtree(model_to_delete.as_posix())
303
+
304
+ # save metric
305
+ if best_metric is None:
306
+ best_idx_epoch = idx_epoch
307
+ best_metric = evaluation_loss
308
+ elif evaluation_loss < best_metric:
309
+ best_idx_epoch = idx_epoch
310
+ best_metric = evaluation_loss
311
+ else:
312
+ pass
313
+
314
+ metrics = {
315
+ "idx_epoch": idx_epoch,
316
+ "best_idx_epoch": best_idx_epoch,
317
+ "training_loss": training_loss,
318
+ "evaluation_loss": evaluation_loss,
319
+ "learning_rate": optimizer.param_groups[0]["lr"],
320
+ }
321
+ metrics_filename = epoch_dir / "metrics_epoch.json"
322
+ with open(metrics_filename, "w", encoding="utf-8") as f:
323
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
324
+
325
+ # save best
326
+ best_dir = serialization_dir / "best"
327
+ if best_idx_epoch == idx_epoch:
328
+ if best_dir.exists():
329
+ shutil.rmtree(best_dir)
330
+ shutil.copytree(epoch_dir, best_dir)
331
+
332
+ # early stop
333
+ early_stop_flag = False
334
+ if best_idx_epoch == idx_epoch:
335
+ patience_count = 0
336
+ else:
337
+ patience_count += 1
338
+ if patience_count >= args.patience:
339
+ early_stop_flag = True
340
+
341
+ # early stop
342
+ if early_stop_flag:
343
+ break
344
+ return
345
+
346
+
347
+ if __name__ == '__main__':
348
+ main()
examples/simple_lstm_irm_aishell/step_3_evaluation.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import logging
5
+ import os
6
+ from pathlib import Path
7
+ import sys
8
+ import uuid
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ import pandas as pd
16
+ from scipy.io import wavfile
17
+ import torch
18
+ import torch.nn as nn
19
+ import torchaudio
20
+ from tqdm import tqdm
21
+
22
+ from toolbox.torchaudio.models.simple_lstm_irm.modeling_simple_lstm_irm import SimpleLstmIRMPretrainedModel
23
+
24
+
25
+ def get_args():
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
28
+ parser.add_argument("--model_dir", default="serialization_dir/best", type=str)
29
+ parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str)
30
+
31
+ parser.add_argument("--limit", default=10, type=int)
32
+
33
+ args = parser.parse_args()
34
+ return args
35
+
36
+
37
+ def logging_config():
38
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
39
+
40
+ logging.basicConfig(format=fmt,
41
+ datefmt="%m/%d/%Y %H:%M:%S",
42
+ level=logging.INFO)
43
+ stream_handler = logging.StreamHandler()
44
+ stream_handler.setLevel(logging.INFO)
45
+ stream_handler.setFormatter(logging.Formatter(fmt))
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+ return logger
50
+
51
+
52
+ def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float):
53
+ l1 = len(speech)
54
+ l2 = len(noise)
55
+ l = min(l1, l2)
56
+ speech = speech[:l]
57
+ noise = noise[:l]
58
+
59
+ # np.float32, value between (-1, 1).
60
+
61
+ speech_power = np.mean(np.square(speech))
62
+ noise_power = speech_power / (10 ** (snr_db / 10))
63
+
64
+ noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2))
65
+
66
+ noisy_signal = speech + noise_adjusted
67
+
68
+ return noisy_signal
69
+
70
+
71
+ stft_power = torchaudio.transforms.Spectrogram(
72
+ n_fft=512,
73
+ win_length=200,
74
+ hop_length=80,
75
+ power=2.0,
76
+ window_fn=torch.hamming_window,
77
+ )
78
+
79
+
80
+ stft_complex = torchaudio.transforms.Spectrogram(
81
+ n_fft=512,
82
+ win_length=200,
83
+ hop_length=80,
84
+ power=None,
85
+ window_fn=torch.hamming_window,
86
+ )
87
+
88
+
89
+ istft = torchaudio.transforms.InverseSpectrogram(
90
+ n_fft=512,
91
+ win_length=200,
92
+ hop_length=80,
93
+ window_fn=torch.hamming_window,
94
+ )
95
+
96
+
97
+ def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor):
98
+ mix_spec_complex = mix_spec_complex.detach().cpu()
99
+ speech_irm_prediction = speech_irm_prediction.detach().cpu()
100
+
101
+ mask_speech = speech_irm_prediction
102
+ mask_noise = 1.0 - speech_irm_prediction
103
+
104
+ speech_spec = mix_spec_complex * mask_speech
105
+ noise_spec = mix_spec_complex * mask_noise
106
+
107
+ speech_wave = istft.forward(speech_spec)
108
+ noise_wave = istft.forward(noise_spec)
109
+
110
+ return speech_wave, noise_wave
111
+
112
+
113
+ def save_audios(noise_wave: torch.Tensor,
114
+ speech_wave: torch.Tensor,
115
+ mix_wave: torch.Tensor,
116
+ speech_wave_enhanced: torch.Tensor,
117
+ noise_wave_enhanced: torch.Tensor,
118
+ output_dir: str,
119
+ sample_rate: int = 8000,
120
+ ):
121
+ basename = uuid.uuid4().__str__()
122
+ output_dir = Path(output_dir) / basename
123
+ output_dir.mkdir(parents=True, exist_ok=True)
124
+
125
+ filename = output_dir / "noise_wave.wav"
126
+ torchaudio.save(filename, noise_wave, sample_rate)
127
+ filename = output_dir / "speech_wave.wav"
128
+ torchaudio.save(filename, speech_wave, sample_rate)
129
+ filename = output_dir / "mix_wave.wav"
130
+ torchaudio.save(filename, mix_wave, sample_rate)
131
+
132
+ filename = output_dir / "speech_wave_enhanced.wav"
133
+ torchaudio.save(filename, speech_wave_enhanced, sample_rate)
134
+ filename = output_dir / "noise_wave_enhanced.wav"
135
+ torchaudio.save(filename, noise_wave_enhanced, sample_rate)
136
+
137
+ return output_dir.as_posix()
138
+
139
+
140
+ def main():
141
+ args = get_args()
142
+
143
+ logger = logging_config()
144
+
145
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
146
+ n_gpu = torch.cuda.device_count()
147
+ logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
148
+
149
+ logger.info("prepare model")
150
+ model = SimpleLstmIRMPretrainedModel.from_pretrained(
151
+ pretrained_model_name_or_path=args.model_dir,
152
+ )
153
+ model.to(device)
154
+ model.eval()
155
+
156
+ # optimizer
157
+ logger.info("prepare loss_fn")
158
+ mse_loss = nn.MSELoss(
159
+ reduction="mean",
160
+ )
161
+
162
+ logger.info("read excel")
163
+ df = pd.read_excel(args.valid_dataset)
164
+
165
+ total_loss = 0.
166
+ total_examples = 0.
167
+ progress_bar = tqdm(total=len(df), desc="Evaluation")
168
+ for idx, row in df.iterrows():
169
+ noise_filename = row["noise_filename"]
170
+ noise_offset = row["noise_offset"]
171
+ noise_duration = row["noise_duration"]
172
+
173
+ speech_filename = row["speech_filename"]
174
+ speech_offset = row["speech_offset"]
175
+ speech_duration = row["speech_duration"]
176
+
177
+ snr_db = row["snr_db"]
178
+
179
+ noise_wave, _ = librosa.load(
180
+ noise_filename,
181
+ sr=8000,
182
+ offset=noise_offset,
183
+ duration=noise_duration,
184
+ )
185
+ speech_wave, _ = librosa.load(
186
+ speech_filename,
187
+ sr=8000,
188
+ offset=speech_offset,
189
+ duration=speech_duration,
190
+ )
191
+ mix_wave: np.ndarray = mix_speech_and_noise(
192
+ speech=speech_wave,
193
+ noise=noise_wave,
194
+ snr_db=snr_db,
195
+ )
196
+ noise_wave = torch.tensor(noise_wave, dtype=torch.float32)
197
+ speech_wave = torch.tensor(speech_wave, dtype=torch.float32)
198
+ mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32)
199
+
200
+ noise_wave = noise_wave.unsqueeze(dim=0)
201
+ speech_wave = speech_wave.unsqueeze(dim=0)
202
+ mix_wave = mix_wave.unsqueeze(dim=0)
203
+
204
+ noise_spec: torch.Tensor = stft_power.forward(noise_wave)
205
+ speech_spec: torch.Tensor = stft_power.forward(speech_wave)
206
+ mix_spec: torch.Tensor = stft_power.forward(mix_wave)
207
+ mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave)
208
+
209
+ speech_irm = speech_spec / (noise_spec + speech_spec)
210
+ speech_irm = torch.pow(speech_irm, 1.0)
211
+
212
+ mix_spec = mix_spec.to(device)
213
+ speech_irm_target = speech_irm.to(device)
214
+ with torch.no_grad():
215
+ speech_irm_prediction = model.forward(mix_spec)
216
+ loss = mse_loss.forward(speech_irm_prediction, speech_irm_target)
217
+
218
+ speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_irm_prediction)
219
+ save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
220
+
221
+ total_loss += loss.item()
222
+ total_examples += mix_spec.size(0)
223
+
224
+ evaluation_loss = total_loss / total_examples
225
+ evaluation_loss = round(evaluation_loss, 4)
226
+
227
+ progress_bar.update(1)
228
+ progress_bar.set_postfix({
229
+ "evaluation_loss": evaluation_loss,
230
+ })
231
+
232
+ if idx > args.limit:
233
+ break
234
+
235
+ return
236
+
237
+
238
+ if __name__ == '__main__':
239
+ main()
examples/spectrum_unet_irm_aishell/run.sh ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+
6
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir \
7
+ --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
8
+ --speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train"
9
+
10
+
11
+ sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir \
12
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
+
15
+
16
+ END
17
+
18
+
19
+ # params
20
+ system_version="windows";
21
+ verbose=true;
22
+ stage=0 # start from 0 if you need to start from data preparation
23
+ stop_stage=9
24
+
25
+ work_dir="$(pwd)"
26
+ file_folder_name=file_folder_name
27
+ final_model_name=final_model_name
28
+ config_file="yaml/config.yaml"
29
+ limit=10
30
+
31
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
32
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
33
+
34
+ nohup_name=nohup.out
35
+
36
+ # model params
37
+ batch_size=64
38
+ max_epochs=200
39
+ save_top_k=10
40
+ patience=5
41
+
42
+
43
+ # parse options
44
+ while true; do
45
+ [ -z "${1:-}" ] && break; # break if there are no arguments
46
+ case "$1" in
47
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
48
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
49
+ old_value="(eval echo \\$$name)";
50
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
51
+ was_bool=true;
52
+ else
53
+ was_bool=false;
54
+ fi
55
+
56
+ # Set the variable to the right value-- the escaped quotes make it work if
57
+ # the option had spaces, like --cmd "queue.pl -sync y"
58
+ eval "${name}=\"$2\"";
59
+
60
+ # Check that Boolean-valued arguments are really Boolean.
61
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
62
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
63
+ exit 1;
64
+ fi
65
+ shift 2;
66
+ ;;
67
+
68
+ *) break;
69
+ esac
70
+ done
71
+
72
+ file_dir="${work_dir}/${file_folder_name}"
73
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
74
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
75
+
76
+ dataset="${file_dir}/dataset.xlsx"
77
+ train_dataset="${file_dir}/train.xlsx"
78
+ valid_dataset="${file_dir}/valid.xlsx"
79
+
80
+ $verbose && echo "system_version: ${system_version}"
81
+ $verbose && echo "file_folder_name: ${file_folder_name}"
82
+
83
+ if [ $system_version == "windows" ]; then
84
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
85
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
86
+ #source /data/local/bin/nx_denoise/bin/activate
87
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
88
+ fi
89
+
90
+
91
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
92
+ $verbose && echo "stage 1: prepare data"
93
+ cd "${work_dir}" || exit 1
94
+ python3 step_1_prepare_data.py \
95
+ --file_dir "${file_dir}" \
96
+ --noise_dir "${noise_dir}" \
97
+ --speech_dir "${speech_dir}" \
98
+ --train_dataset "${train_dataset}" \
99
+ --valid_dataset "${valid_dataset}" \
100
+
101
+ fi
102
+
103
+
104
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
105
+ $verbose && echo "stage 2: train model"
106
+ cd "${work_dir}" || exit 1
107
+ python3 step_2_train_model.py \
108
+ --train_dataset "${train_dataset}" \
109
+ --valid_dataset "${valid_dataset}" \
110
+ --serialization_dir "${file_dir}" \
111
+ --config_file "${config_file}" \
112
+
113
+ fi
114
+
115
+
116
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
117
+ $verbose && echo "stage 3: test model"
118
+ cd "${work_dir}" || exit 1
119
+ python3 step_3_evaluation.py \
120
+ --valid_dataset "${valid_dataset}" \
121
+ --model_dir "${file_dir}/best" \
122
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
123
+ --limit "${limit}" \
124
+
125
+ fi
126
+
127
+
128
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
129
+ $verbose && echo "stage 4: export model"
130
+ cd "${work_dir}" || exit 1
131
+ python3 step_5_export_models.py \
132
+ --vocabulary_dir "${vocabulary_dir}" \
133
+ --model_dir "${file_dir}/best" \
134
+ --serialization_dir "${file_dir}" \
135
+
136
+ fi
137
+
138
+
139
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
140
+ $verbose && echo "stage 5: collect files"
141
+ cd "${work_dir}" || exit 1
142
+
143
+ mkdir -p ${final_model_dir}
144
+
145
+ cp "${file_dir}/best"/* "${final_model_dir}"
146
+ cp -r "${file_dir}/vocabulary" "${final_model_dir}"
147
+
148
+ cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
149
+
150
+ cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
151
+ cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
152
+ cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
153
+ cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
154
+
155
+ cd "${final_model_dir}/.." || exit 1;
156
+
157
+ if [ -e "${final_model_name}.zip" ]; then
158
+ rm -rf "${final_model_name}_backup.zip"
159
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
160
+ fi
161
+
162
+ zip -r "${final_model_name}.zip" "${final_model_name}"
163
+ rm -rf "${final_model_name}"
164
+
165
+ fi
166
+
167
+
168
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
169
+ $verbose && echo "stage 6: clear file_dir"
170
+ cd "${work_dir}" || exit 1
171
+
172
+ rm -rf "${file_dir}";
173
+
174
+ fi
examples/spectrum_unet_irm_aishell/step_1_prepare_data.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import random
7
+ import sys
8
+ import shutil
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import pandas as pd
14
+ from scipy.io import wavfile
15
+ from tqdm import tqdm
16
+ import librosa
17
+
18
+ from project_settings import project_path
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--file_dir", default="./", type=str)
24
+
25
+ parser.add_argument(
26
+ "--noise_dir",
27
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
28
+ type=str
29
+ )
30
+ parser.add_argument(
31
+ "--speech_dir",
32
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
33
+ type=str
34
+ )
35
+
36
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
+
39
+ parser.add_argument("--duration", default=2.0, type=float)
40
+ parser.add_argument("--min_nsr_db", default=-20, type=float)
41
+ parser.add_argument("--max_nsr_db", default=5, type=float)
42
+
43
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
44
+
45
+ args = parser.parse_args()
46
+ return args
47
+
48
+
49
+ def filename_generator(data_dir: str):
50
+ data_dir = Path(data_dir)
51
+ for filename in data_dir.glob("**/*.wav"):
52
+ yield filename.as_posix()
53
+
54
+
55
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
56
+ data_dir = Path(data_dir)
57
+ for filename in data_dir.glob("**/*.wav"):
58
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
+
61
+ if raw_duration < duration:
62
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
+ continue
64
+ if signal.ndim != 1:
65
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
+
67
+ signal_length = len(signal)
68
+ win_size = int(duration * sample_rate)
69
+ for begin in range(0, signal_length - win_size, win_size):
70
+ row = {
71
+ "filename": filename.as_posix(),
72
+ "raw_duration": round(raw_duration, 4),
73
+ "offset": round(begin / sample_rate, 4),
74
+ "duration": round(duration, 4),
75
+ }
76
+ yield row
77
+
78
+
79
+ def get_dataset(args):
80
+ file_dir = Path(args.file_dir)
81
+ file_dir.mkdir(exist_ok=True)
82
+
83
+ noise_dir = Path(args.noise_dir)
84
+ speech_dir = Path(args.speech_dir)
85
+
86
+ noise_generator = target_second_signal_generator(
87
+ noise_dir.as_posix(),
88
+ duration=args.duration,
89
+ sample_rate=args.target_sample_rate
90
+ )
91
+ speech_generator = target_second_signal_generator(
92
+ speech_dir.as_posix(),
93
+ duration=args.duration,
94
+ sample_rate=args.target_sample_rate
95
+ )
96
+
97
+ dataset = list()
98
+
99
+ count = 0
100
+ process_bar = tqdm(desc="build dataset excel")
101
+ for noise, speech in zip(noise_generator, speech_generator):
102
+
103
+ noise_filename = noise["filename"]
104
+ noise_raw_duration = noise["raw_duration"]
105
+ noise_offset = noise["offset"]
106
+ noise_duration = noise["duration"]
107
+
108
+ speech_filename = speech["filename"]
109
+ speech_raw_duration = speech["raw_duration"]
110
+ speech_offset = speech["offset"]
111
+ speech_duration = speech["duration"]
112
+
113
+ random1 = random.random()
114
+ random2 = random.random()
115
+
116
+ row = {
117
+ "noise_filename": noise_filename,
118
+ "noise_raw_duration": noise_raw_duration,
119
+ "noise_offset": noise_offset,
120
+ "noise_duration": noise_duration,
121
+
122
+ "speech_filename": speech_filename,
123
+ "speech_raw_duration": speech_raw_duration,
124
+ "speech_offset": speech_offset,
125
+ "speech_duration": speech_duration,
126
+
127
+ "snr_db": random.uniform(args.min_nsr_db, args.max_nsr_db),
128
+
129
+ "random1": random1,
130
+ "random2": random2,
131
+ "flag": "TRAIN" if random2 < 0.8 else "TEST",
132
+ }
133
+ dataset.append(row)
134
+ count += 1
135
+ duration_seconds = count * args.duration
136
+ duration_hours = duration_seconds / 3600
137
+
138
+ process_bar.update(n=1)
139
+ process_bar.set_postfix({
140
+ # "duration_seconds": round(duration_seconds, 4),
141
+ "duration_hours": round(duration_hours, 4),
142
+
143
+ })
144
+
145
+ dataset = pd.DataFrame(dataset)
146
+ dataset = dataset.sort_values(by=["random1"], ascending=False)
147
+ dataset.to_excel(
148
+ file_dir / "dataset.xlsx",
149
+ index=False,
150
+ )
151
+ return
152
+
153
+
154
+
155
+ def split_dataset(args):
156
+ """分割训练集, 测试集"""
157
+ file_dir = Path(args.file_dir)
158
+ file_dir.mkdir(exist_ok=True)
159
+
160
+ df = pd.read_excel(file_dir / "dataset.xlsx")
161
+
162
+ train = list()
163
+ test = list()
164
+
165
+ for i, row in df.iterrows():
166
+ flag = row["flag"]
167
+ if flag == "TRAIN":
168
+ train.append(row)
169
+ else:
170
+ test.append(row)
171
+
172
+ train = pd.DataFrame(train)
173
+ train.to_excel(
174
+ args.train_dataset,
175
+ index=False,
176
+ # encoding="utf_8_sig"
177
+ )
178
+ test = pd.DataFrame(test)
179
+ test.to_excel(
180
+ args.valid_dataset,
181
+ index=False,
182
+ # encoding="utf_8_sig"
183
+ )
184
+
185
+ return
186
+
187
+
188
+ def main():
189
+ args = get_args()
190
+
191
+ get_dataset(args)
192
+ split_dataset(args)
193
+ return
194
+
195
+
196
+ if __name__ == "__main__":
197
+ main()
examples/spectrum_unet_irm_aishell/step_2_train_model.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
5
+ """
6
+ import argparse
7
+ import json
8
+ import logging
9
+ from logging.handlers import TimedRotatingFileHandler
10
+ import os
11
+ import platform
12
+ from pathlib import Path
13
+ import random
14
+ import sys
15
+ import shutil
16
+ from typing import List
17
+
18
+ pwd = os.path.abspath(os.path.dirname(__file__))
19
+ sys.path.append(os.path.join(pwd, "../../"))
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+ from torch.utils.data.dataloader import DataLoader
25
+ import torchaudio
26
+ from tqdm import tqdm
27
+
28
+ from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
29
+ from toolbox.torchaudio.models.spectrum_unet_irm.configuration_specturm_unet_irm import SpectrumUnetIRMConfig
30
+ from toolbox.torchaudio.models.spectrum_unet_irm.modeling_spectrum_unet_irm import SpectrumUnetIRMPretrainedModel
31
+
32
+
33
+ def get_args():
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
36
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
37
+
38
+ parser.add_argument("--max_epochs", default=100, type=int)
39
+
40
+ parser.add_argument("--batch_size", default=64, type=int)
41
+ parser.add_argument("--learning_rate", default=1e-3, type=float)
42
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
43
+ parser.add_argument("--patience", default=5, type=int)
44
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
45
+ parser.add_argument("--seed", default=0, type=int)
46
+
47
+ parser.add_argument("--config_file", default="config.yaml", type=str)
48
+
49
+ args = parser.parse_args()
50
+ return args
51
+
52
+
53
+ def logging_config(file_dir: str):
54
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
55
+
56
+ logging.basicConfig(format=fmt,
57
+ datefmt="%m/%d/%Y %H:%M:%S",
58
+ level=logging.INFO)
59
+ file_handler = TimedRotatingFileHandler(
60
+ filename=os.path.join(file_dir, "main.log"),
61
+ encoding="utf-8",
62
+ when="D",
63
+ interval=1,
64
+ backupCount=7
65
+ )
66
+ file_handler.setLevel(logging.INFO)
67
+ file_handler.setFormatter(logging.Formatter(fmt))
68
+ logger = logging.getLogger(__name__)
69
+ logger.addHandler(file_handler)
70
+
71
+ return logger
72
+
73
+
74
+ class CollateFunction(object):
75
+ def __init__(self,
76
+ n_fft: int = 512,
77
+ win_length: int = 200,
78
+ hop_length: int = 80,
79
+ window_fn: str = "hamming",
80
+ irm_beta: float = 1.0,
81
+ epsilon: float = 1e-8,
82
+ ):
83
+ self.n_fft = n_fft
84
+ self.win_length = win_length
85
+ self.hop_length = hop_length
86
+ self.window_fn = window_fn
87
+ self.irm_beta = irm_beta
88
+ self.epsilon = epsilon
89
+
90
+ self.transform = torchaudio.transforms.Spectrogram(
91
+ n_fft=self.n_fft,
92
+ win_length=self.win_length,
93
+ hop_length=self.hop_length,
94
+ power=2.0,
95
+ window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
96
+ )
97
+
98
+ def __call__(self, batch: List[dict]):
99
+ mix_spec_list = list()
100
+ speech_irm_list = list()
101
+ snr_db_list = list()
102
+ for sample in batch:
103
+ noise_wave: torch.Tensor = sample["noise_wave"]
104
+ speech_wave: torch.Tensor = sample["speech_wave"]
105
+ mix_wave: torch.Tensor = sample["mix_wave"]
106
+ # snr_db: float = sample["snr_db"]
107
+
108
+ noise_spec = self.transform.forward(noise_wave)
109
+ speech_spec = self.transform.forward(speech_wave)
110
+ mix_spec = self.transform.forward(mix_wave)
111
+
112
+ # noise_irm = noise_spec / (noise_spec + speech_spec)
113
+ speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon)
114
+ speech_irm = torch.pow(speech_irm, self.irm_beta)
115
+
116
+ # noise_spec, speech_spec, mix_spec, speech_irm
117
+ # shape: [freq_dim, time_steps]
118
+
119
+ snr_db: torch.Tensor = 10 * torch.log10(
120
+ speech_spec / (noise_spec + self.epsilon)
121
+ )
122
+ snr_db = torch.mean(snr_db, dim=0, keepdim=True)
123
+ # snr_db shape: [1, time_steps]
124
+
125
+ mix_spec_list.append(mix_spec)
126
+ speech_irm_list.append(speech_irm)
127
+ snr_db_list.append(snr_db)
128
+
129
+ mix_spec_list = torch.stack(mix_spec_list)
130
+ speech_irm_list = torch.stack(speech_irm_list)
131
+ snr_db_list = torch.stack(snr_db_list) # shape: (batch_size, time_steps, 1)
132
+
133
+ mix_spec_list = mix_spec_list[:, :-1, :]
134
+ speech_irm_list = speech_irm_list[:, :-1, :]
135
+
136
+ # mix_spec_list shape: [batch_size, freq_dim, time_steps]
137
+ # speech_irm_list shape: [batch_size, freq_dim, time_steps]
138
+ # snr_db shape: [batch_size, 1, time_steps]
139
+
140
+ # assert
141
+ if torch.any(torch.isnan(mix_spec_list)):
142
+ raise AssertionError("nan in mix_spec Tensor")
143
+ if torch.any(torch.isnan(speech_irm_list)):
144
+ raise AssertionError("nan in speech_irm Tensor")
145
+ if torch.any(torch.isnan(snr_db_list)):
146
+ raise AssertionError("nan in snr_db Tensor")
147
+
148
+ return mix_spec_list, speech_irm_list, snr_db_list
149
+
150
+
151
+ collate_fn = CollateFunction()
152
+
153
+
154
+ def main():
155
+ args = get_args()
156
+
157
+ serialization_dir = Path(args.serialization_dir)
158
+ serialization_dir.mkdir(parents=True, exist_ok=True)
159
+
160
+ logger = logging_config(serialization_dir)
161
+
162
+ random.seed(args.seed)
163
+ np.random.seed(args.seed)
164
+ torch.manual_seed(args.seed)
165
+ logger.info("set seed: {}".format(args.seed))
166
+
167
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
168
+ n_gpu = torch.cuda.device_count()
169
+ logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
170
+
171
+ # datasets
172
+ logger.info("prepare datasets")
173
+ train_dataset = DenoiseExcelDataset(
174
+ excel_file=args.train_dataset,
175
+ expected_sample_rate=8000,
176
+ max_wave_value=32768.0,
177
+ )
178
+ valid_dataset = DenoiseExcelDataset(
179
+ excel_file=args.valid_dataset,
180
+ expected_sample_rate=8000,
181
+ max_wave_value=32768.0,
182
+ )
183
+ train_data_loader = DataLoader(
184
+ dataset=train_dataset,
185
+ batch_size=args.batch_size,
186
+ shuffle=True,
187
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
188
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
189
+ collate_fn=collate_fn,
190
+ pin_memory=False,
191
+ # prefetch_factor=64,
192
+ )
193
+ valid_data_loader = DataLoader(
194
+ dataset=valid_dataset,
195
+ batch_size=args.batch_size,
196
+ shuffle=True,
197
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
198
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
199
+ collate_fn=collate_fn,
200
+ pin_memory=False,
201
+ # prefetch_factor=64,
202
+ )
203
+
204
+ # models
205
+ logger.info(f"prepare models. config_file: {args.config_file}")
206
+ config = SpectrumUnetIRMConfig.from_pretrained(
207
+ pretrained_model_name_or_path=args.config_file,
208
+ # num_labels=vocabulary.get_vocab_size(namespace="labels")
209
+ )
210
+ model = SpectrumUnetIRMPretrainedModel(
211
+ config=config,
212
+ )
213
+ model.to(device)
214
+ model.train()
215
+
216
+ # optimizer
217
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
218
+ param_optimizer = model.parameters()
219
+ optimizer = torch.optim.Adam(
220
+ param_optimizer,
221
+ lr=args.learning_rate,
222
+ )
223
+ # lr_scheduler = torch.optim.lr_scheduler.StepLR(
224
+ # optimizer,
225
+ # step_size=2000
226
+ # )
227
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
228
+ optimizer,
229
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
230
+ )
231
+ irm_mse_loss = nn.MSELoss(
232
+ reduction="mean",
233
+ )
234
+ snr_mse_loss = nn.MSELoss(
235
+ reduction="mean",
236
+ )
237
+
238
+ # training loop
239
+ logger.info("training")
240
+
241
+ training_loss = 10000000000
242
+ evaluation_loss = 10000000000
243
+
244
+ model_list = list()
245
+ best_idx_epoch = None
246
+ best_metric = None
247
+ patience_count = 0
248
+
249
+ for idx_epoch in range(args.max_epochs):
250
+ total_loss = 0.
251
+ total_examples = 0.
252
+ progress_bar = tqdm(
253
+ total=len(train_data_loader),
254
+ desc="Training; epoch: {}".format(idx_epoch),
255
+ )
256
+
257
+ for batch in train_data_loader:
258
+ mix_spec, speech_irm, snr_db = batch
259
+ mix_spec = mix_spec.to(device)
260
+ speech_irm_target = speech_irm.to(device)
261
+ snr_db_target = snr_db.to(device)
262
+
263
+ speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
264
+ irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
265
+ # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
266
+ # loss = irm_loss + 0.1 * snr_loss
267
+ loss = irm_loss
268
+
269
+ total_loss += loss.item()
270
+ total_examples += mix_spec.size(0)
271
+
272
+ optimizer.zero_grad()
273
+ loss.backward()
274
+ optimizer.step()
275
+ lr_scheduler.step()
276
+
277
+ training_loss = total_loss / total_examples
278
+ training_loss = round(training_loss, 4)
279
+
280
+ progress_bar.update(1)
281
+ progress_bar.set_postfix({
282
+ "training_loss": training_loss,
283
+ })
284
+
285
+ total_loss = 0.
286
+ total_examples = 0.
287
+ progress_bar = tqdm(
288
+ total=len(valid_data_loader),
289
+ desc="Evaluation; epoch: {}".format(idx_epoch),
290
+ )
291
+ for batch in valid_data_loader:
292
+ mix_spec, speech_irm, snr_db = batch
293
+ mix_spec = mix_spec.to(device)
294
+ speech_irm_target = speech_irm.to(device)
295
+ snr_db_target = snr_db.to(device)
296
+
297
+ with torch.no_grad():
298
+ speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
299
+ irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
300
+ # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
301
+ # loss = irm_loss + 0.1 * snr_loss
302
+ loss = irm_loss
303
+
304
+ total_loss += loss.item()
305
+ total_examples += mix_spec.size(0)
306
+
307
+ evaluation_loss = total_loss / total_examples
308
+ evaluation_loss = round(evaluation_loss, 4)
309
+
310
+ progress_bar.update(1)
311
+ progress_bar.set_postfix({
312
+ "evaluation_loss": evaluation_loss,
313
+ })
314
+
315
+ # save path
316
+ epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
317
+ epoch_dir.mkdir(parents=True, exist_ok=False)
318
+
319
+ # save models
320
+ model.save_pretrained(epoch_dir.as_posix())
321
+
322
+ model_list.append(epoch_dir)
323
+ if len(model_list) >= args.num_serialized_models_to_keep:
324
+ model_to_delete: Path = model_list.pop(0)
325
+ shutil.rmtree(model_to_delete.as_posix())
326
+
327
+ # save metric
328
+ if best_metric is None:
329
+ best_idx_epoch = idx_epoch
330
+ best_metric = evaluation_loss
331
+ elif evaluation_loss < best_metric:
332
+ best_idx_epoch = idx_epoch
333
+ best_metric = evaluation_loss
334
+ else:
335
+ pass
336
+
337
+ metrics = {
338
+ "idx_epoch": idx_epoch,
339
+ "best_idx_epoch": best_idx_epoch,
340
+ "training_loss": training_loss,
341
+ "evaluation_loss": evaluation_loss,
342
+ "learning_rate": optimizer.param_groups[0]["lr"],
343
+ }
344
+ metrics_filename = epoch_dir / "metrics_epoch.json"
345
+ with open(metrics_filename, "w", encoding="utf-8") as f:
346
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
347
+
348
+ # save best
349
+ best_dir = serialization_dir / "best"
350
+ if best_idx_epoch == idx_epoch:
351
+ if best_dir.exists():
352
+ shutil.rmtree(best_dir)
353
+ shutil.copytree(epoch_dir, best_dir)
354
+
355
+ # early stop
356
+ early_stop_flag = False
357
+ if best_idx_epoch == idx_epoch:
358
+ patience_count = 0
359
+ else:
360
+ patience_count += 1
361
+ if patience_count >= args.patience:
362
+ early_stop_flag = True
363
+
364
+ # early stop
365
+ if early_stop_flag:
366
+ break
367
+ return
368
+
369
+
370
+ if __name__ == '__main__':
371
+ main()
examples/spectrum_unet_irm_aishell/step_3_evaluation.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import logging
5
+ import os
6
+ from pathlib import Path
7
+ import sys
8
+ import uuid
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import librosa
14
+ import numpy as np
15
+ import pandas as pd
16
+ from scipy.io import wavfile
17
+ import torch
18
+ import torch.nn as nn
19
+ import torchaudio
20
+ from tqdm import tqdm
21
+
22
+ from toolbox.torchaudio.models.spectrum_unet_irm.modeling_spectrum_unet_irm import SpectrumUnetIRMPretrainedModel
23
+
24
+
25
+ def get_args():
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
28
+ parser.add_argument("--model_dir", default="serialization_dir/best", type=str)
29
+ parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str)
30
+
31
+ parser.add_argument("--limit", default=10, type=int)
32
+
33
+ args = parser.parse_args()
34
+ return args
35
+
36
+
37
+ def logging_config():
38
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
39
+
40
+ logging.basicConfig(format=fmt,
41
+ datefmt="%m/%d/%Y %H:%M:%S",
42
+ level=logging.INFO)
43
+ stream_handler = logging.StreamHandler()
44
+ stream_handler.setLevel(logging.INFO)
45
+ stream_handler.setFormatter(logging.Formatter(fmt))
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+ return logger
50
+
51
+
52
+ def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float):
53
+ l1 = len(speech)
54
+ l2 = len(noise)
55
+ l = min(l1, l2)
56
+ speech = speech[:l]
57
+ noise = noise[:l]
58
+
59
+ # np.float32, value between (-1, 1).
60
+
61
+ speech_power = np.mean(np.square(speech))
62
+ noise_power = speech_power / (10 ** (snr_db / 10))
63
+
64
+ noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2))
65
+
66
+ noisy_signal = speech + noise_adjusted
67
+
68
+ return noisy_signal
69
+
70
+
71
+ stft_power = torchaudio.transforms.Spectrogram(
72
+ n_fft=512,
73
+ win_length=200,
74
+ hop_length=80,
75
+ power=2.0,
76
+ window_fn=torch.hamming_window,
77
+ )
78
+
79
+
80
+ stft_complex = torchaudio.transforms.Spectrogram(
81
+ n_fft=512,
82
+ win_length=200,
83
+ hop_length=80,
84
+ power=None,
85
+ window_fn=torch.hamming_window,
86
+ )
87
+
88
+
89
+ istft = torchaudio.transforms.InverseSpectrogram(
90
+ n_fft=512,
91
+ win_length=200,
92
+ hop_length=80,
93
+ window_fn=torch.hamming_window,
94
+ )
95
+
96
+
97
+ def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor):
98
+ mix_spec_complex = mix_spec_complex.detach().cpu()
99
+ speech_irm_prediction = speech_irm_prediction.detach().cpu()
100
+
101
+ mask_speech = speech_irm_prediction
102
+ mask_noise = 1.0 - speech_irm_prediction
103
+
104
+ speech_spec = mix_spec_complex * mask_speech
105
+ noise_spec = mix_spec_complex * mask_noise
106
+
107
+ speech_wave = istft.forward(speech_spec)
108
+ noise_wave = istft.forward(noise_spec)
109
+
110
+ return speech_wave, noise_wave
111
+
112
+
113
+ def save_audios(noise_wave: torch.Tensor,
114
+ speech_wave: torch.Tensor,
115
+ mix_wave: torch.Tensor,
116
+ speech_wave_enhanced: torch.Tensor,
117
+ noise_wave_enhanced: torch.Tensor,
118
+ output_dir: str,
119
+ sample_rate: int = 8000,
120
+ ):
121
+ basename = uuid.uuid4().__str__()
122
+ output_dir = Path(output_dir) / basename
123
+ output_dir.mkdir(parents=True, exist_ok=True)
124
+
125
+ filename = output_dir / "noise_wave.wav"
126
+ torchaudio.save(filename, noise_wave, sample_rate)
127
+ filename = output_dir / "speech_wave.wav"
128
+ torchaudio.save(filename, speech_wave, sample_rate)
129
+ filename = output_dir / "mix_wave.wav"
130
+ torchaudio.save(filename, mix_wave, sample_rate)
131
+
132
+ filename = output_dir / "speech_wave_enhanced.wav"
133
+ torchaudio.save(filename, speech_wave_enhanced, sample_rate)
134
+ filename = output_dir / "noise_wave_enhanced.wav"
135
+ torchaudio.save(filename, noise_wave_enhanced, sample_rate)
136
+
137
+ return output_dir.as_posix()
138
+
139
+
140
+ def main():
141
+ args = get_args()
142
+
143
+ logger = logging_config()
144
+
145
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
146
+ n_gpu = torch.cuda.device_count()
147
+ logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
148
+
149
+ logger.info("prepare model")
150
+ model = SpectrumUnetIRMPretrainedModel.from_pretrained(
151
+ pretrained_model_name_or_path=args.model_dir,
152
+ )
153
+ model.to(device)
154
+ model.eval()
155
+
156
+ # optimizer
157
+ logger.info("prepare loss_fn")
158
+ irm_mse_loss = nn.MSELoss(
159
+ reduction="mean",
160
+ )
161
+ snr_mse_loss = nn.MSELoss(
162
+ reduction="mean",
163
+ )
164
+
165
+ logger.info("read excel")
166
+ df = pd.read_excel(args.valid_dataset)
167
+
168
+ total_loss = 0.
169
+ total_examples = 0.
170
+ progress_bar = tqdm(total=len(df), desc="Evaluation")
171
+ for idx, row in df.iterrows():
172
+ noise_filename = row["noise_filename"]
173
+ noise_offset = row["noise_offset"]
174
+ noise_duration = row["noise_duration"]
175
+
176
+ speech_filename = row["speech_filename"]
177
+ speech_offset = row["speech_offset"]
178
+ speech_duration = row["speech_duration"]
179
+
180
+ snr_db = row["snr_db"]
181
+
182
+ noise_wave, _ = librosa.load(
183
+ noise_filename,
184
+ sr=8000,
185
+ offset=noise_offset,
186
+ duration=noise_duration,
187
+ )
188
+ speech_wave, _ = librosa.load(
189
+ speech_filename,
190
+ sr=8000,
191
+ offset=speech_offset,
192
+ duration=speech_duration,
193
+ )
194
+ mix_wave: np.ndarray = mix_speech_and_noise(
195
+ speech=speech_wave,
196
+ noise=noise_wave,
197
+ snr_db=snr_db,
198
+ )
199
+ noise_wave = torch.tensor(noise_wave, dtype=torch.float32)
200
+ speech_wave = torch.tensor(speech_wave, dtype=torch.float32)
201
+ mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32)
202
+
203
+ noise_wave = noise_wave.unsqueeze(dim=0)
204
+ speech_wave = speech_wave.unsqueeze(dim=0)
205
+ mix_wave = mix_wave.unsqueeze(dim=0)
206
+
207
+ noise_spec: torch.Tensor = stft_power.forward(noise_wave)
208
+ speech_spec: torch.Tensor = stft_power.forward(speech_wave)
209
+ mix_spec: torch.Tensor = stft_power.forward(mix_wave)
210
+
211
+ noise_spec = noise_spec[:, :-1, :]
212
+ speech_spec = speech_spec[:, :-1, :]
213
+ mix_spec = mix_spec[:, :-1, :]
214
+
215
+ mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave)
216
+ # mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
217
+
218
+ speech_irm = speech_spec / (noise_spec + speech_spec)
219
+ speech_irm = torch.pow(speech_irm, 1.0)
220
+
221
+ snr_db: torch.Tensor = 10 * torch.log10(
222
+ speech_spec / (noise_spec + 1e-8)
223
+ )
224
+ snr_db = torch.mean(snr_db, dim=1, keepdim=True)
225
+ # snr_db shape: [batch_size, 1, time_steps]
226
+
227
+ mix_spec = mix_spec.to(device)
228
+ speech_irm_target = speech_irm.to(device)
229
+ snr_db_target = snr_db.to(device)
230
+
231
+ with torch.no_grad():
232
+ speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
233
+ irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
234
+ # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
235
+ # loss = irm_loss + 0.1 * snr_loss
236
+ loss = irm_loss
237
+
238
+ # mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
239
+ # speech_irm_prediction shape: [batch_size, freq_dim (256), time_steps]
240
+ batch_size, _, time_steps = speech_irm_prediction.shape
241
+ speech_irm_prediction = torch.concat(
242
+ [
243
+ speech_irm_prediction,
244
+ 0.5*torch.ones(size=(batch_size, 1, time_steps), dtype=speech_irm_prediction.dtype).to(device)
245
+ ],
246
+ dim=1,
247
+ )
248
+ # speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps]
249
+ speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_irm_prediction)
250
+ save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
251
+
252
+ total_loss += loss.item()
253
+ total_examples += mix_spec.size(0)
254
+
255
+ evaluation_loss = total_loss / total_examples
256
+ evaluation_loss = round(evaluation_loss, 4)
257
+
258
+ progress_bar.update(1)
259
+ progress_bar.set_postfix({
260
+ "evaluation_loss": evaluation_loss,
261
+ })
262
+
263
+ if idx > args.limit:
264
+ break
265
+
266
+ return
267
+
268
+
269
+ if __name__ == '__main__':
270
+ main()
examples/spectrum_unet_irm_aishell/yaml/config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "spectrum_unet_irm"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+
9
+ spec_bins: 256
10
+
11
+ # model
12
+ conv_channels: 64
13
+ conv_kernel_size_input:
14
+ - 3
15
+ - 3
16
+ conv_kernel_size_inner:
17
+ - 1
18
+ - 3
19
+ conv_lookahead: 0
20
+
21
+ convt_kernel_size_inner:
22
+ - 1
23
+ - 3
24
+
25
+ encoder_emb_skip_op: "none"
26
+ encoder_emb_linear_groups: 16
27
+ encoder_emb_hidden_size: 256
28
+
29
+ lsnr_max: 20
30
+ lsnr_min: -10
31
+
32
+ decoder_emb_num_layers: 3
33
+ decoder_emb_skip_op: "none"
34
+ decoder_emb_linear_groups: 16
35
+ decoder_emb_hidden_size: 256
install.sh ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # bash install.sh --stage 2 --stop_stage 2 --system_version centos
4
+
5
+
6
+ python_version=3.8.10
7
+ system_version="centos";
8
+
9
+ verbose=true;
10
+ stage=-1
11
+ stop_stage=0
12
+
13
+
14
+ # parse options
15
+ while true; do
16
+ [ -z "${1:-}" ] && break; # break if there are no arguments
17
+ case "$1" in
18
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
19
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
20
+ old_value="(eval echo \\$$name)";
21
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
22
+ was_bool=true;
23
+ else
24
+ was_bool=false;
25
+ fi
26
+
27
+ # Set the variable to the right value-- the escaped quotes make it work if
28
+ # the option had spaces, like --cmd "queue.pl -sync y"
29
+ eval "${name}=\"$2\"";
30
+
31
+ # Check that Boolean-valued arguments are really Boolean.
32
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
33
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
34
+ exit 1;
35
+ fi
36
+ shift 2;
37
+ ;;
38
+
39
+ *) break;
40
+ esac
41
+ done
42
+
43
+ work_dir="$(pwd)"
44
+
45
+
46
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
47
+ $verbose && echo "stage 1: install python"
48
+ cd "${work_dir}" || exit 1;
49
+
50
+ sh ./script/install_python.sh --python_version "${python_version}" --system_version "${system_version}"
51
+ fi
52
+
53
+
54
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
55
+ $verbose && echo "stage 2: create virtualenv"
56
+
57
+ # /usr/local/python-3.9.9/bin/virtualenv nx_denoise
58
+ # source /data/local/bin/nx_denoise/bin/activate
59
+ /usr/local/python-${python_version}/bin/pip3 install virtualenv
60
+ mkdir -p /data/local/bin
61
+ cd /data/local/bin || exit 1;
62
+ /usr/local/python-${python_version}/bin/virtualenv nx_denoise
63
+
64
+ fi
main.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import platform
5
+
6
+ import gradio as gr
7
+
8
+ from project_settings import environment, project_path
9
+
10
+
11
+ def get_args():
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument(
14
+ "--hf_token",
15
+ default=environment.get("hf_token"),
16
+ type=str,
17
+ )
18
+ parser.add_argument(
19
+ "--server_port",
20
+ default=environment.get("server_port", 7860),
21
+ type=int
22
+ )
23
+
24
+ args = parser.parse_args()
25
+ return args
26
+
27
+
28
+ def main():
29
+ args = get_args()
30
+
31
+ # ui
32
+ with gr.Blocks() as blocks:
33
+ gr.Markdown(value="in progress.")
34
+
35
+ # http://127.0.0.1:7864/
36
+ blocks.queue().launch(
37
+ share=False if platform.system() == "Windows" else False,
38
+ server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
39
+ server_port=args.server_port
40
+ )
41
+ return
42
+
43
+
44
+ if __name__ == "__main__":
45
+ main()
project_settings.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from pathlib import Path
5
+
6
+ from toolbox.os.environment import EnvironmentManager
7
+
8
+
9
+ project_path = os.path.abspath(os.path.dirname(__file__))
10
+ project_path = Path(project_path)
11
+
12
+ log_directory = project_path / "logs"
13
+ log_directory.mkdir(parents=True, exist_ok=True)
14
+
15
+ temp_directory = project_path / "temp"
16
+ temp_directory.mkdir(parents=True, exist_ok=True)
17
+
18
+ environment = EnvironmentManager(
19
+ path=os.path.join(project_path, "dotenv"),
20
+ env=os.environ.get("environment", "dev"),
21
+ )
22
+
23
+
24
+ if __name__ == '__main__':
25
+ pass
requirements-python-3-9-9.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.44.1
2
+ datasets==3.2.0
3
+ python-dotenv==1.0.1
4
+ scipy==1.13.1
5
+ librosa==0.10.2.post1
6
+ pandas==2.2.3
7
+ openpyxl==3.1.5
8
+ torch==2.5.1
9
+ torchaudio==2.5.1
10
+ overrides==7.7.0
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==5.12.0
2
+ datasets==3.2.0
3
+ python-dotenv==1.0.1
4
+ scipy==1.15.1
5
+ librosa==0.10.2.post1
6
+ pandas==2.2.3
7
+ openpyxl==3.1.5
8
+ torch==2.5.1
9
+ torchaudio==2.5.1
10
+ overrides==7.7.0
toolbox/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/json/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/json/misc.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Callable
4
+
5
+
6
+ def traverse(js, callback: Callable, *args, **kwargs):
7
+ if isinstance(js, list):
8
+ result = list()
9
+ for l in js:
10
+ l = traverse(l, callback, *args, **kwargs)
11
+ result.append(l)
12
+ return result
13
+ elif isinstance(js, tuple):
14
+ result = list()
15
+ for l in js:
16
+ l = traverse(l, callback, *args, **kwargs)
17
+ result.append(l)
18
+ return tuple(result)
19
+ elif isinstance(js, dict):
20
+ result = dict()
21
+ for k, v in js.items():
22
+ k = traverse(k, callback, *args, **kwargs)
23
+ v = traverse(v, callback, *args, **kwargs)
24
+ result[k] = v
25
+ return result
26
+ elif isinstance(js, int):
27
+ return callback(js, *args, **kwargs)
28
+ elif isinstance(js, str):
29
+ return callback(js, *args, **kwargs)
30
+ else:
31
+ return js
32
+
33
+
34
+ def demo1():
35
+ d = {
36
+ "env": "ppe",
37
+ "mysql_connect": {
38
+ "host": "$mysql_connect_host",
39
+ "port": 3306,
40
+ "user": "callbot",
41
+ "password": "NxcloudAI2021!",
42
+ "database": "callbot_ppe",
43
+ "charset": "utf8"
44
+ },
45
+ "es_connect": {
46
+ "hosts": ["10.20.251.8"],
47
+ "http_auth": ["elastic", "ElasticAI2021!"],
48
+ "port": 9200
49
+ }
50
+ }
51
+
52
+ def callback(s):
53
+ if isinstance(s, str) and s.startswith('$'):
54
+ return s[1:]
55
+ return s
56
+
57
+ result = traverse(d, callback=callback)
58
+ print(result)
59
+ return
60
+
61
+
62
+ if __name__ == '__main__':
63
+ demo1()
toolbox/os/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/os/command.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+
5
+
6
+ class Command(object):
7
+ custom_command = [
8
+ "cd"
9
+ ]
10
+
11
+ @staticmethod
12
+ def _get_cmd(command):
13
+ command = str(command).strip()
14
+ if command == "":
15
+ return None
16
+ cmd_and_args = command.split(sep=" ")
17
+ cmd = cmd_and_args[0]
18
+ args = " ".join(cmd_and_args[1:])
19
+ return cmd, args
20
+
21
+ @classmethod
22
+ def popen(cls, command):
23
+ cmd, args = cls._get_cmd(command)
24
+ if cmd in cls.custom_command:
25
+ method = getattr(cls, cmd)
26
+ return method(args)
27
+ else:
28
+ resp = os.popen(command)
29
+ result = resp.read()
30
+ resp.close()
31
+ return result
32
+
33
+ @classmethod
34
+ def cd(cls, args):
35
+ if args.startswith("/"):
36
+ os.chdir(args)
37
+ else:
38
+ pwd = os.getcwd()
39
+ path = os.path.join(pwd, args)
40
+ os.chdir(path)
41
+
42
+ @classmethod
43
+ def system(cls, command):
44
+ return os.system(command)
45
+
46
+ def __init__(self):
47
+ pass
48
+
49
+
50
+ def ps_ef_grep(keyword: str):
51
+ cmd = "ps -ef | grep {}".format(keyword)
52
+ rows = Command.popen(cmd)
53
+ rows = str(rows).split("\n")
54
+ rows = [row for row in rows if row.__contains__(keyword) and not row.__contains__("grep")]
55
+ return rows
56
+
57
+
58
+ if __name__ == "__main__":
59
+ pass
toolbox/os/environment.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import json
4
+ import os
5
+
6
+ from dotenv import load_dotenv
7
+ from dotenv.main import DotEnv
8
+
9
+ from toolbox.json.misc import traverse
10
+
11
+
12
+ class EnvironmentManager(object):
13
+ def __init__(self, path, env, override=False):
14
+ filename = os.path.join(path, '{}.env'.format(env))
15
+ self.filename = filename
16
+
17
+ load_dotenv(
18
+ dotenv_path=filename,
19
+ override=override
20
+ )
21
+
22
+ self._environ = dict()
23
+
24
+ def open_dotenv(self, filename: str = None):
25
+ filename = filename or self.filename
26
+ dotenv = DotEnv(
27
+ dotenv_path=filename,
28
+ stream=None,
29
+ verbose=False,
30
+ interpolate=False,
31
+ override=False,
32
+ encoding="utf-8",
33
+ )
34
+ result = dotenv.dict()
35
+ return result
36
+
37
+ def get(self, key, default=None, dtype=str):
38
+ result = os.environ.get(key)
39
+ if result is None:
40
+ if default is None:
41
+ result = None
42
+ else:
43
+ result = default
44
+ else:
45
+ result = dtype(result)
46
+ self._environ[key] = result
47
+ return result
48
+
49
+
50
+ _DEFAULT_DTYPE_MAP = {
51
+ 'int': int,
52
+ 'float': float,
53
+ 'str': str,
54
+ 'json.loads': json.loads
55
+ }
56
+
57
+
58
+ class JsonConfig(object):
59
+ """
60
+ 将 json 中, 形如 `$float:threshold` 的值, 处理为:
61
+ 从环境变量中查到 threshold, 再将其转换为 float 类型.
62
+ """
63
+ def __init__(self, dtype_map: dict = None, environment: EnvironmentManager = None):
64
+ self.dtype_map = dtype_map or _DEFAULT_DTYPE_MAP
65
+ self.environment = environment or os.environ
66
+
67
+ def sanitize_by_filename(self, filename: str):
68
+ with open(filename, 'r', encoding='utf-8') as f:
69
+ js = json.load(f)
70
+
71
+ return self.sanitize_by_json(js)
72
+
73
+ def sanitize_by_json(self, js):
74
+ js = traverse(
75
+ js,
76
+ callback=self.sanitize,
77
+ environment=self.environment
78
+ )
79
+ return js
80
+
81
+ def sanitize(self, string, environment):
82
+ """支持 $ 符开始的, 环境变量配置"""
83
+ if isinstance(string, str) and string.startswith('$'):
84
+ dtype, key = string[1:].split(':')
85
+ dtype = self.dtype_map[dtype]
86
+
87
+ value = environment.get(key)
88
+ if value is None:
89
+ raise AssertionError('environment not exist. key: {}'.format(key))
90
+
91
+ value = dtype(value)
92
+ result = value
93
+ else:
94
+ result = string
95
+ return result
96
+
97
+
98
+ def demo1():
99
+ import json
100
+
101
+ from project_settings import project_path
102
+
103
+ environment = EnvironmentManager(
104
+ path=os.path.join(project_path, 'server/callbot_server/dotenv'),
105
+ env='dev',
106
+ )
107
+ init_scenes = environment.get(key='init_scenes', dtype=json.loads)
108
+ print(init_scenes)
109
+ print(environment._environ)
110
+ return
111
+
112
+
113
+ if __name__ == '__main__':
114
+ demo1()
toolbox/os/other.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import inspect
3
+
4
+
5
+ def pwd():
6
+ """你在哪个文件调用此函数, 它就会返回那个文件所在的 dir 目标"""
7
+ frame = inspect.stack()[1]
8
+ module = inspect.getmodule(frame[0])
9
+ return os.path.dirname(os.path.abspath(module.__file__))
toolbox/torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torch/utils/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torch/utils/data/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torch/utils/data/dataset/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torch/utils/data/dataset/denoise_excel_dataset.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import pandas as pd
8
+ from scipy.io import wavfile
9
+ import torch
10
+ import torchaudio
11
+ from torch.utils.data import Dataset
12
+ from tqdm import tqdm
13
+
14
+
15
+ class DenoiseExcelDataset(Dataset):
16
+ def __init__(self,
17
+ excel_file: str,
18
+ expected_sample_rate: int,
19
+ resample: bool = False,
20
+ max_wave_value: float = 1.0,
21
+ ):
22
+ self.excel_file = excel_file
23
+ self.expected_sample_rate = expected_sample_rate
24
+ self.resample = resample
25
+ self.max_wave_value = max_wave_value
26
+
27
+ self.samples = self.load_samples(excel_file)
28
+
29
+ @staticmethod
30
+ def load_samples(filename: str):
31
+ df = pd.read_excel(filename)
32
+ samples = list()
33
+ for i, row in tqdm(df.iterrows(), total=len(df)):
34
+ noise_filename = row["noise_filename"]
35
+ noise_raw_duration = row["noise_raw_duration"]
36
+ noise_offset = row["noise_offset"]
37
+ noise_duration = row["noise_duration"]
38
+
39
+ speech_filename = row["speech_filename"]
40
+ speech_raw_duration = row["speech_raw_duration"]
41
+ speech_offset = row["speech_offset"]
42
+ speech_duration = row["speech_duration"]
43
+
44
+ snr_db = row["snr_db"]
45
+
46
+ row = {
47
+ "noise_filename": noise_filename,
48
+ "noise_raw_duration": noise_raw_duration,
49
+ "noise_offset": noise_offset,
50
+ "noise_duration": noise_duration,
51
+
52
+ "speech_filename": speech_filename,
53
+ "speech_raw_duration": speech_raw_duration,
54
+ "speech_offset": speech_offset,
55
+ "speech_duration": speech_duration,
56
+
57
+ "snr_db": snr_db,
58
+ }
59
+ samples.append(row)
60
+ return samples
61
+
62
+ def __getitem__(self, index):
63
+ sample = self.samples[index]
64
+ noise_filename = sample["noise_filename"]
65
+ noise_offset = sample["noise_offset"]
66
+ noise_duration = sample["noise_duration"]
67
+
68
+ speech_filename = sample["speech_filename"]
69
+ speech_offset = sample["speech_offset"]
70
+ speech_duration = sample["speech_duration"]
71
+
72
+ snr_db = sample["snr_db"]
73
+
74
+ noise_wave = self.filename_to_waveform(noise_filename, noise_offset, noise_duration)
75
+ speech_wave = self.filename_to_waveform(speech_filename, speech_offset, speech_duration)
76
+
77
+ mix_wave, noise_wave_adjusted = self.mix_speech_and_noise(
78
+ speech=speech_wave.numpy(),
79
+ noise=noise_wave.numpy(),
80
+ snr_db=snr_db,
81
+ )
82
+ mix_wave = torch.tensor(mix_wave, dtype=torch.float32)
83
+ noise_wave_adjusted = torch.tensor(noise_wave_adjusted, dtype=torch.float32)
84
+
85
+ result = {
86
+ "noise_wave": noise_wave_adjusted,
87
+ "speech_wave": speech_wave,
88
+ "mix_wave": mix_wave,
89
+ "snr_db": snr_db,
90
+ }
91
+ return result
92
+
93
+ def __len__(self):
94
+ return len(self.samples)
95
+
96
+ def filename_to_waveform(self, filename: str, offset: float, duration: float):
97
+ try:
98
+ waveform, sample_rate = librosa.load(
99
+ filename,
100
+ sr=self.expected_sample_rate,
101
+ offset=offset,
102
+ duration=duration,
103
+ )
104
+ except ValueError as e:
105
+ print(f"load failed. error type: {type(e)}, error text: {str(e)}, filename: {filename}")
106
+ raise e
107
+ waveform = torch.tensor(waveform, dtype=torch.float32)
108
+ return waveform
109
+
110
+ @staticmethod
111
+ def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float):
112
+ l1 = len(speech)
113
+ l2 = len(noise)
114
+ l = min(l1, l2)
115
+ speech = speech[:l]
116
+ noise = noise[:l]
117
+
118
+ # np.float32, value between (-1, 1).
119
+
120
+ speech_power = np.mean(np.square(speech))
121
+ noise_power = speech_power / (10 ** (snr_db / 10))
122
+
123
+ noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2))
124
+
125
+ noisy_signal = speech + noise_adjusted
126
+
127
+ return noisy_signal, noise_adjusted
128
+
129
+
130
+ if __name__ == '__main__':
131
+ pass
toolbox/torchaudio/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torchaudio/configuration_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import copy
4
+ import os
5
+ from typing import Any, Dict, Union
6
+
7
+ import yaml
8
+
9
+
10
+ CONFIG_FILE = "config.yaml"
11
+
12
+
13
+ class PretrainedConfig(object):
14
+ def __init__(self, **kwargs):
15
+ pass
16
+
17
+ @classmethod
18
+ def _dict_from_yaml_file(cls, yaml_file: Union[str, os.PathLike]):
19
+ with open(yaml_file, encoding="utf-8") as f:
20
+ config_dict = yaml.safe_load(f)
21
+ return config_dict
22
+
23
+ @classmethod
24
+ def get_config_dict(
25
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike]
26
+ ) -> Dict[str, Any]:
27
+ if os.path.isdir(pretrained_model_name_or_path):
28
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_FILE)
29
+ else:
30
+ config_file = pretrained_model_name_or_path
31
+ config_dict = cls._dict_from_yaml_file(config_file)
32
+ return config_dict
33
+
34
+ @classmethod
35
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
36
+ for k, v in kwargs.items():
37
+ if k in config_dict.keys():
38
+ config_dict[k] = v
39
+ config = cls(**config_dict)
40
+ return config
41
+
42
+ @classmethod
43
+ def from_pretrained(
44
+ cls,
45
+ pretrained_model_name_or_path: Union[str, os.PathLike],
46
+ **kwargs,
47
+ ):
48
+ config_dict = cls.get_config_dict(pretrained_model_name_or_path)
49
+ return cls.from_dict(config_dict, **kwargs)
50
+
51
+ def to_dict(self):
52
+ output = copy.deepcopy(self.__dict__)
53
+ return output
54
+
55
+ def to_yaml_file(self, yaml_file_path: Union[str, os.PathLike]):
56
+ config_dict = self.to_dict()
57
+
58
+ with open(yaml_file_path, "w", encoding="utf-8") as writer:
59
+ yaml.safe_dump(config_dict, writer)
60
+
61
+
62
+ if __name__ == '__main__':
63
+ pass
toolbox/torchaudio/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ if __name__ == '__main__':
5
+ pass
toolbox/torchaudio/models/clean_unet/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://arxiv.org/abs/2202.07790
5
+ """
6
+
7
+
8
+ if __name__ == '__main__':
9
+ pass
toolbox/torchaudio/models/dfnet3/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/models/dfnet3/configuration_dfnet3.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Any, Dict, List, Tuple, Union
4
+
5
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class DfNetConfig(PretrainedConfig):
9
+ def __init__(self,
10
+ sample_rate: int,
11
+ fft_size: int,
12
+ hop_size: int,
13
+ df_bins: int,
14
+ erb_bins: int,
15
+ min_freq_bins_for_erb: int,
16
+ df_order: int,
17
+ df_lookahead: int,
18
+ norm_tau: int,
19
+ lsnr_max: int,
20
+ lsnr_min: int,
21
+ conv_channels: int,
22
+ conv_kernel_size_input: Tuple[int, int],
23
+ conv_kernel_size_inner: Tuple[int, int],
24
+ convt_kernel_size_inner: Tuple[int, int],
25
+ conv_lookahead: int,
26
+ emb_hidden_dim: int,
27
+ mask_post_filter: bool,
28
+ df_hidden_dim: int,
29
+ df_num_layers: int,
30
+ df_pathway_kernel_size_t: int,
31
+ df_gru_skip: str,
32
+ post_filter_beta: float,
33
+ df_n_iter: float,
34
+ lsnr_dropout: bool,
35
+ encoder_gru_skip_op: str,
36
+ encoder_linear_groups: int,
37
+ encoder_squeezed_gru_linear_groups: int,
38
+ encoder_concat: bool,
39
+ erb_decoder_gru_skip_op: str,
40
+ erb_decoder_linear_groups: int,
41
+ erb_decoder_emb_num_layers: int,
42
+ df_decoder_linear_groups: int,
43
+ **kwargs
44
+ ):
45
+ super(DfNetConfig, self).__init__(**kwargs)
46
+ if df_gru_skip not in ("none", "identity", "grouped_linear"):
47
+ raise AssertionError
48
+
49
+ self.sample_rate = sample_rate
50
+ self.fft_size = fft_size
51
+ self.hop_size = hop_size
52
+ self.df_bins = df_bins
53
+ self.erb_bins = erb_bins
54
+ self.min_freq_bins_for_erb = min_freq_bins_for_erb
55
+ self.df_order = df_order
56
+ self.df_lookahead = df_lookahead
57
+ self.norm_tau = norm_tau
58
+ self.lsnr_max = lsnr_max
59
+ self.lsnr_min = lsnr_min
60
+
61
+ self.conv_channels = conv_channels
62
+ self.conv_kernel_size_input = conv_kernel_size_input
63
+ self.conv_kernel_size_inner = conv_kernel_size_inner
64
+ self.convt_kernel_size_inner = convt_kernel_size_inner
65
+ self.conv_lookahead = conv_lookahead
66
+
67
+ self.emb_hidden_dim = emb_hidden_dim
68
+ self.mask_post_filter = mask_post_filter
69
+ self.df_hidden_dim = df_hidden_dim
70
+ self.df_num_layers = df_num_layers
71
+ self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
72
+ self.df_gru_skip = df_gru_skip
73
+ self.post_filter_beta = post_filter_beta
74
+ self.df_n_iter = df_n_iter
75
+ self.lsnr_dropout = lsnr_dropout
76
+ self.encoder_gru_skip_op = encoder_gru_skip_op
77
+ self.encoder_linear_groups = encoder_linear_groups
78
+ self.encoder_squeezed_gru_linear_groups = encoder_squeezed_gru_linear_groups
79
+ self.encoder_concat = encoder_concat
80
+
81
+ self.erb_decoder_gru_skip_op = erb_decoder_gru_skip_op
82
+ self.erb_decoder_linear_groups = erb_decoder_linear_groups
83
+ self.erb_decoder_emb_num_layers = erb_decoder_emb_num_layers
84
+
85
+ self.df_decoder_linear_groups = df_decoder_linear_groups
86
+
87
+
88
+ if __name__ == "__main__":
89
+ pass
toolbox/torchaudio/models/dfnet3/features.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import math
4
+
5
+ import numpy as np
6
+
7
+
8
+ def freq2erb(freq_hz: float) -> float:
9
+ """
10
+ https://www.cnblogs.com/LXP-Never/p/16011229.html
11
+ 1 / (24.7 * 9.265) = 0.00436976
12
+ """
13
+ return 9.265 * math.log(freq_hz / (24.7 * 9.265) + 1)
14
+
15
+
16
+ def erb2freq(n_erb: float) -> float:
17
+ return 24.7 * 9.265 * (math.exp(n_erb / 9.265) - 1)
18
+
19
+
20
+ def get_erb_widths(sample_rate: int, fft_size: int, erb_bins: int, min_freq_bins_for_erb: int) -> np.ndarray:
21
+ """
22
+ https://github.com/Rikorose/DeepFilterNet/blob/main/libDF/src/lib.rs
23
+ :param sample_rate:
24
+ :param fft_size:
25
+ :param erb_bins: erb (Equivalent Rectangular Bandwidth) 等效矩形带宽的通道数.
26
+ :param min_freq_bins_for_erb: Minimum number of frequency bands per erb band
27
+ :return:
28
+ """
29
+ nyq_freq = sample_rate / 2.
30
+ freq_width: float = sample_rate / fft_size
31
+
32
+ min_erb: float = freq2erb(0.)
33
+ max_erb: float = freq2erb(nyq_freq)
34
+
35
+ erb = [0] * erb_bins
36
+ step = (max_erb - min_erb) / erb_bins
37
+
38
+ prev_freq_bin = 0
39
+ freq_over = 0
40
+ for i in range(1, erb_bins + 1):
41
+ f = erb2freq(min_erb + i * step)
42
+ freq_bin = int(round(f / freq_width))
43
+ freq_bins = freq_bin - prev_freq_bin - freq_over
44
+
45
+ if freq_bins < min_freq_bins_for_erb:
46
+ freq_over = min_freq_bins_for_erb - freq_bins
47
+ freq_bins = min_freq_bins_for_erb
48
+ else:
49
+ freq_over = 0
50
+ erb[i - 1] = freq_bins
51
+ prev_freq_bin = freq_bin
52
+
53
+ erb[erb_bins - 1] += 1
54
+ too_large = sum(erb) - (fft_size / 2 + 1)
55
+ if too_large > 0:
56
+ erb[erb_bins - 1] -= too_large
57
+ return np.array(erb, dtype=np.uint64)
58
+
59
+
60
+ def get_erb_filter_bank(erb_widths: np.ndarray,
61
+ sample_rate: int,
62
+ normalized: bool = True,
63
+ inverse: bool = False,
64
+ ):
65
+ num_freq_bins = int(np.sum(erb_widths))
66
+ num_erb_bins = len(erb_widths)
67
+
68
+ fb: np.ndarray = np.zeros(shape=(num_freq_bins, num_erb_bins))
69
+
70
+ points = np.cumsum([0] + erb_widths.tolist()).astype(int)[:-1]
71
+ for i, (b, w) in enumerate(zip(points.tolist(), erb_widths.tolist())):
72
+ fb[b: b + w, i] = 1
73
+
74
+ if inverse:
75
+ fb = fb.T
76
+ if not normalized:
77
+ fb /= np.sum(fb, axis=1, keepdims=True)
78
+ else:
79
+ if normalized:
80
+ fb /= np.sum(fb, axis=0)
81
+ return fb
82
+
83
+
84
+ def spec2erb(spec: np.ndarray, erb_fb: np.ndarray, db: bool = True):
85
+ """
86
+ ERB filterbank and transform to decibel scale.
87
+
88
+ :param spec: Spectrum of shape [B, C, T, F].
89
+ :param erb_fb: ERB filterbank array of shape [B] containing the ERB widths,
90
+ where B are the number of ERB bins.
91
+ :param db: Whether to transform the output into decibel scale. Defaults to `True`.
92
+ :return:
93
+ """
94
+ # complex spec to power spec. (real * real + image * image)
95
+ spec_ = np.abs(spec) ** 2
96
+
97
+ # spec to erb feature.
98
+ erb_feat = np.matmul(spec_, erb_fb)
99
+
100
+ if db:
101
+ erb_feat = 10 * np.log10(erb_feat + 1e-10)
102
+
103
+ erb_feat = np.array(erb_feat, dtype=np.float32)
104
+ return erb_feat
105
+
106
+
107
+ def _calculate_norm_alpha(sample_rate: int, hop_size: int, tau: float):
108
+ """Exponential decay factor alpha for a given tau (decay window size [s])."""
109
+ dt = hop_size / sample_rate
110
+ result = math.exp(-dt / tau)
111
+ return result
112
+
113
+
114
+ def get_norm_alpha(sample_rate: int, hop_size: int, norm_tau: float) -> float:
115
+ a_ = _calculate_norm_alpha(sample_rate=sample_rate, hop_size=hop_size, tau=norm_tau)
116
+
117
+ precision = 3
118
+ a = 1.0
119
+ while a >= 1.0:
120
+ a = round(a_, precision)
121
+ precision += 1
122
+
123
+ return a
124
+
125
+
126
+ MEAN_NORM_INIT = [-60., -90.]
127
+
128
+
129
+ def make_erb_norm_state(erb_bins: int, channels: int) -> np.ndarray:
130
+ state = np.linspace(MEAN_NORM_INIT[0], MEAN_NORM_INIT[1], erb_bins)
131
+ state = np.expand_dims(state, axis=0)
132
+ state = np.repeat(state, channels, axis=0)
133
+
134
+ # state shape: (audio_channels, erb_bins)
135
+ return state
136
+
137
+
138
+ def erb_normalize(erb_feat: np.ndarray, alpha: float, state: np.ndarray = None):
139
+ erb_feat = np.copy(erb_feat)
140
+ batch_size, time_steps, erb_bins = erb_feat.shape
141
+
142
+ if state is None:
143
+ state = make_erb_norm_state(erb_bins, erb_feat.shape[0])
144
+ # state = np.linspace(MEAN_NORM_INIT[0], MEAN_NORM_INIT[1], erb_bins)
145
+ # state = np.expand_dims(state, axis=0)
146
+ # state = np.repeat(state, erb_feat.shape[0], axis=0)
147
+
148
+ for i in range(batch_size):
149
+ for j in range(time_steps):
150
+ for k in range(erb_bins):
151
+ x = erb_feat[i][j][k]
152
+ s = state[i][k]
153
+
154
+ state[i][k] = x * (1. - alpha) + s * alpha
155
+ erb_feat[i][j][k] -= state[i][k]
156
+ erb_feat[i][j][k] /= 40.
157
+
158
+ return erb_feat
159
+
160
+
161
+ UNIT_NORM_INIT = [0.001, 0.0001]
162
+
163
+
164
+ def make_spec_norm_state(df_bins: int, channels: int) -> np.ndarray:
165
+ state = np.linspace(UNIT_NORM_INIT[0], UNIT_NORM_INIT[1], df_bins)
166
+ state = np.expand_dims(state, axis=0)
167
+ state = np.repeat(state, channels, axis=0)
168
+
169
+ # state shape: (audio_channels, df_bins)
170
+ return state
171
+
172
+
173
+ def spec_normalize(spec_feat: np.ndarray, alpha: float, state: np.ndarray = None):
174
+ spec_feat = np.copy(spec_feat)
175
+ batch_size, time_steps, df_bins = spec_feat.shape
176
+
177
+ if state is None:
178
+ state = make_spec_norm_state(df_bins, spec_feat.shape[0])
179
+
180
+ for i in range(batch_size):
181
+ for j in range(time_steps):
182
+ for k in range(df_bins):
183
+ x = spec_feat[i][j][k]
184
+ s = state[i][k]
185
+
186
+ state[i][k] = np.abs(x) * (1. - alpha) + s * alpha
187
+ spec_feat[i][j][k] /= np.sqrt(state[i][k])
188
+ return spec_feat
189
+
190
+
191
+ if __name__ == '__main__':
192
+ pass
toolbox/torchaudio/models/dfnet3/modeling_dfnet3.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import logging
4
+ import math
5
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from toolbox.torchaudio.models.dfnet3.configuration_dfnet3 import DfNetConfig
12
+ from toolbox.torchaudio.models.dfnet3 import multiframes as MF
13
+ from toolbox.torchaudio.models.dfnet3 import utils
14
+
15
+ logger = logging.getLogger("toolbox")
16
+
17
+ PI = 3.1415926535897932384626433
18
+
19
+
20
+ norm_layer_dict = {
21
+ "batch_norm_2d": torch.nn.BatchNorm2d
22
+ }
23
+
24
+ activation_layer_dict = {
25
+ "relu": torch.nn.ReLU,
26
+ "identity": torch.nn.Identity,
27
+ "sigmoid": torch.nn.Sigmoid,
28
+ }
29
+
30
+
31
+ class CausalConv2d(nn.Sequential):
32
+ def __init__(self,
33
+ in_channels: int,
34
+ out_channels: int,
35
+ kernel_size: Union[int, Iterable[int]],
36
+ fstride: int = 1,
37
+ dilation: int = 1,
38
+ fpad: bool = True,
39
+ bias: bool = True,
40
+ separable: bool = False,
41
+ norm_layer: str = "batch_norm_2d",
42
+ activation_layer: str = "relu",
43
+ ):
44
+ """
45
+ Causal Conv2d by delaying the signal for any lookahead.
46
+
47
+ Expected input format: [B, C, T, F]
48
+
49
+ :param in_channels:
50
+ :param out_channels:
51
+ :param kernel_size:
52
+ :param fstride:
53
+ :param dilation:
54
+ :param fpad:
55
+ """
56
+ super(CausalConv2d, self).__init__()
57
+ lookahead = 0
58
+
59
+ kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
60
+
61
+ if fpad:
62
+ fpad_ = kernel_size[1] // 2 + dilation - 1
63
+ else:
64
+ fpad_ = 0
65
+
66
+ # for last 2 dim, pad (left, right, top, bottom).
67
+ pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
68
+
69
+ layers = []
70
+ if any(x > 0 for x in pad):
71
+ layers.append(nn.ConstantPad2d(pad, 0.0))
72
+
73
+ groups = math.gcd(in_channels, out_channels) if separable else 1
74
+ if groups == 1:
75
+ separable = False
76
+ if max(kernel_size) == 1:
77
+ separable = False
78
+
79
+ layers.append(
80
+ nn.Conv2d(
81
+ in_channels,
82
+ out_channels,
83
+ kernel_size=kernel_size,
84
+ padding=(0, fpad_),
85
+ stride=(1, fstride), # stride over time is always 1
86
+ dilation=(1, dilation), # dilation over time is always 1
87
+ groups=groups,
88
+ bias=bias,
89
+ )
90
+ )
91
+
92
+ if separable:
93
+ layers.append(
94
+ nn.Conv2d(
95
+ out_channels,
96
+ out_channels,
97
+ kernel_size=1,
98
+ bias=False,
99
+ )
100
+ )
101
+
102
+ if norm_layer is not None:
103
+ norm_layer = norm_layer_dict[norm_layer]
104
+ layers.append(norm_layer(out_channels))
105
+
106
+ if activation_layer is not None:
107
+ activation_layer = activation_layer_dict[activation_layer]
108
+ layers.append(activation_layer())
109
+
110
+ super().__init__(*layers)
111
+
112
+
113
+ class CausalConvTranspose2d(nn.Sequential):
114
+ def __init__(self,
115
+ in_channels: int,
116
+ out_channels: int,
117
+ kernel_size: Union[int, Iterable[int]],
118
+ fstride: int = 1,
119
+ dilation: int = 1,
120
+ fpad: bool = True,
121
+ bias: bool = True,
122
+ separable: bool = False,
123
+ norm_layer: str = "batch_norm_2d",
124
+ activation_layer: str = "relu",
125
+ ):
126
+ """
127
+ Causal ConvTranspose2d.
128
+
129
+ Expected input format: [B, C, T, F]
130
+ """
131
+ super(CausalConvTranspose2d, self).__init__()
132
+ lookahead = 0
133
+
134
+ kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
135
+
136
+ if fpad:
137
+ fpad_ = kernel_size[1] // 2
138
+ else:
139
+ fpad_ = 0
140
+
141
+ # for last 2 dim, pad (left, right, top, bottom).
142
+ pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
143
+
144
+ layers = []
145
+ if any(x > 0 for x in pad):
146
+ layers.append(nn.ConstantPad2d(pad, 0.0))
147
+
148
+ groups = math.gcd(in_channels, out_channels) if separable else 1
149
+ if groups == 1:
150
+ separable = False
151
+
152
+ layers.append(
153
+ nn.ConvTranspose2d(
154
+ in_channels,
155
+ out_channels,
156
+ kernel_size=kernel_size,
157
+ padding=(kernel_size[0] - 1, fpad_ + dilation - 1),
158
+ output_padding=(0, fpad_),
159
+ stride=(1, fstride), # stride over time is always 1
160
+ dilation=(1, dilation), # dilation over time is always 1
161
+ groups=groups,
162
+ bias=bias,
163
+ )
164
+ )
165
+
166
+ if separable:
167
+ layers.append(
168
+ nn.Conv2d(
169
+ out_channels,
170
+ out_channels,
171
+ kernel_size=1,
172
+ bias=False,
173
+ )
174
+ )
175
+
176
+ if norm_layer is not None:
177
+ norm_layer = norm_layer_dict[norm_layer]
178
+ layers.append(norm_layer(out_channels))
179
+
180
+ if activation_layer is not None:
181
+ activation_layer = activation_layer_dict[activation_layer]
182
+ layers.append(activation_layer())
183
+
184
+ super().__init__(*layers)
185
+
186
+
187
+ class GroupedLinear(nn.Module):
188
+
189
+ def __init__(self, input_size: int, hidden_size: int, groups: int = 1):
190
+ super().__init__()
191
+ # self.weight: Tensor
192
+ self.input_size = input_size
193
+ self.hidden_size = hidden_size
194
+ self.groups = groups
195
+ assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}"
196
+ assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}"
197
+ self.ws = input_size // groups
198
+ self.register_parameter(
199
+ "weight",
200
+ torch.nn.Parameter(
201
+ torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True
202
+ ),
203
+ )
204
+ self.reset_parameters()
205
+
206
+ def reset_parameters(self):
207
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore
208
+
209
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
210
+ # x: [..., I]
211
+ b, t, _ = x.shape
212
+ # new_shape = list(x.shape)[:-1] + [self.groups, self.ws]
213
+ new_shape = (b, t, self.groups, self.ws)
214
+ x = x.view(new_shape)
215
+ # The better way, but not supported by torchscript
216
+ # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
217
+ x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G]
218
+ x = x.flatten(2, 3) # [B, T, H]
219
+ return x
220
+
221
+ def __repr__(self):
222
+ cls = self.__class__.__name__
223
+ return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})"
224
+
225
+
226
+ class SqueezedGRU_S(nn.Module):
227
+ """
228
+ SGE net: Video object detection with squeezed GRU and information entropy map
229
+ https://arxiv.org/abs/2106.07224
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ input_size: int,
235
+ hidden_size: int,
236
+ output_size: Optional[int] = None,
237
+ num_layers: int = 1,
238
+ linear_groups: int = 8,
239
+ batch_first: bool = True,
240
+ skip_op: str = "none",
241
+ activation_layer: str = "identity",
242
+ ):
243
+ super().__init__()
244
+ self.input_size = input_size
245
+ self.hidden_size = hidden_size
246
+
247
+ self.linear_in = nn.Sequential(
248
+ GroupedLinear(
249
+ input_size=input_size,
250
+ hidden_size=hidden_size,
251
+ groups=linear_groups,
252
+ ),
253
+ activation_layer_dict[activation_layer](),
254
+ )
255
+
256
+ # gru skip operator
257
+ self.gru_skip_op = None
258
+
259
+ if skip_op == "none":
260
+ self.gru_skip_op = None
261
+ elif skip_op == "identity":
262
+ if not input_size != output_size:
263
+ raise AssertionError("Dimensions do not match")
264
+ self.gru_skip_op = nn.Identity()
265
+ elif skip_op == "grouped_linear":
266
+ self.gru_skip_op = GroupedLinear(
267
+ input_size=hidden_size,
268
+ hidden_size=hidden_size,
269
+ groups=linear_groups,
270
+ )
271
+ else:
272
+ raise NotImplementedError()
273
+
274
+ self.gru = nn.GRU(
275
+ input_size=hidden_size,
276
+ hidden_size=hidden_size,
277
+ num_layers=num_layers,
278
+ batch_first=batch_first,
279
+ )
280
+
281
+ if output_size is not None:
282
+ self.linear_out = nn.Sequential(
283
+ GroupedLinear(
284
+ input_size=hidden_size,
285
+ hidden_size=output_size,
286
+ groups=linear_groups,
287
+ ),
288
+ activation_layer_dict[activation_layer](),
289
+ )
290
+ else:
291
+ self.linear_out = nn.Identity()
292
+
293
+ def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
294
+ x = self.linear_in(inputs)
295
+
296
+ x, h = self.gru(x, h)
297
+
298
+ x = self.linear_out(x)
299
+
300
+ if self.gru_skip_op is not None:
301
+ x = x + self.gru_skip_op(inputs)
302
+
303
+ return x, h
304
+
305
+
306
+ class Add(nn.Module):
307
+ def forward(self, a, b):
308
+ return a + b
309
+
310
+
311
+ class Concat(nn.Module):
312
+ def forward(self, a, b):
313
+ return torch.cat((a, b), dim=-1)
314
+
315
+
316
+ class Encoder(nn.Module):
317
+ def __init__(self, config: DfNetConfig):
318
+ super(Encoder, self).__init__()
319
+ self.emb_in_dim = config.conv_channels * config.erb_bins // 4
320
+ self.emb_out_dim = config.conv_channels * config.erb_bins // 4
321
+ self.emb_hidden_dim = config.emb_hidden_dim
322
+
323
+ self.erb_conv0 = CausalConv2d(
324
+ in_channels=1,
325
+ out_channels=config.conv_channels,
326
+ kernel_size=config.conv_kernel_size_input,
327
+ bias=False,
328
+ separable=True,
329
+ )
330
+ self.erb_conv1 = CausalConv2d(
331
+ in_channels=config.conv_channels,
332
+ out_channels=config.conv_channels,
333
+ kernel_size=config.conv_kernel_size_inner,
334
+ bias=False,
335
+ separable=True,
336
+ fstride=2,
337
+ )
338
+ self.erb_conv2 = CausalConv2d(
339
+ in_channels=config.conv_channels,
340
+ out_channels=config.conv_channels,
341
+ kernel_size=config.conv_kernel_size_inner,
342
+ bias=False,
343
+ separable=True,
344
+ fstride=2,
345
+ )
346
+ self.erb_conv3 = CausalConv2d(
347
+ in_channels=config.conv_channels,
348
+ out_channels=config.conv_channels,
349
+ kernel_size=config.conv_kernel_size_inner,
350
+ bias=False,
351
+ separable=True,
352
+ fstride=1,
353
+ )
354
+
355
+ self.df_conv0 = CausalConv2d(
356
+ in_channels=2,
357
+ out_channels=config.conv_channels,
358
+ kernel_size=config.conv_kernel_size_input,
359
+ bias=False,
360
+ separable=True,
361
+ )
362
+ self.df_conv1 = CausalConv2d(
363
+ in_channels=config.conv_channels,
364
+ out_channels=config.conv_channels,
365
+ kernel_size=config.conv_kernel_size_inner,
366
+ bias=False,
367
+ separable=True,
368
+ fstride=2,
369
+ )
370
+
371
+ self.df_fc_emb = nn.Sequential(
372
+ GroupedLinear(
373
+ config.conv_channels * config.df_bins // 2,
374
+ self.emb_in_dim,
375
+ groups=config.encoder_linear_groups
376
+ ),
377
+ nn.ReLU(inplace=True)
378
+ )
379
+
380
+ if config.encoder_concat:
381
+ self.emb_in_dim *= 2
382
+ self.combine = Concat()
383
+ else:
384
+ self.combine = Add()
385
+
386
+ self.emb_gru = SqueezedGRU_S(
387
+ self.emb_in_dim,
388
+ self.emb_hidden_dim,
389
+ output_size=self.emb_out_dim,
390
+ num_layers=1,
391
+ batch_first=True,
392
+ skip_op=config.encoder_gru_skip_op,
393
+ linear_groups=config.encoder_squeezed_gru_linear_groups,
394
+ activation_layer="relu",
395
+ )
396
+
397
+ self.lsnr_fc = nn.Sequential(
398
+ nn.Linear(self.emb_out_dim, 1),
399
+ nn.Sigmoid()
400
+ )
401
+ self.lsnr_scale = config.lsnr_max - config.lsnr_min
402
+ self.lsnr_offset = config.lsnr_min
403
+
404
+ def forward(self,
405
+ feat_erb: torch.Tensor,
406
+ feat_spec: torch.Tensor,
407
+ h: torch.Tensor = None,
408
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
409
+ # Encodes erb; erb should be in dB scale + normalized; Fe are number of erb bands.
410
+ # erb: [B, 1, T, Fe]
411
+ # spec: [B, 2, T, Fc]
412
+ # b, _, t, _ = feat_erb.shape
413
+ e0 = self.erb_conv0(feat_erb) # [B, C, T, F]
414
+ e1 = self.erb_conv1(e0) # [B, C*2, T, F/2]
415
+ e2 = self.erb_conv2(e1) # [B, C*4, T, F/4]
416
+ e3 = self.erb_conv3(e2) # [B, C*4, T, F/4]
417
+ c0 = self.df_conv0(feat_spec) # [B, C, T, Fc]
418
+ c1 = self.df_conv1(c0) # [B, C*2, T, Fc/2]
419
+ cemb = c1.permute(0, 2, 3, 1).flatten(2) # [B, T, -1]
420
+ cemb = self.df_fc_emb(cemb) # [T, B, C * F/4]
421
+ emb = e3.permute(0, 2, 3, 1).flatten(2) # [B, T, C * F]
422
+ emb = self.combine(emb, cemb)
423
+ emb, h = self.emb_gru(emb, h) # [B, T, -1]
424
+
425
+ lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
426
+ return e0, e1, e2, e3, emb, c0, lsnr, h
427
+
428
+
429
+ class ErbDecoder(nn.Module):
430
+ def __init__(self,
431
+ config: DfNetConfig,
432
+ ):
433
+ super(ErbDecoder, self).__init__()
434
+ if config.erb_bins % 8 != 0:
435
+ raise AssertionError("erb_bins should be divisible by 8")
436
+
437
+ self.emb_in_dim = config.conv_channels * config.erb_bins // 4
438
+ self.emb_out_dim = config.conv_channels * config.erb_bins // 4
439
+ self.emb_hidden_dim = config.emb_hidden_dim
440
+
441
+ self.emb_gru = SqueezedGRU_S(
442
+ self.emb_in_dim,
443
+ self.emb_hidden_dim,
444
+ output_size=self.emb_out_dim,
445
+ num_layers=config.erb_decoder_emb_num_layers - 1,
446
+ batch_first=True,
447
+ skip_op=config.erb_decoder_gru_skip_op,
448
+ linear_groups=config.erb_decoder_linear_groups,
449
+ activation_layer="relu",
450
+ )
451
+
452
+ # convt: TransposedConvolution, convp: Pathway (encoder to decoder) convolutions
453
+ self.conv3p = CausalConv2d(
454
+ in_channels=config.conv_channels,
455
+ out_channels=config.conv_channels,
456
+ kernel_size=1,
457
+ bias=False,
458
+ separable=True,
459
+ )
460
+ self.convt3 = CausalConv2d(
461
+ in_channels=config.conv_channels,
462
+ out_channels=config.conv_channels,
463
+ kernel_size=config.conv_kernel_size_inner,
464
+ bias=False,
465
+ separable=True,
466
+ )
467
+ self.conv2p = CausalConv2d(
468
+ in_channels=config.conv_channels,
469
+ out_channels=config.conv_channels,
470
+ kernel_size=1,
471
+ bias=False,
472
+ separable=True,
473
+ )
474
+ self.convt2 = CausalConvTranspose2d(
475
+ in_channels=config.conv_channels,
476
+ out_channels=config.conv_channels,
477
+ fstride=2,
478
+ kernel_size=config.convt_kernel_size_inner,
479
+ bias=False,
480
+ separable=True,
481
+ )
482
+ self.conv1p = CausalConv2d(
483
+ in_channels=config.conv_channels,
484
+ out_channels=config.conv_channels,
485
+ kernel_size=1,
486
+ bias=False,
487
+ separable=True,
488
+ )
489
+ self.convt1 = CausalConvTranspose2d(
490
+ in_channels=config.conv_channels,
491
+ out_channels=config.conv_channels,
492
+ fstride=2,
493
+ kernel_size=config.convt_kernel_size_inner,
494
+ bias=False,
495
+ separable=True,
496
+ )
497
+ self.conv0p = CausalConv2d(
498
+ in_channels=config.conv_channels,
499
+ out_channels=config.conv_channels,
500
+ kernel_size=1,
501
+ bias=False,
502
+ separable=True,
503
+ )
504
+ self.conv0_out = CausalConv2d(
505
+ in_channels=config.conv_channels,
506
+ out_channels=1,
507
+ kernel_size=config.conv_kernel_size_inner,
508
+ activation_layer="sigmoid",
509
+ bias=False,
510
+ separable=True,
511
+ )
512
+
513
+ def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor:
514
+ # Estimates erb mask
515
+ b, _, t, f8 = e3.shape
516
+ emb, _ = self.emb_gru(emb)
517
+ emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) # [B, C*8, T, F/8]
518
+ e3 = self.convt3(self.conv3p(e3) + emb) # [B, C*4, T, F/4]
519
+ e2 = self.convt2(self.conv2p(e2) + e3) # [B, C*2, T, F/2]
520
+ e1 = self.convt1(self.conv1p(e1) + e2) # [B, C, T, F]
521
+ m = self.conv0_out(self.conv0p(e0) + e1) # [B, 1, T, F]
522
+ return m
523
+
524
+
525
+ class Mask(nn.Module):
526
+ def __init__(self, erb_inv_fb: torch.FloatTensor, post_filter: bool = False, eps: float = 1e-12):
527
+ super().__init__()
528
+ self.erb_inv_fb: torch.FloatTensor
529
+ self.register_buffer("erb_inv_fb", erb_inv_fb.float())
530
+ self.clamp_tensor = torch.__version__ > "1.9.0" or torch.__version__ == "1.9.0"
531
+ self.post_filter = post_filter
532
+ self.eps = eps
533
+
534
+ def pf(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor:
535
+ """
536
+ Post-Filter
537
+
538
+ A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech.
539
+ https://arxiv.org/abs/2008.04259
540
+
541
+ :param mask: Real valued mask, typically of shape [B, C, T, F].
542
+ :param beta: Global gain factor.
543
+ :return:
544
+ """
545
+ mask_sin = mask * torch.sin(np.pi * mask / 2)
546
+ mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2))
547
+ return mask_pf
548
+
549
+ def forward(self, spec: torch.Tensor, mask: torch.Tensor, atten_lim: Optional[torch.Tensor] = None) -> torch.Tensor:
550
+ # spec (real) [B, 1, T, F, 2], F: freq_bins
551
+ # mask (real): [B, 1, T, Fe], Fe: erb_bins
552
+ # atten_lim: [B]
553
+ if not self.training and self.post_filter:
554
+ mask = self.pf(mask)
555
+ if atten_lim is not None:
556
+ # dB to amplitude
557
+ atten_lim = 10 ** (-atten_lim / 20)
558
+ # Greater equal (__ge__) not implemented for TorchVersion.
559
+ if self.clamp_tensor:
560
+ # Supported by torch >= 1.9
561
+ mask = mask.clamp(min=atten_lim.view(-1, 1, 1, 1))
562
+ else:
563
+ m_out = []
564
+ for i in range(atten_lim.shape[0]):
565
+ m_out.append(mask[i].clamp_min(atten_lim[i].item()))
566
+ mask = torch.stack(m_out, dim=0)
567
+ mask = mask.matmul(self.erb_inv_fb) # [B, 1, T, F]
568
+ if not spec.is_complex():
569
+ mask = mask.unsqueeze(4)
570
+ return spec * mask
571
+
572
+
573
+ class DfDecoder(nn.Module):
574
+ def __init__(self,
575
+ config: DfNetConfig,
576
+ ):
577
+ super().__init__()
578
+ layer_width = config.conv_channels
579
+
580
+ self.emb_in_dim = config.conv_channels * config.erb_bins // 4
581
+ self.emb_dim = config.df_hidden_dim
582
+
583
+ self.df_n_hidden = config.df_hidden_dim
584
+ self.df_n_layers = config.df_num_layers
585
+ self.df_order = config.df_order
586
+ self.df_bins = config.df_bins
587
+ self.df_out_ch = config.df_order * 2
588
+
589
+ self.df_convp = CausalConv2d(
590
+ layer_width,
591
+ self.df_out_ch,
592
+ fstride=1,
593
+ kernel_size=(config.df_pathway_kernel_size_t, 1),
594
+ separable=True,
595
+ bias=False,
596
+ )
597
+ self.df_gru = SqueezedGRU_S(
598
+ self.emb_in_dim,
599
+ self.emb_dim,
600
+ num_layers=self.df_n_layers,
601
+ batch_first=True,
602
+ skip_op="none",
603
+ activation_layer="relu",
604
+ )
605
+
606
+ if config.df_gru_skip == "none":
607
+ self.df_skip = None
608
+ elif config.df_gru_skip == "identity":
609
+ if config.emb_hidden_dim != config.df_hidden_dim:
610
+ raise AssertionError("Dimensions do not match")
611
+ self.df_skip = nn.Identity()
612
+ elif config.df_gru_skip == "grouped_linear":
613
+ self.df_skip = GroupedLinear(self.emb_in_dim, self.emb_dim, groups=config.df_decoder_linear_groups)
614
+ else:
615
+ raise NotImplementedError()
616
+
617
+ self.df_out: nn.Module
618
+ out_dim = self.df_bins * self.df_out_ch
619
+
620
+ self.df_out = nn.Sequential(
621
+ GroupedLinear(
622
+ input_size=self.df_n_hidden,
623
+ hidden_size=out_dim,
624
+ groups=config.df_decoder_linear_groups
625
+ ),
626
+ nn.Tanh()
627
+ )
628
+ self.df_fc_a = nn.Sequential(
629
+ nn.Linear(self.df_n_hidden, 1),
630
+ nn.Sigmoid()
631
+ )
632
+
633
+ def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor:
634
+ b, t, _ = emb.shape
635
+ c, _ = self.df_gru(emb) # [B, T, H], H: df_n_hidden
636
+ if self.df_skip is not None:
637
+ c = c + self.df_skip(emb)
638
+ c0 = self.df_convp(c0).permute(0, 2, 3, 1) # [B, T, F, O*2], channels_last
639
+ c = self.df_out(c) # [B, T, F*O*2], O: df_order
640
+ c = c.view(b, t, self.df_bins, self.df_out_ch) + c0 # [B, T, F, O*2]
641
+ return c
642
+
643
+
644
+ class DfOutputReshapeMF(nn.Module):
645
+ """Coefficients output reshape for multiframe/MultiFrameModule
646
+
647
+ Requires input of shape B, C, T, F, 2.
648
+ """
649
+
650
+ def __init__(self, df_order: int, df_bins: int):
651
+ super().__init__()
652
+ self.df_order = df_order
653
+ self.df_bins = df_bins
654
+
655
+ def forward(self, coefs: torch.Tensor) -> torch.Tensor:
656
+ # [B, T, F, O*2] -> [B, O, T, F, 2]
657
+ new_shape = list(coefs.shape)
658
+ new_shape[-1] = -1
659
+ new_shape.append(2)
660
+ coefs = coefs.view(new_shape)
661
+ coefs = coefs.permute(0, 3, 1, 2, 4)
662
+ return coefs
663
+
664
+
665
+ class DfNet(nn.Module):
666
+ """
667
+ DeepFilterNet: Perceptually Motivated Real-Time Speech Enhancement
668
+ https://arxiv.org/abs/2305.08227
669
+
670
671
+ """
672
+ def __init__(self,
673
+ config: DfNetConfig,
674
+ erb_fb: torch.FloatTensor,
675
+ erb_inv_fb: torch.FloatTensor,
676
+ run_df: bool = True,
677
+ train_mask: bool = True,
678
+ ):
679
+ """
680
+ :param erb_fb: erb filter bank.
681
+ """
682
+ super(DfNet, self).__init__()
683
+ if config.erb_bins % 8 != 0:
684
+ raise AssertionError("erb_bins should be divisible by 8")
685
+
686
+ self.df_lookahead = config.df_lookahead
687
+ self.df_bins = config.df_bins
688
+ self.freq_bins: int = config.fft_size // 2 + 1
689
+ self.emb_dim: int = config.conv_channels * config.erb_bins
690
+ self.erb_bins: int = config.erb_bins
691
+
692
+ if config.conv_lookahead > 0:
693
+ if config.conv_lookahead < config.df_lookahead:
694
+ raise AssertionError
695
+ # for last 2 dim, pad (left, right, top, bottom).
696
+ self.pad_feat = nn.ConstantPad2d((0, 0, -config.conv_lookahead, config.conv_lookahead), 0.0)
697
+ else:
698
+ self.pad_feat = nn.Identity()
699
+
700
+ if config.df_lookahead > 0:
701
+ # for last 3 dim, pad (left, right, top, bottom, front, back).
702
+ self.pad_spec = nn.ConstantPad3d((0, 0, 0, 0, -config.df_lookahead, config.df_lookahead), 0.0)
703
+ else:
704
+ self.pad_spec = nn.Identity()
705
+
706
+ self.register_buffer("erb_fb", erb_fb)
707
+
708
+ self.enc = Encoder(config)
709
+ self.erb_dec = ErbDecoder(config)
710
+ self.mask = Mask(erb_inv_fb)
711
+
712
+ self.erb_inv_fb = erb_inv_fb
713
+ self.post_filter = config.mask_post_filter
714
+ self.post_filter_beta = config.post_filter_beta
715
+
716
+ self.df_order = config.df_order
717
+ self.df_op = MF.DF(num_freqs=config.df_bins, frame_size=config.df_order, lookahead=self.df_lookahead)
718
+ self.df_dec = DfDecoder(config)
719
+ self.df_out_transform = DfOutputReshapeMF(self.df_order, config.df_bins)
720
+
721
+ self.run_erb = config.df_bins + 1 < self.freq_bins
722
+ if not self.run_erb:
723
+ logger.warning("Running without ERB stage")
724
+ self.run_df = run_df
725
+ if not run_df:
726
+ logger.warning("Running without DF stage")
727
+ self.train_mask = train_mask
728
+ self.lsnr_dropout = config.lsnr_dropout
729
+ if config.df_n_iter != 1:
730
+ raise AssertionError
731
+
732
+ def forward1(
733
+ self,
734
+ spec: torch.Tensor,
735
+ feat_erb: torch.Tensor,
736
+ feat_spec: torch.Tensor, # Not used, take spec modified by mask instead
737
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
738
+ """Forward method of DeepFilterNet2.
739
+
740
+ Args:
741
+ spec (Tensor): Spectrum of shape [B, 1, T, F, 2]
742
+ feat_erb (Tensor): ERB features of shape [B, 1, T, E]
743
+ feat_spec (Tensor): Complex spectrogram features of shape [B, 1, T, F', 2]
744
+
745
+ Returns:
746
+ spec (Tensor): Enhanced spectrum of shape [B, 1, T, F, 2]
747
+ m (Tensor): ERB mask estimate of shape [B, 1, T, E]
748
+ lsnr (Tensor): Local SNR estimate of shape [B, T, 1]
749
+ """
750
+ # feat_spec shape: [batch_size, 1, time_steps, freq_dim, 2]
751
+ feat_spec = feat_spec.squeeze(1).permute(0, 3, 1, 2)
752
+ # feat_spec shape: [batch_size, 2, time_steps, freq_dim]
753
+
754
+ # feat_erb shape: [batch_size, 1, time_steps, erb_bins]
755
+ # assert time_steps >= conv_lookahead.
756
+ feat_erb = self.pad_feat(feat_erb)
757
+ feat_spec = self.pad_feat(feat_spec)
758
+ e0, e1, e2, e3, emb, c0, lsnr, h = self.enc(feat_erb, feat_spec)
759
+
760
+ if self.lsnr_droput:
761
+ idcs = lsnr.squeeze() > -10.0
762
+ b, t = (spec.shape[0], spec.shape[2])
763
+ m = torch.zeros((b, 1, t, self.erb_bins), device=spec.device)
764
+ df_coefs = torch.zeros((b, t, self.nb_df, self.df_order * 2))
765
+ spec_m = spec.clone()
766
+ emb = emb[:, idcs]
767
+ e0 = e0[:, :, idcs]
768
+ e1 = e1[:, :, idcs]
769
+ e2 = e2[:, :, idcs]
770
+ e3 = e3[:, :, idcs]
771
+ c0 = c0[:, :, idcs]
772
+
773
+ if self.run_erb:
774
+ if self.lsnr_dropout:
775
+ m[:, :, idcs] = self.erb_dec(emb, e3, e2, e1, e0)
776
+ else:
777
+ m = self.erb_dec(emb, e3, e2, e1, e0)
778
+ spec_m = self.mask(spec, m)
779
+ else:
780
+ m = torch.zeros((), device=spec.device)
781
+ spec_m = torch.zeros_like(spec)
782
+
783
+ if self.run_df:
784
+ if self.lsnr_dropout:
785
+ df_coefs[:, idcs] = self.df_dec(emb, c0)
786
+ else:
787
+ df_coefs = self.df_dec(emb, c0)
788
+ df_coefs = self.df_out_transform(df_coefs)
789
+ spec_e = self.df_op(spec.clone(), df_coefs)
790
+ spec_e[..., self.df_bins:, :] = spec_m[..., self.df_bins:, :]
791
+ else:
792
+ df_coefs = torch.zeros((), device=spec.device)
793
+ spec_e = spec_m
794
+
795
+ if self.post_filter:
796
+ beta = self.post_filter_beta
797
+ eps = 1e-12
798
+ mask = (utils.as_complex(spec_e).abs() / utils.as_complex(spec).abs().add(eps)).clamp(eps, 1)
799
+ mask_sin = mask * torch.sin(PI * mask / 2).clamp_min(eps)
800
+ pf = (1 + beta) / (1 + beta * mask.div(mask_sin).pow(2))
801
+ spec_e = spec_e * pf.unsqueeze(-1)
802
+
803
+ return spec_e, m, lsnr, df_coefs
804
+
805
+ def forward(
806
+ self,
807
+ spec: torch.Tensor,
808
+ feat_erb: torch.Tensor,
809
+ feat_spec: torch.Tensor, # Not used, take spec modified by mask instead
810
+ erb_encoder_h: torch.Tensor = None,
811
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
812
+ # feat_spec shape: [batch_size, 1, time_steps, freq_dim, 2]
813
+ feat_spec = feat_spec.squeeze(1).permute(0, 3, 1, 2)
814
+ # feat_spec shape: [batch_size, 2, time_steps, freq_dim]
815
+
816
+ # feat_erb shape: [batch_size, 1, time_steps, erb_bins]
817
+ # assert time_steps >= conv_lookahead.
818
+ feat_erb = self.pad_feat(feat_erb)
819
+ feat_spec = self.pad_feat(feat_spec)
820
+ e0, e1, e2, e3, emb, c0, lsnr, erb_encoder_h = self.enc(feat_erb, feat_spec, erb_encoder_h)
821
+
822
+ m = self.erb_dec(emb, e3, e2, e1, e0)
823
+ spec_m = self.mask(spec, m)
824
+ # spec_e = spec_m
825
+
826
+ df_coefs = self.df_dec(emb, c0)
827
+ df_coefs = self.df_out_transform(df_coefs)
828
+ spec_e = self.df_op(spec.clone(), df_coefs)
829
+ spec_e[..., self.df_bins:, :] = spec_m[..., self.df_bins:, :]
830
+
831
+ return spec_e, m, lsnr, df_coefs, erb_encoder_h
832
+
833
+
834
+ if __name__ == "__main__":
835
+ pass
toolbox/torchaudio/models/dfnet3/multiframes.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ # From torchaudio
9
+ def _compute_mat_trace(input: torch.Tensor, dim1: int = -2, dim2: int = -1) -> torch.Tensor:
10
+ r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
11
+ Args:
12
+ input (torch.Tensor): Tensor of dimension `(..., channel, channel)`
13
+ dim1 (int, optional): the first dimension of the diagonal matrix
14
+ (Default: -1)
15
+ dim2 (int, optional): the second dimension of the diagonal matrix
16
+ (Default: -2)
17
+ Returns:
18
+ Tensor: trace of the input Tensor
19
+ """
20
+ assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
21
+ assert (
22
+ input.shape[dim1] == input.shape[dim2]
23
+ ), "The size of ``dim1`` and ``dim2`` must be the same."
24
+ input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2)
25
+ return input.sum(dim=-1)
26
+
27
+
28
+ def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.Tensor:
29
+ """Perform Tikhonov regularization (only modifying real part).
30
+ Args:
31
+ mat (torch.Tensor): input matrix (..., channel, channel)
32
+ reg (float, optional): regularization factor (Default: 1e-8)
33
+ eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: ``1e-8``)
34
+ Returns:
35
+ Tensor: regularized matrix (..., channel, channel)
36
+ """
37
+ # Add eps
38
+ C = mat.size(-1)
39
+ eye = torch.eye(C, dtype=mat.dtype, device=mat.device)
40
+ epsilon = _compute_mat_trace(mat).real[..., None, None] * reg
41
+ # in case that correlation_matrix is all-zero
42
+ epsilon = epsilon + eps
43
+ mat = mat + epsilon * eye[..., :, :]
44
+ return mat
45
+
46
+
47
+ class MultiFrameModule(nn.Module):
48
+ """
49
+ Multi-frame speech enhancement modules.
50
+
51
+ Signal model and notation:
52
+ Noisy: `x = s + n`
53
+ Enhanced: `y = f(x)`
54
+ Objective: `min ||s - y||`
55
+
56
+ PSD: Power spectral density, notated eg. as `Rxx` for noisy PSD.
57
+ IFC: Inter-frame correlation vector: PSD*u, u: selection vector. Notated as `rxx`
58
+ RTF: Relative transfere function, also called steering vector.
59
+ """
60
+ def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0, real: bool = False):
61
+ """
62
+ Multi-Frame filtering module.
63
+
64
+ :param num_freqs: int. Number of frequency bins used for filtering.
65
+ :param frame_size: int. Frame size in FD domain.
66
+ :param lookahead: int. Lookahead, may be used to select the output time step.
67
+ Note: This module does not add additional padding according to lookahead!
68
+ :param real:
69
+ """
70
+ super().__init__()
71
+ self.num_freqs = num_freqs
72
+ self.frame_size = frame_size
73
+ self.real = real
74
+ if real:
75
+ self.pad = nn.ConstantPad3d((0, 0, 0, 0, frame_size - 1 - lookahead, lookahead), 0.0)
76
+ else:
77
+ self.pad = nn.ConstantPad2d((0, 0, frame_size - 1 - lookahead, lookahead), 0.0)
78
+ self.need_unfold = frame_size > 1
79
+ self.lookahead = lookahead
80
+
81
+ def spec_unfold_real(self, spec: torch.Tensor):
82
+ if self.need_unfold:
83
+ spec = self.pad(spec).unfold(-3, self.frame_size, 1)
84
+ return spec.permute(0, 1, 5, 2, 3, 4)
85
+ # return as_windowed(self.pad(spec), self.frame_size, 1, dim=-3)
86
+ return spec.unsqueeze(-1)
87
+
88
+ def spec_unfold(self, spec: torch.Tensor):
89
+ """Pads and unfolds the spectrogram according to frame_size.
90
+
91
+ Args:
92
+ spec (complex Tensor): Spectrogram of shape [B, C, T, F]
93
+ Returns:
94
+ spec (Tensor): Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size.
95
+ """
96
+ if self.need_unfold:
97
+ return self.pad(spec).unfold(2, self.frame_size, 1)
98
+ return spec.unsqueeze(-1)
99
+
100
+ @staticmethod
101
+ def solve(Rxx, rss, diag_eps: float = 1e-8, eps: float = 1e-7) -> torch.Tensor:
102
+ return torch.einsum(
103
+ "...nm,...m->...n", torch.inverse(_tik_reg(Rxx, diag_eps, eps)), rss
104
+ ) # [T, F, N]
105
+
106
+ @staticmethod
107
+ def apply_coefs(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor:
108
+ # spec: [B, C, T, F, N]
109
+ # coefs: [B, C, T, F, N]
110
+ return torch.einsum("...n,...n->...", spec, coefs)
111
+
112
+
113
+ class DF(MultiFrameModule):
114
+ """Deep Filtering."""
115
+
116
+ def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0, conj: bool = False):
117
+ super().__init__(num_freqs, frame_size, lookahead)
118
+ self.conj: bool = conj
119
+
120
+ def forward(self, spec: torch.Tensor, coefs: torch.Tensor):
121
+ spec_u = self.spec_unfold(torch.view_as_complex(spec))
122
+ coefs = torch.view_as_complex(coefs)
123
+ spec_f = spec_u.narrow(-2, 0, self.num_freqs)
124
+ coefs = coefs.view(coefs.shape[0], -1, self.frame_size, *coefs.shape[2:])
125
+ if self.conj:
126
+ coefs = coefs.conj()
127
+ spec_f = self.df(spec_f, coefs)
128
+ if self.training:
129
+ spec = spec.clone()
130
+ spec[..., : self.num_freqs, :] = torch.view_as_real(spec_f)
131
+ return spec
132
+
133
+ @staticmethod
134
+ def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor:
135
+ """
136
+ Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram.
137
+ :param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N].
138
+ :param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F].
139
+ :return: (complex Tensor). Spectrogram of shape [B, C, T, F].
140
+ """
141
+ return torch.einsum("...tfn,...ntf->...tf", spec, coefs)
142
+
143
+
144
+ if __name__ == '__main__':
145
+ pass
toolbox/torchaudio/models/dfnet3/utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+
5
+
6
+ def as_complex(x: torch.Tensor):
7
+ if torch.is_complex(x):
8
+ return x
9
+ if x.shape[-1] != 2:
10
+ raise ValueError(f"Last dimension need to be of length 2 (re + im), but got {x.shape}")
11
+ if x.stride(-1) != 1:
12
+ x = x.contiguous()
13
+ return torch.view_as_complex(x)
14
+
15
+
16
+ if __name__ == '__main__':
17
+ pass
toolbox/torchaudio/models/ehnet/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/models/ehnet/modeling_ehnet.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://arxiv.org/abs/1805.00579
5
+
6
+ https://github.com/haoxiangsnr/A-Convolutional-Recurrent-Neural-Network-for-Real-Time-Speech-Enhancement
7
+
8
+ """
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+
13
+ class CausalConvBlock(nn.Module):
14
+ def __init__(self, in_channels, out_channels):
15
+ super().__init__()
16
+ self.conv = nn.Conv2d(
17
+ in_channels=in_channels,
18
+ out_channels=out_channels,
19
+ kernel_size=(3, 2),
20
+ stride=(2, 1),
21
+ padding=(0, 1)
22
+ )
23
+ self.norm = nn.BatchNorm2d(num_features=out_channels)
24
+ self.activation = nn.ELU()
25
+
26
+ def forward(self, x):
27
+ """
28
+ 2D Causal convolution.
29
+ Args:
30
+ x: [B, C, F, T]
31
+
32
+ Returns:
33
+ [B, C, F, T]
34
+ """
35
+ x = self.conv(x)
36
+ x = x[:, :, :, :-1] # chomp size
37
+ x = self.norm(x)
38
+ x = self.activation(x)
39
+ return x
40
+
41
+
42
+ class CausalTransConvBlock(nn.Module):
43
+ def __init__(self, in_channels, out_channels, is_last=False, output_padding=(0, 0)):
44
+ super().__init__()
45
+ self.conv = nn.ConvTranspose2d(
46
+ in_channels=in_channels,
47
+ out_channels=out_channels,
48
+ kernel_size=(3, 2),
49
+ stride=(2, 1),
50
+ output_padding=output_padding
51
+ )
52
+ self.norm = nn.BatchNorm2d(num_features=out_channels)
53
+ if is_last:
54
+ self.activation = nn.ReLU()
55
+ else:
56
+ self.activation = nn.ELU()
57
+
58
+ def forward(self, x):
59
+ """
60
+ 2D Causal convolution.
61
+ Args:
62
+ x: [B, C, F, T]
63
+
64
+ Returns:
65
+ [B, C, F, T]
66
+ """
67
+ x = self.conv(x)
68
+ x = x[:, :, :, :-1] # chomp size
69
+ x = self.norm(x)
70
+ x = self.activation(x)
71
+ return x
72
+
73
+
74
+
75
+ class CRN(nn.Module):
76
+ """
77
+ Input: [batch size, channels=1, T, n_fft]
78
+ Output: [batch size, T, n_fft]
79
+ """
80
+
81
+ def __init__(self):
82
+ super(CRN, self).__init__()
83
+ # Encoder
84
+ self.conv_block_1 = CausalConvBlock(1, 16)
85
+ self.conv_block_2 = CausalConvBlock(16, 32)
86
+ self.conv_block_3 = CausalConvBlock(32, 64)
87
+ self.conv_block_4 = CausalConvBlock(64, 128)
88
+ self.conv_block_5 = CausalConvBlock(128, 256)
89
+
90
+ # LSTM
91
+ self.lstm_layer = nn.LSTM(input_size=1024, hidden_size=1024, num_layers=2, batch_first=True)
92
+
93
+ self.tran_conv_block_1 = CausalTransConvBlock(256 + 256, 128)
94
+ self.tran_conv_block_2 = CausalTransConvBlock(128 + 128, 64)
95
+ self.tran_conv_block_3 = CausalTransConvBlock(64 + 64, 32)
96
+ self.tran_conv_block_4 = CausalTransConvBlock(32 + 32, 16, output_padding=(1, 0))
97
+ self.tran_conv_block_5 = CausalTransConvBlock(16 + 16, 1, is_last=True)
98
+
99
+ def forward(self, x):
100
+ self.lstm_layer.flatten_parameters()
101
+
102
+ e_1 = self.conv_block_1(x)
103
+ e_2 = self.conv_block_2(e_1)
104
+ e_3 = self.conv_block_3(e_2)
105
+ e_4 = self.conv_block_4(e_3)
106
+ e_5 = self.conv_block_5(e_4) # [2, 256, 4, 200]
107
+
108
+ batch_size, n_channels, n_f_bins, n_frame_size = e_5.shape
109
+
110
+ # [2, 256, 4, 200] = [2, 1024, 200] => [2, 200, 1024]
111
+ lstm_in = e_5.reshape(batch_size, n_channels * n_f_bins, n_frame_size).permute(0, 2, 1)
112
+ lstm_out, _ = self.lstm_layer(lstm_in) # [2, 200, 1024]
113
+ lstm_out = lstm_out.permute(0, 2, 1).reshape(batch_size, n_channels, n_f_bins, n_frame_size) # [2, 256, 4, 200]
114
+
115
+ d_1 = self.tran_conv_block_1(torch.cat((lstm_out, e_5), 1))
116
+ d_2 = self.tran_conv_block_2(torch.cat((d_1, e_4), 1))
117
+ d_3 = self.tran_conv_block_3(torch.cat((d_2, e_3), 1))
118
+ d_4 = self.tran_conv_block_4(torch.cat((d_3, e_2), 1))
119
+ d_5 = self.tran_conv_block_5(torch.cat((d_4, e_1), 1))
120
+
121
+ return d_5
122
+
123
+
124
+ def main():
125
+ layer = CRN()
126
+ a = torch.rand(2, 1, 161, 200)
127
+ print(layer(a).shape)
128
+ return
129
+
130
+
131
+ if __name__ == '__main__':
132
+ main()
toolbox/torchaudio/models/percepnet/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/models/percepnet/modeling_percetnet.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/jzi040941/PercepNet
5
+
6
+ https://arxiv.org/abs/2008.04259
7
+ """
8
+
9
+
10
+ if __name__ == '__main__':
11
+ pass