suchkow's picture
Upload 48 files
25e7dcb verified
raw
history blame
3.04 kB
from typing import Type
import numpy as np
import optuna
import torch
from torch import nn
from utils import tts_last_n
from .early_stopping import EarlyStopping
from .loaders import get_dataloader
from .train_eval import train_model, validate_model, predict
from sklearn.metrics import roc_auc_score
import math
def objective_tune_reg_a(trial: optuna.Trial,
model_class: Type[nn.Module],
inputs: np.ndarray,
targets: np.ndarray
) -> float:
criterion = nn.MSELoss()
hidden_size = trial.suggest_categorical('hidden_size', [4, 8, 16, 32, 64, 128])
num_layers = trial.suggest_int('num_layers', 1, 4)
dropout = trial.suggest_float('dropout', 0.05, 0.5, log=True)
learning_rate = trial.suggest_float('learning_rate', 1e-4, 1e-1, log=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model_class(
input_size=inputs.shape[2],
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
output_size=1
)
train_seq, test_seq, train_tar, test_tar = tts_last_n(inputs, targets, 252)
tloader = get_dataloader(train_seq, train_tar, device)
vloader = get_dataloader(test_seq, test_tar, device)
train_model(
model=model,
criterion=criterion,
train_loader=tloader,
val_loader=vloader,
early_stopping=EarlyStopping(5, 1e-6),
epochs=100,
lr=learning_rate,
verbose=0
)
return math.sqrt(validate_model(model, vloader, criterion, device))
def objective_tune_clas_a(trial: optuna.Trial,
model_class: Type[nn.Module],
inputs: np.ndarray,
targets: np.ndarray
) -> float:
criterion = nn.BCELoss()
hidden_size = trial.suggest_categorical('hidden_size', [4, 8, 16, 32, 64, 128])
num_layers = trial.suggest_int('num_layers', 1, 4)
dropout = trial.suggest_float('dropout', 0.05, 0.5, log=True)
learning_rate = trial.suggest_float('learning_rate', 1e-4, 1e-1, log=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model_class(
input_size=inputs.shape[2],
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
output_size=1
)
train_seq, test_seq, train_tar, test_tar = tts_last_n(inputs, targets, 252)
tloader = get_dataloader(train_seq, train_tar, device)
vloader = get_dataloader(test_seq, test_tar, device)
train_model(
model=model,
criterion=criterion,
train_loader=tloader,
val_loader=vloader,
early_stopping=EarlyStopping(5, 1e-6),
epochs=100,
lr=learning_rate,
verbose=0
)
labels, preds = predict(model, vloader, device)
return roc_auc_score(labels, preds)