Update PLCMOS/plc_mos.py
Browse files- PLCMOS/plc_mos.py +0 -93
PLCMOS/plc_mos.py
CHANGED
@@ -7,9 +7,6 @@ import onnxruntime as ort
|
|
7 |
from numpy.fft import rfft
|
8 |
from numpy.lib.stride_tricks import as_strided
|
9 |
|
10 |
-
from utils.utils import LSD
|
11 |
-
|
12 |
-
|
13 |
class PLCMOSEstimator():
|
14 |
def __init__(self, model_version=1):
|
15 |
"""
|
@@ -155,93 +152,3 @@ class PLCMOSEstimator():
|
|
155 |
mos_2 = float(session.run(None, onnx_inputs)[0])
|
156 |
mos = [mos, mos_2]
|
157 |
return mos
|
158 |
-
|
159 |
-
|
160 |
-
def run_with_defaults(degraded, clean, allow_set_size_difference=False, progress=False, model_ver=1):
|
161 |
-
import soundfile as sf
|
162 |
-
import glob
|
163 |
-
import tqdm
|
164 |
-
import pandas as pd
|
165 |
-
|
166 |
-
if os.path.isfile(degraded):
|
167 |
-
degraded = [degraded]
|
168 |
-
else:
|
169 |
-
degraded = list(glob.glob(os.path.join(degraded, "*.wav")))
|
170 |
-
|
171 |
-
if os.path.isfile(clean):
|
172 |
-
clean = [clean] * len(degraded)
|
173 |
-
else:
|
174 |
-
clean = list(glob.glob(os.path.join(clean, "*.wav")))
|
175 |
-
|
176 |
-
degraded = list(sorted(degraded))
|
177 |
-
clean = list(sorted(clean))
|
178 |
-
|
179 |
-
if not allow_set_size_difference:
|
180 |
-
assert len(degraded) == len(clean)
|
181 |
-
|
182 |
-
clean_dict = {os.path.basename(x): x for x in clean}
|
183 |
-
clean = []
|
184 |
-
for degraded_name in degraded:
|
185 |
-
clean.append(clean_dict[os.path.basename(degraded_name)])
|
186 |
-
assert len(degraded) == len(clean)
|
187 |
-
|
188 |
-
iter = zip(degraded, clean)
|
189 |
-
if progress:
|
190 |
-
iter = tqdm.tqdm(iter, total=len(degraded))
|
191 |
-
results = []
|
192 |
-
|
193 |
-
estimator = PLCMOSEstimator(model_version=model_ver)
|
194 |
-
intr = []
|
195 |
-
nonintr = []
|
196 |
-
lsds = []
|
197 |
-
sisdrs = []
|
198 |
-
for degraded_name, clean_name in iter:
|
199 |
-
audio_degraded, sr_degraded = sf.read(degraded_name)
|
200 |
-
audio_clean, sr_clean = sf.read(clean_name)
|
201 |
-
lsd = LSD(audio_clean, audio_degraded)
|
202 |
-
audio_degraded = librosa.resample(audio_degraded, 48000, 16000, res_type='kaiser_fast')
|
203 |
-
audio_clean = librosa.resample(audio_clean, 48000, 16000, res_type='kaiser_fast')
|
204 |
-
|
205 |
-
score = estimator.run(audio_degraded, audio_clean)
|
206 |
-
results.append(
|
207 |
-
{
|
208 |
-
"filename_degraded": degraded_name,
|
209 |
-
"filename_clean": clean_name,
|
210 |
-
"intrusive" + str(model_ver): score[0],
|
211 |
-
"non-intrusive" + str(model_ver): score[1],
|
212 |
-
|
213 |
-
}
|
214 |
-
)
|
215 |
-
lsds.append(lsd)
|
216 |
-
intr.append(score[0])
|
217 |
-
nonintr.append(score[1])
|
218 |
-
iter.set_description("Intru {}, Non-Intr {}, LSD {}, SISDR {}".format(sum(intr) / len(intr),
|
219 |
-
sum(nonintr) / len(nonintr),
|
220 |
-
sum(lsds) / len(lsds),
|
221 |
-
sum(sisdrs) / len(sisdrs)))
|
222 |
-
|
223 |
-
return pd.DataFrame(results)
|
224 |
-
|
225 |
-
|
226 |
-
if __name__ == "__main__":
|
227 |
-
import argparse
|
228 |
-
|
229 |
-
parser = argparse.ArgumentParser()
|
230 |
-
parser.add_argument("--degraded", type=str, required=True, help="Path to folder with degraded audio files")
|
231 |
-
parser.add_argument("--clean", type=str, required=True, help="Path to folder with clean audio files")
|
232 |
-
parser.add_argument("--model-ver", type=int, default=1, help="Model version to use")
|
233 |
-
parser.add_argument("--out-csv", type=str, default=None, help="Path to output CSV file, if CSV output is desired")
|
234 |
-
parser.add_argument("--allow-set-size-difference", type=bool, default=True,
|
235 |
-
help="Set to true to allow the number of degraded and clean audio files to be different")
|
236 |
-
args = parser.parse_args()
|
237 |
-
|
238 |
-
results = run_with_defaults(args.degraded, args.clean, args.allow_set_size_difference, True, args.model_ver)
|
239 |
-
|
240 |
-
if args.out_csv is not None:
|
241 |
-
results.to_csv(args.out_csv)
|
242 |
-
else:
|
243 |
-
import pandas as pd
|
244 |
-
|
245 |
-
pd.set_option("display.max_rows", None)
|
246 |
-
# print(results)
|
247 |
-
print("")
|
|
|
7 |
from numpy.fft import rfft
|
8 |
from numpy.lib.stride_tricks import as_strided
|
9 |
|
|
|
|
|
|
|
10 |
class PLCMOSEstimator():
|
11 |
def __init__(self, model_version=1):
|
12 |
"""
|
|
|
152 |
mos_2 = float(session.run(None, onnx_inputs)[0])
|
153 |
mos = [mos, mos_2]
|
154 |
return mos
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|