Vedmani commited on
Commit
00d5f74
1 Parent(s): 69a2c1e

added utils.py

Browse files
Files changed (1) hide show
  1. utils.py +91 -0
utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from pathlib import Path
6
+
7
+
8
+ def save_model(model, optimizer, epoch, loss, directory, model_name='model', **kwargs):
9
+ """
10
+ Save a PyTorch model checkpoint.
11
+
12
+ Args:
13
+ model: Trained model.
14
+ optimizer: Optimizer used for training.
15
+ epoch: The last epoch the model was trained on.
16
+ loss: The last loss recorded during training.
17
+ directory: The directory where to save the model.
18
+ model_name: Base name for the model file, defaults to 'model'.
19
+ kwargs: Additional keyword arguments representing metrics to be included in the filename.
20
+ To use the function, you would do something like this:
21
+ >>>save_checkpoint(model, optimizer, epoch, loss, './model_dir', f1_score=val_f1score)
22
+ """
23
+ # Create the directory if it does not exist
24
+ Path(directory).mkdir(parents=True, exist_ok=True)
25
+
26
+ # Create the filename
27
+ metrics_str = '_'.join(f'{key}={value:.4f}' for key, value in kwargs.items())
28
+ filename = f'{directory}/{model_name}_epoch={epoch}_loss={loss:.4f}_{metrics_str}.pth'
29
+
30
+ # Save the model checkpoint
31
+ torch.save({
32
+ 'epoch': epoch,
33
+ 'model_state_dict': model.state_dict(),
34
+ 'optimizer_state_dict': optimizer.state_dict(),
35
+ 'loss': loss,
36
+ **kwargs
37
+ }, filename)
38
+
39
+
40
+ def get_device() -> torch.device:
41
+ """
42
+ Retrieves the appropriate Torch device for running computations.
43
+
44
+ Returns:
45
+ torch.device: The Torch device to be used for computations.
46
+
47
+ Raises:
48
+ None
49
+
50
+ Examples:
51
+ >>> device = get_device()
52
+ >>> print(device)
53
+ cuda
54
+
55
+ """
56
+ if torch.cuda.is_available():
57
+ device = "cuda" # NVIDIA GPU
58
+ elif torch.backends.mps.is_available():
59
+ device = "mps" # Apple GPU
60
+ else:
61
+ device = "cpu" # Defaults to CPU if NVIDIA GPU/Apple GPU aren't available
62
+ # print(f"Using {device} device")
63
+ return torch.device(device)
64
+
65
+
66
+ def load_checkpoint(model, optimizer, filename):
67
+ """
68
+ Load a PyTorch model checkpoint.
69
+
70
+ Args:
71
+ model: Model to load the weights into.
72
+ optimizer: Optimizer to load the state into.
73
+ filename: The path of the checkpoint file.
74
+
75
+ Returns:
76
+ The epoch at which training was stopped, the last loss recorded, and any additional metrics.
77
+ """
78
+ checkpoint = torch.load(filename)
79
+ model.load_state_dict(checkpoint['model_state_dict'])
80
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
81
+ epoch = checkpoint['epoch']
82
+ loss = checkpoint['loss']
83
+
84
+ # Extract additional metrics
85
+ metrics = {key: value for key, value in checkpoint.items() if
86
+ key not in ['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss']}
87
+
88
+ return epoch, loss, metrics
89
+
90
+ # To use the function, you would do something like this:
91
+ # epoch, loss, metrics = load_checkpoint(model, optimizer, 'model_checkpoint.pth')