SLT-Task2-ngram-baseline / hyper_optim.py
Taejin's picture
Uploading images and scripts
fc17b57 verified
import optuna
import os
import tempfile
import time
import json
import subprocess
import logging
from beam_search_utils import (
write_seglst_jsons,
run_mp_beam_search_decoding,
convert_nemo_json_to_seglst,
)
from hydra.core.config_store import ConfigStore
def evaluate(cfg, temp_out_dir, workspace_dir, asrdiar_file_name, source_info_dict, hypothesis_sessions_dict, reference_info_dict):
write_seglst_jsons(hypothesis_sessions_dict, input_error_src_list_path=cfg.input_error_src_list_path, diar_out_path=temp_out_dir, ext_str='hyp')
write_seglst_jsons(reference_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=temp_out_dir, ext_str='ref')
write_seglst_jsons(source_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=temp_out_dir, ext_str='src')
# Construct the file paths
src_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.src.seglst.json")
hyp_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.hyp.seglst.json")
ref_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.ref.seglst.json")
# Construct the output JSON file path
output_cpwer_hyp_json_file = os.path.join(temp_out_dir, f"{asrdiar_file_name}.hyp.seglst_cpwer.json")
output_cpwer_src_json_file = os.path.join(temp_out_dir, f"{asrdiar_file_name}.src.seglst_cpwer.json")
# Run meeteval-wer command
cmd_hyp = [
"meeteval-wer",
"cpwer",
"-h", hyp_seglst_json,
"-r", ref_seglst_json
]
subprocess.run(cmd_hyp)
cmd_src = [
"meeteval-wer",
"cpwer",
"-h", src_seglst_json,
"-r", ref_seglst_json
]
subprocess.run(cmd_src)
# Read the JSON file and print the cpWER
try:
with open(output_cpwer_hyp_json_file, "r") as file:
data_h = json.load(file)
print("Hypothesis cpWER:", data_h["error_rate"])
cpwer = data_h["error_rate"]
logging.info(f"-> HYPOTHESIS cpWER={cpwer:.4f}")
except FileNotFoundError:
raise FileNotFoundError(f"Output JSON: {output_cpwer_hyp_json_file}\nfile not found.")
try:
with open(output_cpwer_src_json_file, "r") as file:
data_s = json.load(file)
print("Source cpWER:", data_s["error_rate"])
source_cpwer = data_s["error_rate"]
logging.info(f"-> SOURCE cpWER={source_cpwer:.4f}")
except FileNotFoundError:
raise FileNotFoundError(f"Output JSON: {output_cpwer_src_json_file}\nfile not found.")
return cpwer
def optuna_suggest_params(cfg, trial):
cfg.alpha = trial.suggest_float("alpha", 0.01, 5.0)
cfg.beta = trial.suggest_float("beta", 0.001, 2.0)
cfg.beam_width = trial.suggest_int("beam_width", 4, 64)
cfg.word_window = trial.suggest_int("word_window", 16, 64)
cfg.use_ngram = True
cfg.parallel_chunk_word_len = trial.suggest_int("parallel_chunk_word_len", 50, 300)
cfg.peak_prob = trial.suggest_float("peak_prob", 0.9, 1.0)
return cfg
def beamsearch_objective(
trial,
cfg,
speaker_beam_search_decoder,
loaded_kenlm_model,
div_trans_info_dict,
org_trans_info_dict,
source_info_dict,
reference_info_dict,
):
with tempfile.TemporaryDirectory(dir=cfg.temp_out_dir, prefix="GenSEC_") as loca_temp_out_dir:
start_time2 = time.time()
cfg = optuna_suggest_params(cfg, trial)
trans_info_dict = run_mp_beam_search_decoding(speaker_beam_search_decoder,
loaded_kenlm_model=loaded_kenlm_model,
div_trans_info_dict=div_trans_info_dict,
org_trans_info_dict=org_trans_info_dict,
div_mp=True,
win_len=cfg.parallel_chunk_word_len,
word_window=cfg.word_window,
port=cfg.port,
use_ngram=cfg.use_ngram,
)
hypothesis_sessions_dict = convert_nemo_json_to_seglst(trans_info_dict)
cpwer = evaluate(cfg, loca_temp_out_dir, cfg.workspace_dir, cfg.asrdiar_file_name, source_info_dict, hypothesis_sessions_dict, reference_info_dict)
logging.info(f"Beam Search time taken for trial {trial}: {(time.time() - start_time2)/60:.2f} mins")
logging.info(f"Trial: {trial.number}")
logging.info(f"[ cpWER={cpwer:.4f} ]")
logging.info("-----------------------------------------------")
return cpwer
def optuna_hyper_optim(
cfg,
speaker_beam_search_decoder,
loaded_kenlm_model,
div_trans_info_dict,
org_trans_info_dict,
source_info_dict,
reference_info_dict,
):
"""
Optuna hyper-parameter optimization function.
Parameters:
cfg (dict): A dictionary containing the configuration parameters.
"""
worker_function = lambda trial: beamsearch_objective( # noqa: E731
trial=trial,
cfg=cfg,
speaker_beam_search_decoder=speaker_beam_search_decoder,
loaded_kenlm_model=loaded_kenlm_model,
div_trans_info_dict=div_trans_info_dict,
org_trans_info_dict=org_trans_info_dict,
source_info_dict=source_info_dict,
reference_info_dict=reference_info_dict,
)
study = optuna.create_study(
direction="minimize",
study_name=cfg.optuna_study_name,
storage=cfg.storage,
load_if_exists=True
)
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Setup the root logger.
if cfg.output_log_file is not None:
logger.addHandler(logging.FileHandler(cfg.output_log_file, mode="a"))
logger.addHandler(logging.StreamHandler())
optuna.logging.enable_propagation() # Propagate logs to the root logger.
study.optimize(worker_function, n_trials=cfg.optuna_n_trials)