UTMOSv2 / utmosv2 /utils /task_dependents.py
kAIto47802
Resolved conflict in README.md
b55d767
raw
history blame
7.91 kB
from __future__ import annotations
import glob
import json
import os
from collections.abc import Callable
from pathlib import Path
import numpy as np
import pandas as pd
import scipy.stats
import torch
import torch.nn as nn
from utmosv2.dataset import (
MultiSpecDataset,
MultiSpecExtDataset,
SSLDataset,
SSLExtDataset,
SSLLMultiSpecExtDataset,
)
from utmosv2.model import (
MultiSpecExtModel,
MultiSpecModelV2,
SSLExtModel,
SSLMultiSpecExtModelV1,
SSLMultiSpecExtModelV2,
)
from utmosv2.preprocess import add_sys_mean, preprocess, preprocess_test
def get_data(cfg) -> pd.DataFrame:
train_mos_list = pd.read_csv(cfg.input_dir / "sets/train_mos_list.txt", header=None)
val_mos_list = pd.read_csv(cfg.input_dir / "sets/val_mos_list.txt", header=None)
test_mos_list = pd.read_csv(cfg.input_dir / "sets/test_mos_list.txt", header=None)
data = pd.concat([train_mos_list, val_mos_list, test_mos_list], axis=0)
data.columns = ["utt_id", "mos"]
data["file_path"] = data["utt_id"].apply(lambda x: cfg.input_dir / f"wav/{x}")
return data
def get_dataset(cfg, data: pd.DataFrame, phase: str) -> torch.utils.data.Dataset:
if cfg.print_config:
print(f"Using dataset: {cfg.dataset.name}")
if cfg.dataset.name == "multi_spec":
res = MultiSpecDataset(cfg, data, phase, cfg.transform)
elif cfg.dataset.name == "ssl":
res = SSLDataset(cfg, data, phase)
elif cfg.dataset.name == "sslext":
res = SSLExtDataset(cfg, data, phase)
elif cfg.dataset.name == "ssl_multispec_ext":
res = SSLLMultiSpecExtDataset(cfg, data, phase, cfg.transform)
elif cfg.dataset.name == "multi_spec_ext":
res = MultiSpecExtDataset(cfg, data, phase, cfg.transform)
else:
raise NotImplementedError
return res
def get_model(cfg, device: torch.device) -> nn.Module:
if cfg.print_config:
print(f"Using model: {cfg.model.name}")
if cfg.model.name == "multi_specv2":
model = MultiSpecModelV2(cfg)
elif cfg.model.name == "sslext":
model = SSLExtModel(cfg)
elif cfg.model.name == "multi_spec_ext":
model = MultiSpecExtModel(cfg)
elif cfg.model.name == "ssl_multispec_ext":
model = SSLMultiSpecExtModelV1(cfg)
elif cfg.model.name == "ssl_multispec_ext_v2":
model = SSLMultiSpecExtModelV2(cfg)
else:
raise NotImplementedError
model = model.to(device)
if cfg.weight is not None:
model.load_state_dict(torch.load(cfg.weight))
return model
def get_metrics() -> dict[str, Callable[[np.ndarray, np.ndarray], float]]:
return {
"mse": lambda x, y: np.mean((x - y) ** 2),
"lcc": lambda x, y: np.corrcoef(x, y)[0][1],
"srcc": lambda x, y: scipy.stats.spearmanr(x, y)[0],
"ktau": lambda x, y: scipy.stats.kendalltau(x, y)[0],
}
def calc_metrics(data: pd.DataFrame, preds: np.ndarray) -> dict[str, float]:
data = data.copy()
data["preds"] = preds
data_sys = data.groupby("sys_id", as_index=False)[["mos", "preds"]].mean()
res = {}
for name, d in {"utt": data, "sys": data_sys}.items():
res[f"{name}_mse"] = np.mean((d["mos"].values - d["preds"].values) ** 2)
res[f"{name}_lcc"] = np.corrcoef(d["mos"].values, d["preds"].values)[0][1]
res[f"{name}_srcc"] = scipy.stats.spearmanr(d["mos"].values, d["preds"].values)[
0
]
res[f"{name}_ktau"] = scipy.stats.kendalltau(
d["mos"].values, d["preds"].values
)[0]
return res
def configure_defaults(cfg):
if cfg.id_name is None:
cfg.id_name = "utt_id"
def _get_testdata(cfg, data: pd.DataFrame) -> pd.DataFrame:
with open(cfg.inference.val_list_path, "r") as f:
val_lists = [s.replace("\n", "") + ".wav" for s in f.readlines()]
test_data = data[data["utt_id"].isin(set(val_lists))]
return test_data
def get_inference_data(cfg) -> pd.DataFrame:
if cfg.reproduce:
data = get_data(cfg)
data = preprocess_test(cfg, data)
data = _get_testdata(cfg, data)
else:
if cfg.input_dir:
files = sorted(glob.glob(str(cfg.input_dir / "*.wav")))
data = pd.DataFrame({"file_path": files})
else:
data = pd.DataFrame({"file_path": [cfg.input_path.as_posix()]})
data["utt_id"] = data["file_path"].apply(
lambda x: x.split("/")[-1].replace(".wav", "")
)
data["sys_id"] = data["utt_id"].apply(lambda x: x.split("-")[0])
if cfg.inference.val_list_path:
with open(cfg.inference.val_list_path, "r") as f:
val_lists = [s.replace(".wav", "") for s in f.read().splitlines()]
print(val_lists)
data = data[data["utt_id"].isin(set(val_lists))]
data["dataset"] = cfg.predict_dataset
data["mos"] = 0
return data
def get_train_data(cfg) -> pd.DataFrame:
if cfg.reproduce:
data = get_data(cfg)
data = preprocess(cfg, data)
else:
with open(cfg.data_config, "r") as f:
datasets = json.load(f)
data = []
for dt in datasets["data"]:
files = sorted(glob.glob(str(Path(dt["dir"]) / "*.wav")))
d = pd.DataFrame({"file_path": files})
d["dataset"] = dt["name"]
d["utt_id"] = d["file_path"].apply(
lambda x: x.split("/")[-1].replace(".wav", "")
)
mos_list = pd.read_csv(dt["mos_list"], header=None)
mos_list.columns = ["utt_id", "mos"]
mos_list["utt_id"] = mos_list["utt_id"].apply(
lambda x: x.replace(".wav", "")
)
d = d.merge(mos_list, on="utt_id", how="inner")
d["sys_id"] = d["utt_id"].apply(lambda x: x.split("-")[0])
add_sys_mean(d)
data.append(d)
data = pd.concat(data, axis=0)
return data
def show_inference_data(data: pd.DataFrame):
print(
data[[c for c in data.columns if c != "mos"]]
.rename(columns={"dataset": "predict_dataset"})
.head()
)
def _get_test_save_name(cfg) -> str:
return f"{cfg.config_name}_[fold{cfg.inference.fold}_tta{cfg.inference.num_tta}_s{cfg.split.seed}]"
def save_test_preds(
cfg, data: pd.DataFrame, test_preds: np.ndarray, test_metrics: dict[str, float]
):
test_df = pd.DataFrame({cfg.id_name: data[cfg.id_name], "test_preds": test_preds})
save_path = (
cfg.inference.save_path
/ f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})_test_preds{'_final' if cfg.final else ''}.csv",
)
test_df.to_csv(save_path, index=False)
save_path = (
cfg.inference.save_path
/ f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})_val_score{'_final' if cfg.final else ''}.json",
)
with open(save_path, "w") as f:
json.dump(test_metrics, f)
print(f"Test predictions are saved to {save_path}")
def make_submission_file(cfg, data: pd.DataFrame, test_preds: np.ndarray):
submit = pd.DataFrame({cfg.id_name: data[cfg.id_name], "prediction": test_preds})
os.makedirs(
cfg.inference.submit_save_path
/ f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})",
exist_ok=True,
)
sub_file = (
cfg.inference.submit_save_path
/ f"{_get_test_save_name(cfg)}_({cfg.predict_dataset})"
/ "answer.txt"
)
submit.to_csv(
sub_file,
index=False,
header=False,
)
print(f"Submission file is saved to {sub_file}")
def save_preds(cfg, data: pd.DataFrame, test_preds: np.ndarray):
pred = pd.DataFrame({cfg.id_name: data[cfg.id_name], "mos": test_preds})
if cfg.out_path is None:
print("Predictions:")
print(pred)
else:
pred.to_csv(cfg.out_path, index=False)
print(f"Predictions are saved to {cfg.out_path}")