File size: 2,641 Bytes
25e7dcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import Optional, Union

import numpy as np
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader

# from utils import training_progress
from .early_stopping import EarlyStopping


def train_model(model: nn.Module,

                criterion: Union[nn.BCELoss, nn.MSELoss],

                train_loader: DataLoader,

                val_loader: Optional[DataLoader] = None,

                device: Optional[torch.device] = None,

                early_stopping: Optional[EarlyStopping] = None,

                epochs: int = 100,

                lr: float = 0.001,

                verbose: int = 0) -> None:
    # TODO Implement verbosity in train model
    if val_loader is None and early_stopping is not None:
        raise ValueError('Cannot use early_stopping without validation set.')

    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):

        model.train()
        running_loss = 0.0

        for i, (inputs, labels) in enumerate(train_loader):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        train_loss = running_loss / len(train_loader)

        if val_loader is None:
            continue

        val_loss = validate_model(model, val_loader, criterion, device)

        if early_stopping is None:
            continue

        early_stopping(val_loss)
        if early_stopping.early_stop:
            return


def validate_model(model, val_loader, criterion, device=None):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            if device is not None:
                inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

    val_loss /= len(val_loader)
    return val_loss


def predict(model, data_loader, device=None):
    model.eval()
    predictions = []
    labels_list = []
    with torch.no_grad():
        for inputs, labels in data_loader:
            outputs = model(inputs)
            # TODO check if it is the efficient way to accumulate predictions
            predictions.append(outputs.cpu().numpy())
            labels_list.append(labels.cpu().numpy())
    predictions = np.concatenate(predictions)
    labels_array = np.concatenate(labels_list)
    return labels_array, predictions