Spaces:
Running
Running
update
Browse files
examples/mpnet_aishell/run.sh
CHANGED
@@ -12,7 +12,7 @@ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name fi
|
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
|
15 |
-
sh run.sh --stage
|
16 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
18 |
|
|
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
|
15 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \
|
16 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
18 |
|
examples/mpnet_aishell/step_2_train_model.py
CHANGED
@@ -44,8 +44,6 @@ def get_args():
|
|
44 |
|
45 |
parser.add_argument("--max_epochs", default=100, type=int)
|
46 |
|
47 |
-
parser.add_argument("--batch_size", default=64, type=int)
|
48 |
-
parser.add_argument("--learning_rate", default=1e-4, type=float)
|
49 |
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
50 |
parser.add_argument("--patience", default=5, type=int)
|
51 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
@@ -142,7 +140,7 @@ def main():
|
|
142 |
)
|
143 |
train_data_loader = DataLoader(
|
144 |
dataset=train_dataset,
|
145 |
-
batch_size=
|
146 |
shuffle=True,
|
147 |
sampler=None,
|
148 |
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
@@ -153,7 +151,7 @@ def main():
|
|
153 |
)
|
154 |
valid_data_loader = DataLoader(
|
155 |
dataset=valid_dataset,
|
156 |
-
batch_size=
|
157 |
shuffle=True,
|
158 |
sampler=None,
|
159 |
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
|
|
44 |
|
45 |
parser.add_argument("--max_epochs", default=100, type=int)
|
46 |
|
|
|
|
|
47 |
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
48 |
parser.add_argument("--patience", default=5, type=int)
|
49 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
|
|
140 |
)
|
141 |
train_data_loader = DataLoader(
|
142 |
dataset=train_dataset,
|
143 |
+
batch_size=config.batch_size,
|
144 |
shuffle=True,
|
145 |
sampler=None,
|
146 |
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
|
|
151 |
)
|
152 |
valid_data_loader = DataLoader(
|
153 |
dataset=valid_dataset,
|
154 |
+
batch_size=config.batch_size,
|
155 |
shuffle=True,
|
156 |
sampler=None,
|
157 |
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
requirements-python-3-9-9.txt
CHANGED
@@ -8,7 +8,7 @@ openpyxl==3.1.5
|
|
8 |
torch==2.5.1
|
9 |
torchaudio==2.5.1
|
10 |
overrides==7.7.0
|
11 |
-
torch-pesq
|
12 |
-
torchmetrics
|
13 |
-
torchmetrics[audio]
|
14 |
-
einops
|
|
|
8 |
torch==2.5.1
|
9 |
torchaudio==2.5.1
|
10 |
overrides==7.7.0
|
11 |
+
torch-pesq==0.1.2
|
12 |
+
torchmetrics==1.6.1
|
13 |
+
torchmetrics[audio]==1.6.1
|
14 |
+
einops==0.8.1
|