File size: 2,641 Bytes
25e7dcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from typing import Optional, Union
import numpy as np
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader
# from utils import training_progress
from .early_stopping import EarlyStopping
def train_model(model: nn.Module,
criterion: Union[nn.BCELoss, nn.MSELoss],
train_loader: DataLoader,
val_loader: Optional[DataLoader] = None,
device: Optional[torch.device] = None,
early_stopping: Optional[EarlyStopping] = None,
epochs: int = 100,
lr: float = 0.001,
verbose: int = 0) -> None:
# TODO Implement verbosity in train model
if val_loader is None and early_stopping is not None:
raise ValueError('Cannot use early_stopping without validation set.')
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
model.train()
running_loss = 0.0
for i, (inputs, labels) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
train_loss = running_loss / len(train_loader)
if val_loader is None:
continue
val_loss = validate_model(model, val_loader, criterion, device)
if early_stopping is None:
continue
early_stopping(val_loss)
if early_stopping.early_stop:
return
def validate_model(model, val_loader, criterion, device=None):
model.eval()
val_loss = 0.0
with torch.no_grad():
for inputs, labels in val_loader:
if device is not None:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item()
val_loss /= len(val_loader)
return val_loss
def predict(model, data_loader, device=None):
model.eval()
predictions = []
labels_list = []
with torch.no_grad():
for inputs, labels in data_loader:
outputs = model(inputs)
# TODO check if it is the efficient way to accumulate predictions
predictions.append(outputs.cpu().numpy())
labels_list.append(labels.cpu().numpy())
predictions = np.concatenate(predictions)
labels_array = np.concatenate(labels_list)
return labels_array, predictions
|