import os from pathlib import Path import pytest from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig, open_dict from swim.eval import evaluate from swim.train import train @pytest.mark.slow def test_train_eval( tmp_path: Path, cfg_train: DictConfig, cfg_eval: DictConfig ) -> None: """Tests training and evaluation by training for 1 epoch with `train.py` then evaluating with `eval.py`. :param tmp_path: The temporary logging path. :param cfg_train: A DictConfig containing a valid training configuration. :param cfg_eval: A DictConfig containing a valid evaluation configuration. """ assert str(tmp_path) == cfg_train.paths.output_dir == cfg_eval.paths.output_dir with open_dict(cfg_train): cfg_train.trainer.max_epochs = 1 cfg_train.test = True HydraConfig().set_config(cfg_train) train_metric_dict, _ = train(cfg_train) assert "last.ckpt" in os.listdir(tmp_path / "checkpoints") with open_dict(cfg_eval): cfg_eval.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") HydraConfig().set_config(cfg_eval) test_metric_dict, _ = evaluate(cfg_eval) assert test_metric_dict["test/acc"] > 0.0 assert ( abs(train_metric_dict["test/acc"].item() - test_metric_dict["test/acc"].item()) < 0.001 )