HoneyTian commited on
Commit
19b9289
·
1 Parent(s): 5458faa
Files changed (4) hide show
  1. examples/mpnet_aishell/run.sh +2 -1
  2. log.py +229 -0
  3. main.py +24 -2
  4. project_settings.py +2 -2
examples/mpnet_aishell/run.sh CHANGED
@@ -19,7 +19,8 @@ sh run.sh --stage 5 --stop_stage 5 --system_version centos --file_folder_name fi
19
 
20
  sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name mpnet-nx-speech-20250224 \
21
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
22
- --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech"
 
23
 
24
 
25
  END
 
19
 
20
  sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name mpnet-nx-speech-20250224 \
21
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
22
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech" \
23
+ --max_epochs 1
24
 
25
 
26
  END
log.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import logging
4
+ from logging.handlers import RotatingFileHandler, TimedRotatingFileHandler
5
+ import os
6
+
7
+
8
+ def setup_size_rotating(log_directory: str):
9
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
10
+
11
+ stream_handler = logging.StreamHandler()
12
+ stream_handler.setLevel(logging.INFO)
13
+ stream_handler.setFormatter(logging.Formatter(fmt))
14
+
15
+ # main
16
+ main_logger = logging.getLogger("main")
17
+ main_logger.addHandler(stream_handler)
18
+ main_info_file_handler = RotatingFileHandler(
19
+ filename=os.path.join(log_directory, "main.log"),
20
+ maxBytes=100*1024*1024, # 100MB
21
+ encoding="utf-8",
22
+ backupCount=2,
23
+ )
24
+ main_info_file_handler.setLevel(logging.INFO)
25
+ main_info_file_handler.setFormatter(logging.Formatter(fmt))
26
+ main_logger.addHandler(main_info_file_handler)
27
+
28
+ # http
29
+ http_logger = logging.getLogger("http")
30
+ http_file_handler = RotatingFileHandler(
31
+ filename=os.path.join(log_directory, "http.log"),
32
+ maxBytes=100*1024*1024, # 100MB
33
+ encoding="utf-8",
34
+ backupCount=2,
35
+ )
36
+ http_file_handler.setLevel(logging.DEBUG)
37
+ http_file_handler.setFormatter(logging.Formatter(fmt))
38
+ http_logger.addHandler(http_file_handler)
39
+
40
+ # api
41
+ api_logger = logging.getLogger("api")
42
+ api_file_handler = RotatingFileHandler(
43
+ filename=os.path.join(log_directory, "api.log"),
44
+ maxBytes=10*1024*1024, # 10MB
45
+ encoding="utf-8",
46
+ backupCount=2,
47
+ )
48
+ api_file_handler.setLevel(logging.DEBUG)
49
+ api_file_handler.setFormatter(logging.Formatter(fmt))
50
+ api_logger.addHandler(api_file_handler)
51
+
52
+ # toolbox
53
+ toolbox_logger = logging.getLogger("toolbox")
54
+ toolbox_logger.addHandler(stream_handler)
55
+ toolbox_file_handler = RotatingFileHandler(
56
+ filename=os.path.join(log_directory, "toolbox.log"),
57
+ maxBytes=10*1024*1024, # 10MB
58
+ encoding="utf-8",
59
+ backupCount=2,
60
+ )
61
+ toolbox_file_handler.setLevel(logging.DEBUG)
62
+ toolbox_file_handler.setFormatter(logging.Formatter(fmt))
63
+ toolbox_logger.addHandler(toolbox_file_handler)
64
+
65
+ # alarm
66
+ alarm_logger = logging.getLogger("alarm")
67
+ alarm_file_handler = RotatingFileHandler(
68
+ filename=os.path.join(log_directory, "alarm.log"),
69
+ maxBytes=1*1024*1024, # 1MB
70
+ encoding="utf-8",
71
+ backupCount=2,
72
+ )
73
+ alarm_file_handler.setLevel(logging.DEBUG)
74
+ alarm_file_handler.setFormatter(logging.Formatter(fmt))
75
+ alarm_logger.addHandler(alarm_file_handler)
76
+
77
+ debug_file_handler = RotatingFileHandler(
78
+ filename=os.path.join(log_directory, "debug.log"),
79
+ maxBytes=1*1024*1024, # 1MB
80
+ encoding="utf-8",
81
+ backupCount=2,
82
+ )
83
+ debug_file_handler.setLevel(logging.DEBUG)
84
+ debug_file_handler.setFormatter(logging.Formatter(fmt))
85
+
86
+ info_file_handler = RotatingFileHandler(
87
+ filename=os.path.join(log_directory, "info.log"),
88
+ maxBytes=1*1024*1024, # 1MB
89
+ encoding="utf-8",
90
+ backupCount=2,
91
+ )
92
+ info_file_handler.setLevel(logging.INFO)
93
+ info_file_handler.setFormatter(logging.Formatter(fmt))
94
+
95
+ error_file_handler = RotatingFileHandler(
96
+ filename=os.path.join(log_directory, "error.log"),
97
+ maxBytes=1*1024*1024, # 1MB
98
+ encoding="utf-8",
99
+ backupCount=2,
100
+ )
101
+ error_file_handler.setLevel(logging.ERROR)
102
+ error_file_handler.setFormatter(logging.Formatter(fmt))
103
+
104
+ logging.basicConfig(
105
+ level=logging.DEBUG,
106
+ datefmt="%a, %d %b %Y %H:%M:%S",
107
+ handlers=[
108
+ debug_file_handler,
109
+ info_file_handler,
110
+ error_file_handler,
111
+ ]
112
+ )
113
+
114
+
115
+ def setup_time_rotating(log_directory: str):
116
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
117
+
118
+ stream_handler = logging.StreamHandler()
119
+ stream_handler.setLevel(logging.INFO)
120
+ stream_handler.setFormatter(logging.Formatter(fmt))
121
+
122
+ # main
123
+ main_logger = logging.getLogger("main")
124
+ main_logger.addHandler(stream_handler)
125
+ main_info_file_handler = TimedRotatingFileHandler(
126
+ filename=os.path.join(log_directory, "main.log"),
127
+ encoding="utf-8",
128
+ when="midnight",
129
+ interval=1,
130
+ backupCount=7
131
+ )
132
+ main_info_file_handler.setLevel(logging.INFO)
133
+ main_info_file_handler.setFormatter(logging.Formatter(fmt))
134
+ main_logger.addHandler(main_info_file_handler)
135
+
136
+ # http
137
+ http_logger = logging.getLogger("http")
138
+ http_file_handler = TimedRotatingFileHandler(
139
+ filename=os.path.join(log_directory, "http.log"),
140
+ encoding='utf-8',
141
+ when="midnight",
142
+ interval=1,
143
+ backupCount=7
144
+ )
145
+ http_file_handler.setLevel(logging.DEBUG)
146
+ http_file_handler.setFormatter(logging.Formatter(fmt))
147
+ http_logger.addHandler(http_file_handler)
148
+
149
+ # api
150
+ api_logger = logging.getLogger("api")
151
+ api_file_handler = TimedRotatingFileHandler(
152
+ filename=os.path.join(log_directory, "api.log"),
153
+ encoding='utf-8',
154
+ when="midnight",
155
+ interval=1,
156
+ backupCount=7
157
+ )
158
+ api_file_handler.setLevel(logging.DEBUG)
159
+ api_file_handler.setFormatter(logging.Formatter(fmt))
160
+ api_logger.addHandler(api_file_handler)
161
+
162
+ # toolbox
163
+ toolbox_logger = logging.getLogger("toolbox")
164
+ toolbox_file_handler = RotatingFileHandler(
165
+ filename=os.path.join(log_directory, "toolbox.log"),
166
+ maxBytes=10*1024*1024, # 10MB
167
+ encoding="utf-8",
168
+ backupCount=2,
169
+ )
170
+ toolbox_file_handler.setLevel(logging.DEBUG)
171
+ toolbox_file_handler.setFormatter(logging.Formatter(fmt))
172
+ toolbox_logger.addHandler(toolbox_file_handler)
173
+
174
+ # alarm
175
+ alarm_logger = logging.getLogger("alarm")
176
+ alarm_file_handler = TimedRotatingFileHandler(
177
+ filename=os.path.join(log_directory, "alarm.log"),
178
+ encoding="utf-8",
179
+ when="midnight",
180
+ interval=1,
181
+ backupCount=7
182
+ )
183
+ alarm_file_handler.setLevel(logging.DEBUG)
184
+ alarm_file_handler.setFormatter(logging.Formatter(fmt))
185
+ alarm_logger.addHandler(alarm_file_handler)
186
+
187
+ debug_file_handler = TimedRotatingFileHandler(
188
+ filename=os.path.join(log_directory, "debug.log"),
189
+ encoding="utf-8",
190
+ when="D",
191
+ interval=1,
192
+ backupCount=7
193
+ )
194
+ debug_file_handler.setLevel(logging.DEBUG)
195
+ debug_file_handler.setFormatter(logging.Formatter(fmt))
196
+
197
+ info_file_handler = TimedRotatingFileHandler(
198
+ filename=os.path.join(log_directory, "info.log"),
199
+ encoding="utf-8",
200
+ when="D",
201
+ interval=1,
202
+ backupCount=7
203
+ )
204
+ info_file_handler.setLevel(logging.INFO)
205
+ info_file_handler.setFormatter(logging.Formatter(fmt))
206
+
207
+ error_file_handler = TimedRotatingFileHandler(
208
+ filename=os.path.join(log_directory, "error.log"),
209
+ encoding="utf-8",
210
+ when="D",
211
+ interval=1,
212
+ backupCount=7
213
+ )
214
+ error_file_handler.setLevel(logging.ERROR)
215
+ error_file_handler.setFormatter(logging.Formatter(fmt))
216
+
217
+ logging.basicConfig(
218
+ level=logging.DEBUG,
219
+ datefmt="%a, %d %b %Y %H:%M:%S",
220
+ handlers=[
221
+ debug_file_handler,
222
+ info_file_handler,
223
+ error_file_handler,
224
+ ]
225
+ )
226
+
227
+
228
+ if __name__ == "__main__":
229
+ pass
main.py CHANGED
@@ -1,6 +1,7 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
 
4
  from pathlib import Path
5
  import platform
6
  import shutil
@@ -9,11 +10,16 @@ import zipfile
9
  import gradio as gr
10
  from huggingface_hub import snapshot_download
11
  import numpy as np
12
- import torch
13
 
14
- from project_settings import environment, project_path
 
 
15
  from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet
16
 
 
 
 
 
17
 
18
  def get_args():
19
  parser = argparse.ArgumentParser()
@@ -48,11 +54,16 @@ def get_args():
48
  return args
49
 
50
 
 
 
 
 
51
  denoise_engines = dict()
52
 
53
 
54
  def when_click_denoise_button(noisy_audio_t, engine: str):
55
  sample_rate, signal = noisy_audio_t
 
56
 
57
  noisy_audio = np.array(signal / (1 << 15), dtype=np.float32)
58
 
@@ -143,6 +154,17 @@ def main():
143
  cache_mode="lazy",
144
  )
145
 
 
 
 
 
 
 
 
 
 
 
 
146
  # http://127.0.0.1:7864/
147
  blocks.queue().launch(
148
  share=False if platform.system() == "Windows" else False,
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
4
+ import logging
5
  from pathlib import Path
6
  import platform
7
  import shutil
 
10
  import gradio as gr
11
  from huggingface_hub import snapshot_download
12
  import numpy as np
 
13
 
14
+ import log
15
+ from project_settings import environment, project_path, log_directory
16
+ from toolbox.os.command import Command
17
  from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet
18
 
19
+ log.setup_size_rotating(log_directory=log_directory)
20
+
21
+ logger = logging.getLogger("main")
22
+
23
 
24
  def get_args():
25
  parser = argparse.ArgumentParser()
 
54
  return args
55
 
56
 
57
+ def shell(cmd: str):
58
+ return Command.popen(cmd)
59
+
60
+
61
  denoise_engines = dict()
62
 
63
 
64
  def when_click_denoise_button(noisy_audio_t, engine: str):
65
  sample_rate, signal = noisy_audio_t
66
+ logger.info(f"run denoise; engine: {engine}, sample_rate: {sample_rate}, signal dtype: {signal.dtype}, signal shape: {signal.shape}")
67
 
68
  noisy_audio = np.array(signal / (1 << 15), dtype=np.float32)
69
 
 
154
  cache_mode="lazy",
155
  )
156
 
157
+ with gr.TabItem("shell"):
158
+ shell_text = gr.Textbox(label="cmd")
159
+ shell_button = gr.Button("run")
160
+ shell_output = gr.Textbox(label="output")
161
+
162
+ shell_button.click(
163
+ shell,
164
+ inputs=[shell_text,],
165
+ outputs=[shell_output],
166
+ )
167
+
168
  # http://127.0.0.1:7864/
169
  blocks.queue().launch(
170
  share=False if platform.system() == "Windows" else False,
project_settings.py CHANGED
@@ -12,8 +12,8 @@ project_path = Path(project_path)
12
  log_directory = project_path / "logs"
13
  log_directory.mkdir(parents=True, exist_ok=True)
14
 
15
- temp_directory = project_path / "temp"
16
- temp_directory.mkdir(parents=True, exist_ok=True)
17
 
18
  environment = EnvironmentManager(
19
  path=os.path.join(project_path, "dotenv"),
 
12
  log_directory = project_path / "logs"
13
  log_directory.mkdir(parents=True, exist_ok=True)
14
 
15
+ # temp_directory = project_path / "temp"
16
+ # temp_directory.mkdir(parents=True, exist_ok=True)
17
 
18
  environment = EnvironmentManager(
19
  path=os.path.join(project_path, "dotenv"),