from typing import Optional, Union import numpy as np import torch import torch.optim as optim from torch import nn from torch.utils.data import DataLoader # from utils import training_progress from .early_stopping import EarlyStopping def train_model(model: nn.Module, criterion: Union[nn.BCELoss, nn.MSELoss], train_loader: DataLoader, val_loader: Optional[DataLoader] = None, device: Optional[torch.device] = None, early_stopping: Optional[EarlyStopping] = None, epochs: int = 100, lr: float = 0.001, verbose: int = 0) -> None: # TODO Implement verbosity in train model if val_loader is None and early_stopping is not None: raise ValueError('Cannot use early_stopping without validation set.') optimizer = optim.Adam(model.parameters(), lr=lr) for epoch in range(epochs): model.train() running_loss = 0.0 for i, (inputs, labels) in enumerate(train_loader): optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() train_loss = running_loss / len(train_loader) if val_loader is None: continue val_loss = validate_model(model, val_loader, criterion, device) if early_stopping is None: continue early_stopping(val_loss) if early_stopping.early_stop: return def validate_model(model, val_loader, criterion, device=None): model.eval() val_loss = 0.0 with torch.no_grad(): for inputs, labels in val_loader: if device is not None: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) loss = criterion(outputs, labels) val_loss += loss.item() val_loss /= len(val_loader) return val_loss def predict(model, data_loader, device=None): model.eval() predictions = [] labels_list = [] with torch.no_grad(): for inputs, labels in data_loader: outputs = model(inputs) # TODO check if it is the efficient way to accumulate predictions predictions.append(outputs.cpu().numpy()) labels_list.append(labels.cpu().numpy()) predictions = np.concatenate(predictions) labels_array = np.concatenate(labels_list) return labels_array, predictions