Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import annotations | |
import glob | |
import json | |
import os | |
from collections.abc import Callable | |
from pathlib import Path | |
import numpy as np | |
import pandas as pd | |
import scipy.stats | |
import torch | |
import torch.nn as nn | |
from utmosv2.dataset import ( | |
MultiSpecDataset, | |
MultiSpecExtDataset, | |
SSLDataset, | |
SSLExtDataset, | |
SSLLMultiSpecExtDataset, | |
) | |
from utmosv2.model import ( | |
MultiSpecExtModel, | |
MultiSpecModelV2, | |
SSLExtModel, | |
SSLMultiSpecExtModelV1, | |
SSLMultiSpecExtModelV2, | |
) | |
from utmosv2.preprocess import add_sys_mean, preprocess, preprocess_test | |
def get_data(cfg) -> pd.DataFrame: | |
train_mos_list = pd.read_csv(cfg.input_dir / "sets/train_mos_list.txt", header=None) | |
val_mos_list = pd.read_csv(cfg.input_dir / "sets/val_mos_list.txt", header=None) | |
test_mos_list = pd.read_csv(cfg.input_dir / "sets/test_mos_list.txt", header=None) | |
data = pd.concat([train_mos_list, val_mos_list, test_mos_list], axis=0) | |
data.columns = ["utt_id", "mos"] | |
data["file_path"] = data["utt_id"].apply(lambda x: cfg.input_dir / f"wav/{x}") | |
return data | |
def get_dataset(cfg, data: pd.DataFrame, phase: str) -> torch.utils.data.Dataset: | |
if cfg.print_config: | |
print(f"Using dataset: {cfg.dataset.name}") | |
if cfg.dataset.name == "multi_spec": | |
res = MultiSpecDataset(cfg, data, phase, cfg.transform) | |
elif cfg.dataset.name == "ssl": | |
res = SSLDataset(cfg, data, phase) | |
elif cfg.dataset.name == "sslext": | |
res = SSLExtDataset(cfg, data, phase) | |
elif cfg.dataset.name == "ssl_multispec_ext": | |
res = SSLLMultiSpecExtDataset(cfg, data, phase, cfg.transform) | |
elif cfg.dataset.name == "multi_spec_ext": | |
res = MultiSpecExtDataset(cfg, data, phase, cfg.transform) | |
else: | |
raise NotImplementedError | |
return res | |
def get_model(cfg, device: torch.device) -> nn.Module: | |
if cfg.print_config: | |
print(f"Using model: {cfg.model.name}") | |
if cfg.model.name == "multi_specv2": | |
model = MultiSpecModelV2(cfg) | |
elif cfg.model.name == "sslext": | |
model = SSLExtModel(cfg) | |
elif cfg.model.name == "multi_spec_ext": | |
model = MultiSpecExtModel(cfg) | |
elif cfg.model.name == "ssl_multispec_ext": | |
model = SSLMultiSpecExtModelV1(cfg) | |
elif cfg.model.name == "ssl_multispec_ext_v2": | |
model = SSLMultiSpecExtModelV2(cfg) | |
else: | |
raise NotImplementedError | |
model = model.to(device) | |
if cfg.weight is not None: | |
model.load_state_dict(torch.load(cfg.weight)) | |
return model | |
def get_metrics() -> dict[str, Callable[[np.ndarray, np.ndarray], float]]: | |
return { | |
"mse": lambda x, y: np.mean((x - y) ** 2), | |
"lcc": lambda x, y: np.corrcoef(x, y)[0][1], | |
"srcc": lambda x, y: scipy.stats.spearmanr(x, y)[0], | |
"ktau": lambda x, y: scipy.stats.kendalltau(x, y)[0], | |
} | |
def calc_metrics(data: pd.DataFrame, preds: np.ndarray) -> dict[str, float]: | |
data = data.copy() | |
data["preds"] = preds | |
data_sys = data.groupby("sys_id", as_index=False)[["mos", "preds"]].mean() | |
res = {} | |
for name, d in {"utt": data, "sys": data_sys}.items(): | |
res[f"{name}_mse"] = np.mean((d["mos"].values - d["preds"].values) ** 2) | |
res[f"{name}_lcc"] = np.corrcoef(d["mos"].values, d["preds"].values)[0][1] | |
res[f"{name}_srcc"] = scipy.stats.spearmanr(d["mos"].values, d["preds"].values)[ | |
0 | |
] | |
res[f"{name}_ktau"] = scipy.stats.kendalltau( | |
d["mos"].values, d["preds"].values | |
)[0] | |
return res | |
def configure_defaults(cfg): | |
if cfg.id_name is None: | |
cfg.id_name = "utt_id" | |
def _get_testdata(cfg, data: pd.DataFrame) -> pd.DataFrame: | |
with open(cfg.inference.val_list_path, "r") as f: | |
val_lists = [s.replace("\n", "") + ".wav" for s in f.readlines()] | |
test_data = data[data["utt_id"].isin(set(val_lists))] | |
return test_data | |
def get_inference_data(cfg) -> pd.DataFrame: | |
if cfg.reproduce: | |
data = get_data(cfg) | |
data = preprocess_test(cfg, data) | |
data = _get_testdata(cfg, data) | |
else: | |
if cfg.input_dir: | |
files = sorted(glob.glob(str(cfg.input_dir / "*.wav"))) | |
data = pd.DataFrame({"file_path": files}) | |
else: | |
data = pd.DataFrame({"file_path": [cfg.input_path.as_posix()]}) | |
data["utt_id"] = data["file_path"].apply( | |
lambda x: x.split("/")[-1].replace(".wav", "") | |
) | |
data["sys_id"] = data["utt_id"].apply(lambda x: x.split("-")[0]) | |
if cfg.inference.val_list_path: | |
with open(cfg.inference.val_list_path, "r") as f: | |
val_lists = [s.replace(".wav", "") for s in f.read().splitlines()] | |
print(val_lists) | |
data = data[data["utt_id"].isin(set(val_lists))] | |
data["dataset"] = cfg.predict_dataset | |
data["mos"] = 0 | |
return data | |
def get_train_data(cfg) -> pd.DataFrame: | |
if cfg.reproduce: | |
data = get_data(cfg) | |
data = preprocess(cfg, data) | |
else: | |
with open(cfg.data_config, "r") as f: | |
datasets = json.load(f) | |
data = [] | |
for dt in datasets["data"]: | |
files = sorted(glob.glob(str(Path(dt["dir"]) / "*.wav"))) | |
d = pd.DataFrame({"file_path": files}) | |
d["dataset"] = dt["name"] | |
d["utt_id"] = d["file_path"].apply( | |
lambda x: x.split("/")[-1].replace(".wav", "") | |
) | |
mos_list = pd.read_csv(dt["mos_list"], header=None) | |
mos_list.columns = ["utt_id", "mos"] | |
mos_list["utt_id"] = mos_list["utt_id"].apply( | |
lambda x: x.replace(".wav", "") | |
) | |
d = d.merge(mos_list, on="utt_id", how="inner") | |
d["sys_id"] = d["utt_id"].apply(lambda x: x.split("-")[0]) | |
add_sys_mean(d) | |
data.append(d) | |
data = pd.concat(data, axis=0) | |
return data | |
def show_inference_data(data: pd.DataFrame): | |
print( | |
data[[c for c in data.columns if c != "mos"]] | |
.rename(columns={"dataset": "predict_dataset"}) | |
.head() | |
) | |
def _get_test_save_name(cfg) -> str: | |
return f"{cfg.config_name}_[fold{cfg.inference.fold}_tta{cfg.inference.num_tta}_s{cfg.split.seed}]" | |
def save_test_preds( | |
cfg, data: pd.DataFrame, test_preds: np.ndarray, test_metrics: dict[str, float] | |
): | |
test_df = pd.DataFrame({cfg.id_name: data[cfg.id_name], "test_preds": test_preds}) | |
save_path = ( | |
cfg.inference.save_path | |
/ f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})_test_preds{'_final' if cfg.final else ''}.csv", | |
) | |
test_df.to_csv(save_path, index=False) | |
save_path = ( | |
cfg.inference.save_path | |
/ f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})_val_score{'_final' if cfg.final else ''}.json", | |
) | |
with open(save_path, "w") as f: | |
json.dump(test_metrics, f) | |
print(f"Test predictions are saved to {save_path}") | |
def make_submission_file(cfg, data: pd.DataFrame, test_preds: np.ndarray): | |
submit = pd.DataFrame({cfg.id_name: data[cfg.id_name], "prediction": test_preds}) | |
os.makedirs( | |
cfg.inference.submit_save_path | |
/ f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})", | |
exist_ok=True, | |
) | |
sub_file = ( | |
cfg.inference.submit_save_path | |
/ f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})" | |
/ "answer.txt" | |
) | |
submit.to_csv( | |
sub_file, | |
index=False, | |
header=False, | |
) | |
print(f"Submission file is saved to {sub_file}") | |
def save_preds(cfg, data: pd.DataFrame, test_preds: np.ndarray): | |
pred = pd.DataFrame({cfg.id_name: data[cfg.id_name], "mos": test_preds}) | |
if cfg.out_path is None: | |
print("Predictions:") | |
print(pred) | |
else: | |
pred.to_csv(cfg.out_path, index=False) | |
print(f"Predictions are saved to {cfg.out_path}") | |