|
import optuna |
|
import os |
|
import tempfile |
|
import time |
|
import json |
|
import subprocess |
|
import logging |
|
from beam_search_utils import ( |
|
write_seglst_jsons, |
|
run_mp_beam_search_decoding, |
|
convert_nemo_json_to_seglst, |
|
) |
|
from hydra.core.config_store import ConfigStore |
|
|
|
|
|
def evaluate(cfg, temp_out_dir, workspace_dir, asrdiar_file_name, source_info_dict, hypothesis_sessions_dict, reference_info_dict): |
|
write_seglst_jsons(hypothesis_sessions_dict, input_error_src_list_path=cfg.input_error_src_list_path, diar_out_path=temp_out_dir, ext_str='hyp') |
|
write_seglst_jsons(reference_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=temp_out_dir, ext_str='ref') |
|
write_seglst_jsons(source_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=temp_out_dir, ext_str='src') |
|
|
|
|
|
src_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.src.seglst.json") |
|
hyp_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.hyp.seglst.json") |
|
ref_seglst_json = os.path.join(temp_out_dir, f"{asrdiar_file_name}.ref.seglst.json") |
|
|
|
|
|
output_cpwer_hyp_json_file = os.path.join(temp_out_dir, f"{asrdiar_file_name}.hyp.seglst_cpwer.json") |
|
output_cpwer_src_json_file = os.path.join(temp_out_dir, f"{asrdiar_file_name}.src.seglst_cpwer.json") |
|
|
|
|
|
cmd_hyp = [ |
|
"meeteval-wer", |
|
"cpwer", |
|
"-h", hyp_seglst_json, |
|
"-r", ref_seglst_json |
|
] |
|
subprocess.run(cmd_hyp) |
|
|
|
cmd_src = [ |
|
"meeteval-wer", |
|
"cpwer", |
|
"-h", src_seglst_json, |
|
"-r", ref_seglst_json |
|
] |
|
subprocess.run(cmd_src) |
|
|
|
|
|
try: |
|
with open(output_cpwer_hyp_json_file, "r") as file: |
|
data_h = json.load(file) |
|
print("Hypothesis cpWER:", data_h["error_rate"]) |
|
cpwer = data_h["error_rate"] |
|
logging.info(f"-> HYPOTHESIS cpWER={cpwer:.4f}") |
|
except FileNotFoundError: |
|
raise FileNotFoundError(f"Output JSON: {output_cpwer_hyp_json_file}\nfile not found.") |
|
|
|
try: |
|
with open(output_cpwer_src_json_file, "r") as file: |
|
data_s = json.load(file) |
|
print("Source cpWER:", data_s["error_rate"]) |
|
source_cpwer = data_s["error_rate"] |
|
logging.info(f"-> SOURCE cpWER={source_cpwer:.4f}") |
|
except FileNotFoundError: |
|
raise FileNotFoundError(f"Output JSON: {output_cpwer_src_json_file}\nfile not found.") |
|
return cpwer |
|
|
|
|
|
def optuna_suggest_params(cfg, trial): |
|
cfg.alpha = trial.suggest_float("alpha", 0.01, 5.0) |
|
cfg.beta = trial.suggest_float("beta", 0.001, 2.0) |
|
cfg.beam_width = trial.suggest_int("beam_width", 4, 64) |
|
cfg.word_window = trial.suggest_int("word_window", 16, 64) |
|
cfg.use_ngram = True |
|
cfg.parallel_chunk_word_len = trial.suggest_int("parallel_chunk_word_len", 50, 300) |
|
cfg.peak_prob = trial.suggest_float("peak_prob", 0.9, 1.0) |
|
return cfg |
|
|
|
def beamsearch_objective( |
|
trial, |
|
cfg, |
|
speaker_beam_search_decoder, |
|
loaded_kenlm_model, |
|
div_trans_info_dict, |
|
org_trans_info_dict, |
|
source_info_dict, |
|
reference_info_dict, |
|
): |
|
with tempfile.TemporaryDirectory(dir=cfg.temp_out_dir, prefix="GenSEC_") as loca_temp_out_dir: |
|
start_time2 = time.time() |
|
cfg = optuna_suggest_params(cfg, trial) |
|
trans_info_dict = run_mp_beam_search_decoding(speaker_beam_search_decoder, |
|
loaded_kenlm_model=loaded_kenlm_model, |
|
div_trans_info_dict=div_trans_info_dict, |
|
org_trans_info_dict=org_trans_info_dict, |
|
div_mp=True, |
|
win_len=cfg.parallel_chunk_word_len, |
|
word_window=cfg.word_window, |
|
port=cfg.port, |
|
use_ngram=cfg.use_ngram, |
|
) |
|
hypothesis_sessions_dict = convert_nemo_json_to_seglst(trans_info_dict) |
|
cpwer = evaluate(cfg, loca_temp_out_dir, cfg.workspace_dir, cfg.asrdiar_file_name, source_info_dict, hypothesis_sessions_dict, reference_info_dict) |
|
logging.info(f"Beam Search time taken for trial {trial}: {(time.time() - start_time2)/60:.2f} mins") |
|
logging.info(f"Trial: {trial.number}") |
|
logging.info(f"[ cpWER={cpwer:.4f} ]") |
|
logging.info("-----------------------------------------------") |
|
|
|
return cpwer |
|
|
|
|
|
def optuna_hyper_optim( |
|
cfg, |
|
speaker_beam_search_decoder, |
|
loaded_kenlm_model, |
|
div_trans_info_dict, |
|
org_trans_info_dict, |
|
source_info_dict, |
|
reference_info_dict, |
|
): |
|
""" |
|
Optuna hyper-parameter optimization function. |
|
|
|
Parameters: |
|
cfg (dict): A dictionary containing the configuration parameters. |
|
|
|
""" |
|
worker_function = lambda trial: beamsearch_objective( |
|
trial=trial, |
|
cfg=cfg, |
|
speaker_beam_search_decoder=speaker_beam_search_decoder, |
|
loaded_kenlm_model=loaded_kenlm_model, |
|
div_trans_info_dict=div_trans_info_dict, |
|
org_trans_info_dict=org_trans_info_dict, |
|
source_info_dict=source_info_dict, |
|
reference_info_dict=reference_info_dict, |
|
) |
|
study = optuna.create_study( |
|
direction="minimize", |
|
study_name=cfg.optuna_study_name, |
|
storage=cfg.storage, |
|
load_if_exists=True |
|
) |
|
logger = logging.getLogger() |
|
logger.setLevel(logging.INFO) |
|
if cfg.output_log_file is not None: |
|
logger.addHandler(logging.FileHandler(cfg.output_log_file, mode="a")) |
|
logger.addHandler(logging.StreamHandler()) |
|
optuna.logging.enable_propagation() |
|
study.optimize(worker_function, n_trials=cfg.optuna_n_trials) |