|
from typing import Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.optim as optim
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
from .early_stopping import EarlyStopping
|
|
|
|
|
|
def train_model(model: nn.Module,
|
|
criterion: Union[nn.BCELoss, nn.MSELoss],
|
|
train_loader: DataLoader,
|
|
val_loader: Optional[DataLoader] = None,
|
|
device: Optional[torch.device] = None,
|
|
early_stopping: Optional[EarlyStopping] = None,
|
|
epochs: int = 100,
|
|
lr: float = 0.001,
|
|
verbose: int = 0) -> None:
|
|
|
|
if val_loader is None and early_stopping is not None:
|
|
raise ValueError('Cannot use early_stopping without validation set.')
|
|
|
|
optimizer = optim.Adam(model.parameters(), lr=lr)
|
|
|
|
for epoch in range(epochs):
|
|
|
|
model.train()
|
|
running_loss = 0.0
|
|
|
|
for i, (inputs, labels) in enumerate(train_loader):
|
|
optimizer.zero_grad()
|
|
outputs = model(inputs)
|
|
loss = criterion(outputs, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
running_loss += loss.item()
|
|
|
|
train_loss = running_loss / len(train_loader)
|
|
|
|
if val_loader is None:
|
|
continue
|
|
|
|
val_loss = validate_model(model, val_loader, criterion, device)
|
|
|
|
if early_stopping is None:
|
|
continue
|
|
|
|
early_stopping(val_loss)
|
|
if early_stopping.early_stop:
|
|
return
|
|
|
|
|
|
def validate_model(model, val_loader, criterion, device=None):
|
|
model.eval()
|
|
val_loss = 0.0
|
|
with torch.no_grad():
|
|
for inputs, labels in val_loader:
|
|
if device is not None:
|
|
inputs, labels = inputs.to(device), labels.to(device)
|
|
outputs = model(inputs)
|
|
loss = criterion(outputs, labels)
|
|
val_loss += loss.item()
|
|
|
|
val_loss /= len(val_loader)
|
|
return val_loss
|
|
|
|
|
|
def predict(model, data_loader, device=None):
|
|
model.eval()
|
|
predictions = []
|
|
labels_list = []
|
|
with torch.no_grad():
|
|
for inputs, labels in data_loader:
|
|
outputs = model(inputs)
|
|
|
|
predictions.append(outputs.cpu().numpy())
|
|
labels_list.append(labels.cpu().numpy())
|
|
predictions = np.concatenate(predictions)
|
|
labels_array = np.concatenate(labels_list)
|
|
return labels_array, predictions
|
|
|