LINC-BIT's picture
Upload 1912 files
b84549f verified
from typing import Dict, Any
from torch import nn
from data.datasets.ab_dataset import ABDataset
from abc import ABC, abstractmethod
from utils.common.log import logger
import json
import os
from utils.common.others import backup_key_codes
from .model import BaseModel
from data import Scenario
from schema import Schema
from utils.common.data_record import write_json
class BaseAlg(ABC):
def __init__(self, models: Dict[str, BaseModel], res_save_dir):
self.models = models
self.res_save_dir = res_save_dir
self.get_required_models_schema().validate(models)
os.makedirs(res_save_dir)
logger.info(f'[alg] init alg: {self.__class__.__name__}, res saved in {res_save_dir}')
@abstractmethod
def get_required_models_schema(self) -> Schema:
raise NotImplementedError
@abstractmethod
def get_required_hyp_schema(self) -> Schema:
raise NotImplementedError
@abstractmethod
def run(self,
scenario: Scenario,
hyps: Dict) -> Dict[str, Any]:
"""
return metrics
"""
self.get_required_hyp_schema().validate(hyps)
try:
write_json(os.path.join(self.res_save_dir, 'hyps.json'), hyps, ensure_obj_serializable=True)
except:
with open(os.path.join(self.res_save_dir, 'hyps.txt'), 'w') as f:
f.write(str(hyps))
write_json(os.path.join(self.res_save_dir, 'scenario.json'), scenario.to_json())
logger.info(f'[alg] alg {self.__class__.__name__} start running')
backup_key_codes(os.path.join(self.res_save_dir, 'backup_codes'))