File size: 6,368 Bytes
71bd54f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch 
from .helper_functions import define_optimizer, predict, display_train, eval_test
from tqdm import tqdm
import matplotlib.pyplot as plt


def save_model(model, optimizer, valid_loss, epoch, path='model.pt'):
    torch.save({'valid_loss': valid_loss,
                'model_state_dict': model.state_dict(),
                'epoch': epoch + 1,
                'optimizer': optimizer.state_dict()
                }, path)
    tqdm.write(f'Model saved to ==> {path}')


def save_metrics(train_loss_list, valid_loss_list, global_steps_list, path='metrics.pt'):
    torch.save({'train_loss_list': train_loss_list,
                'valid_loss_list': valid_loss_list,
                'global_steps_list': global_steps_list,
                }, path)

def plot_losses(metrics_save_name='metrics', save_dir='./'):
    path = f'{save_dir}metrics_{metrics_save_name}.pt'
    state = torch.load(path)

    train_loss_list = state['train_loss_list']
    valid_loss_list = state['valid_loss_list']
    global_steps_list = state['global_steps_list']

    plt.plot(global_steps_list, train_loss_list, label='Train')
    plt.plot(global_steps_list, valid_loss_list, label='Valid')
    plt.xlabel('Global Steps')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

def trainer(model, train_loader, test_loader, valid_loader, num_epochs = 10, lr = 0.01, alpha = 0.99, eval_interval = 10, model_save_name='', save_dir='./'):
    
    # Use GPU if available, else use CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)


    # History for train acc, test acc
    train_accs = []
    valid_accs = []
    global_step = 0
    train_loss_list = []
    valid_loss_list = [] 
    global_steps_list = []
    best_valid_loss = float("inf") 

    
    # Define optimizer
    optimizer = define_optimizer(model, lr, alpha)
    
    # Training model
    for epoch in range(num_epochs):
        # Go trough all samples in train dataset
        model.train()
        running_loss = 0
        correct = 0
        total = 0
        for i, (inputs, labels, notes) in enumerate(train_loader):
            # Get from dataloader and send to device
            inputs = inputs.transpose(1,2).float().to(device)
            # print(labels.shape)
            labels = labels.float().to(device)
            notes = notes.to(device)
            # print(labels.shape)


            # Forward pass
            outputs, predicted = predict(model, inputs, notes, device)
            # print(predicted.shape, labels.shape)
            
            # Check if predicted class matches label and count numbler of correct predictions
            total += labels.size(0)
            #TODO: change acc criteria
            # correct += torch.nn.functional.cosine_similarity(labels,predicted).sum().item()  #(predicted == labels).sum().item()
            values, indices = torch.max(outputs,dim=1)
            correct += sum(1 for s, i in enumerate(indices)
                             if labels[s][i] == 1)
            # Compute loss
            # we use outputs before softmax function to the cross_entropy loss
            loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, labels)
            running_loss += loss.item()*len(labels)
            global_step += 1*len(inputs)
            # Backward and optimize
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Display losses over iterations and evaluate on validation set
            if (i+1) % eval_interval == 0:
                train_accuracy, valid_accuracy, valid_loss = display_train(epoch, num_epochs, i, model, \
                                                               correct, total, loss, \
                                                               train_loader, valid_loader, device)
                
                average_train_loss = running_loss / total
                # average_valid_loss = valid_loss
                train_loss_list.append(average_train_loss)
                valid_loss_list.append(valid_loss)
                global_steps_list.append(global_step)

                if valid_loss < best_valid_loss:
                    best_valid_loss = valid_loss
                    save_model(model, optimizer, best_valid_loss, epoch, path=f'{save_dir}model_{model_save_name}.pt')
                    save_metrics(train_loss_list, valid_loss_list, global_steps_list, path=f'{save_dir}metrics_{model_save_name}.pt')
                    # torch.save(model.state_dict(),  f'./ckpt_mid/{model.name}_best_lr_{lr}.pt')
                
                
        if(len(train_loader)%eval_interval!=0):
            train_accuracy, valid_accuracy, valid_loss = display_train(epoch, num_epochs, i, model, \
                                                                    correct, total, loss, \
                                                                    train_loader, valid_loader, device)
            
            average_train_loss = running_loss / total
            # average_valid_loss = valid_loss/len(valid_loader.dataset)
            train_loss_list.append(average_train_loss)
            valid_loss_list.append(valid_loss)
            global_steps_list.append(global_step)

            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                save_model(model, optimizer, best_valid_loss, epoch, path=f'{save_dir}model_{model_save_name}.pt')
                save_metrics(train_loss_list, valid_loss_list, global_steps_list, path=f'{save_dir}metrics_{model_save_name}.pt')
                # torch.save(model.state_dict(),  f'./ckpt_mid/{model.name}_best_lr_{lr}.pt')
        # Append accuracies to list at the end of each iteration
        train_accs.append(train_accuracy)
        valid_accs.append(valid_accuracy)
        # torch.save(model.state_dict(), f'./ckpt_mid/{model.name}_epoch_{epoch}_lr_{lr}.pt')
    save_metrics(train_loss_list, valid_loss_list, global_steps_list,
                 path=f'{save_dir}metrics_{model_save_name}.pt')
    # Load best_model
    checkpoint = torch.load(f'{save_dir}model_{model_save_name}.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    # Evaluate on test after training has completed
    test_acc = eval_test(model, test_loader, device)
    # Return
    return train_accs, valid_accs, test_acc