suchkow's picture
Upload 48 files
25e7dcb verified
raw
history blame
2.64 kB
from typing import Optional, Union
import numpy as np
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader
# from utils import training_progress
from .early_stopping import EarlyStopping
def train_model(model: nn.Module,
criterion: Union[nn.BCELoss, nn.MSELoss],
train_loader: DataLoader,
val_loader: Optional[DataLoader] = None,
device: Optional[torch.device] = None,
early_stopping: Optional[EarlyStopping] = None,
epochs: int = 100,
lr: float = 0.001,
verbose: int = 0) -> None:
# TODO Implement verbosity in train model
if val_loader is None and early_stopping is not None:
raise ValueError('Cannot use early_stopping without validation set.')
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
model.train()
running_loss = 0.0
for i, (inputs, labels) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
train_loss = running_loss / len(train_loader)
if val_loader is None:
continue
val_loss = validate_model(model, val_loader, criterion, device)
if early_stopping is None:
continue
early_stopping(val_loss)
if early_stopping.early_stop:
return
def validate_model(model, val_loader, criterion, device=None):
model.eval()
val_loss = 0.0
with torch.no_grad():
for inputs, labels in val_loader:
if device is not None:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item()
val_loss /= len(val_loader)
return val_loss
def predict(model, data_loader, device=None):
model.eval()
predictions = []
labels_list = []
with torch.no_grad():
for inputs, labels in data_loader:
outputs = model(inputs)
# TODO check if it is the efficient way to accumulate predictions
predictions.append(outputs.cpu().numpy())
labels_list.append(labels.cpu().numpy())
predictions = np.concatenate(predictions)
labels_array = np.concatenate(labels_list)
return labels_array, predictions