import torch from torch import nn # from tsai.models.TST import TST from sklearn.neighbors import KNeighborsRegressor from config import get_model_config from loss.weightedmseloss import WeightedMSELoss from loss.weightedmultioutputsloss import WeightedMultiOutputLoss from loss.weightedrmseloss import WeightedRMSELoss from model.Hernandez2021cnnlstm import Hernandez2021CNNLSTM from model.bilstmmodel import BiLSTMModel from model.cnnlstm import CNNLSTM from model.dorschky2020cnn import Dorschky2020CNN from model.gholami2020cnn import Gholami2020CNN from model.lstmlstm import Seq2Seq from model.lstmlstmattention import Seq2SeqAtt from model.lstmlstmrec import Seq2SeqRec from model.lstmmodel import LSTMModel from model.tcnmodel import TCNModel from model.transformer import Transformer from model.transformer_seq2seq import Seq2SeqTransformer from model.transformer_tsai import TransformerTSAI from model.zrenner2018cnn import Zrenner2018CNN from utils.update_config import update_model_config class ModelBuilder: def __init__(self, config): self.config = config self.n_input_channel = len(self.config['selected_sensors'])*6 self.n_output = len(self.config['selected_opensim_labels']) self.model_name = self.config['model_name'] self.model_config = get_model_config(f'config_{self.model_name}') self.model_config = update_model_config(self.config, self.model_config) self.optimizer_name = self.config['optimizer_name'] self.learning_rate = self.config['learning_rate'] self.l2_weight_decay_status = self.config['l2_weight_decay_status'] self.l2_weight_decay = self.config['l2_weight_decay'] self.loss = self.config['loss'] self.weight = self.config['loss_weight'] self.device = self.config['device'] # self.device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu') if not self.n_output == len(self.weight): self.weight = None def run_model_builder(self): model = self.get_model_architecture() criterion = self.get_criterion(self.weight) optimizer = self.get_optimizer() return model, optimizer, criterion def get_model_architecture(self): if self.model_name == 'lstm': # done self.model = LSTMModel(self.model_config) elif self.model_name == 'bilstm': # done self.model = BiLSTMModel(self.model_config) elif self.model_name == 'cnnlstm': # done self.model = CNNLSTM(self.model_config) elif self.model_name == 'hernandez2021cnnlstm': # done self.model = Hernandez2021CNNLSTM(self.model_config) elif self.model_name == 'seq2seq': # done self.model = Seq2Seq(self.config) elif self.model_name == 'seq2seqrec': self.model = Seq2SeqRec(self.n_input_channel, self.n_output) elif self.model_name == 'seq2seqatt':# done self.model = Seq2SeqAtt(self.model_config) elif self.model_name == 'transformer': #done self.model = Transformer(d_input=self.n_input_channel, d_model=12, d_output=self.n_output, d_len=self.config['target_padding_length'], h=8, N=1, attention_size=None, dropout=0.5, chunk_mode=None, pe='original', multihead=True) elif self.model_name == 'seq2seqtransformer': self.model = Seq2SeqTransformer(d_input=self.n_input_channel, d_model=24, d_output=self.n_output, h=8, N=4, attention_size=None, dropout=0.1, chunk_mode=None, pe='original') elif self.model_name == 'transformertsai': c_in = self.n_input_channel # aka channels, features, variables, dimensions c_out = self.n_output seq_len = self.config['target_padding_length'] y_range = self.config['target_padding_length'] max_seq_len = self.config['target_padding_length'] d_model = self.model_config['tsai_d_model'] n_heads = self.model_config['tsai_n_heads'] d_k = d_v = None # if None --> d_model // n_heads d_ff = self.model_config['tsai_d_ff'] res_dropout = self.model_config['tsai_res_dropout_p'] activation = "gelu" n_layers = self.model_config['tsai_n_layers'] fc_dropout = self.model_config['tsai_fc_dropout_p'] classification = self.model_config['classification'] kwargs = {} self.model = TransformerTSAI(c_in, c_out, seq_len, max_seq_len=max_seq_len, d_model=d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, res_dropout=res_dropout, act=activation, n_layers=n_layers, fc_dropout=fc_dropout, classification=classification, **kwargs) elif self.model_name == 'Gholami2020CNN': self.model = Gholami2020CNN(self.model_config) elif self.model_name == 'Dorschky2020CNN': self.model = Dorschky2020CNN(self.model_config) elif self.model_name == 'Zrenner2018CNN': self.model = Zrenner2018CNN(self.model_config) elif self.model_name == 'tcn': self.model = TCNModel(self.model_config) elif self.model_name == 'knn': self.model = KNeighborsRegressor() return self.model def get_optimizer(self): if self.optimizer_name == 'Adam': if self.l2_weight_decay_status: self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=self.l2_weight_decay) else: self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) return self.optimizer def get_criterion(self, weight=None): if self.loss == 'RMSE' and weight is not None: weight = torch.tensor(weight).to(self.device) self.criterion = WeightedRMSELoss(weight) elif self.loss == 'RMSE' and weight is None: self.criterion = torch.sqrt(nn.MSELoss()) elif self.loss == 'MSE' and weight is not None: weight = torch.tensor(weight).to(self.device) self.criterion = WeightedMSELoss(weight) elif self.loss == 'MSE-CE' and weight is not None: weight = torch.tensor(weight).to(self.device) self.criterion = WeightedMultiOutputLoss(weight) else: self.criterion = nn.MSELoss() return self.criterion