HoneyTian commited on
Commit
e2f2829
·
1 Parent(s): 7f9c54f
examples/mpnet_aishell/run.sh CHANGED
@@ -12,7 +12,7 @@ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name fi
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
- sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
 
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
examples/mpnet_aishell/step_2_train_model.py CHANGED
@@ -44,8 +44,6 @@ def get_args():
44
 
45
  parser.add_argument("--max_epochs", default=100, type=int)
46
 
47
- parser.add_argument("--batch_size", default=64, type=int)
48
- parser.add_argument("--learning_rate", default=1e-4, type=float)
49
  parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
50
  parser.add_argument("--patience", default=5, type=int)
51
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
@@ -142,7 +140,7 @@ def main():
142
  )
143
  train_data_loader = DataLoader(
144
  dataset=train_dataset,
145
- batch_size=args.batch_size,
146
  shuffle=True,
147
  sampler=None,
148
  # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
@@ -153,7 +151,7 @@ def main():
153
  )
154
  valid_data_loader = DataLoader(
155
  dataset=valid_dataset,
156
- batch_size=args.batch_size,
157
  shuffle=True,
158
  sampler=None,
159
  # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
 
44
 
45
  parser.add_argument("--max_epochs", default=100, type=int)
46
 
 
 
47
  parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
48
  parser.add_argument("--patience", default=5, type=int)
49
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
 
140
  )
141
  train_data_loader = DataLoader(
142
  dataset=train_dataset,
143
+ batch_size=config.batch_size,
144
  shuffle=True,
145
  sampler=None,
146
  # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
 
151
  )
152
  valid_data_loader = DataLoader(
153
  dataset=valid_dataset,
154
+ batch_size=config.batch_size,
155
  shuffle=True,
156
  sampler=None,
157
  # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
requirements-python-3-9-9.txt CHANGED
@@ -8,7 +8,7 @@ openpyxl==3.1.5
8
  torch==2.5.1
9
  torchaudio==2.5.1
10
  overrides==7.7.0
11
- torch-pesq
12
- torchmetrics
13
- torchmetrics[audio]
14
- einops
 
8
  torch==2.5.1
9
  torchaudio==2.5.1
10
  overrides==7.7.0
11
+ torch-pesq==0.1.2
12
+ torchmetrics==1.6.1
13
+ torchmetrics[audio]==1.6.1
14
+ einops==0.8.1