Spaces:
Sleeping
Sleeping
added utils.py
Browse files
utils.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
|
8 |
+
def save_model(model, optimizer, epoch, loss, directory, model_name='model', **kwargs):
|
9 |
+
"""
|
10 |
+
Save a PyTorch model checkpoint.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
model: Trained model.
|
14 |
+
optimizer: Optimizer used for training.
|
15 |
+
epoch: The last epoch the model was trained on.
|
16 |
+
loss: The last loss recorded during training.
|
17 |
+
directory: The directory where to save the model.
|
18 |
+
model_name: Base name for the model file, defaults to 'model'.
|
19 |
+
kwargs: Additional keyword arguments representing metrics to be included in the filename.
|
20 |
+
To use the function, you would do something like this:
|
21 |
+
>>>save_checkpoint(model, optimizer, epoch, loss, './model_dir', f1_score=val_f1score)
|
22 |
+
"""
|
23 |
+
# Create the directory if it does not exist
|
24 |
+
Path(directory).mkdir(parents=True, exist_ok=True)
|
25 |
+
|
26 |
+
# Create the filename
|
27 |
+
metrics_str = '_'.join(f'{key}={value:.4f}' for key, value in kwargs.items())
|
28 |
+
filename = f'{directory}/{model_name}_epoch={epoch}_loss={loss:.4f}_{metrics_str}.pth'
|
29 |
+
|
30 |
+
# Save the model checkpoint
|
31 |
+
torch.save({
|
32 |
+
'epoch': epoch,
|
33 |
+
'model_state_dict': model.state_dict(),
|
34 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
35 |
+
'loss': loss,
|
36 |
+
**kwargs
|
37 |
+
}, filename)
|
38 |
+
|
39 |
+
|
40 |
+
def get_device() -> torch.device:
|
41 |
+
"""
|
42 |
+
Retrieves the appropriate Torch device for running computations.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
torch.device: The Torch device to be used for computations.
|
46 |
+
|
47 |
+
Raises:
|
48 |
+
None
|
49 |
+
|
50 |
+
Examples:
|
51 |
+
>>> device = get_device()
|
52 |
+
>>> print(device)
|
53 |
+
cuda
|
54 |
+
|
55 |
+
"""
|
56 |
+
if torch.cuda.is_available():
|
57 |
+
device = "cuda" # NVIDIA GPU
|
58 |
+
elif torch.backends.mps.is_available():
|
59 |
+
device = "mps" # Apple GPU
|
60 |
+
else:
|
61 |
+
device = "cpu" # Defaults to CPU if NVIDIA GPU/Apple GPU aren't available
|
62 |
+
# print(f"Using {device} device")
|
63 |
+
return torch.device(device)
|
64 |
+
|
65 |
+
|
66 |
+
def load_checkpoint(model, optimizer, filename):
|
67 |
+
"""
|
68 |
+
Load a PyTorch model checkpoint.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
model: Model to load the weights into.
|
72 |
+
optimizer: Optimizer to load the state into.
|
73 |
+
filename: The path of the checkpoint file.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
The epoch at which training was stopped, the last loss recorded, and any additional metrics.
|
77 |
+
"""
|
78 |
+
checkpoint = torch.load(filename)
|
79 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
80 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
81 |
+
epoch = checkpoint['epoch']
|
82 |
+
loss = checkpoint['loss']
|
83 |
+
|
84 |
+
# Extract additional metrics
|
85 |
+
metrics = {key: value for key, value in checkpoint.items() if
|
86 |
+
key not in ['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss']}
|
87 |
+
|
88 |
+
return epoch, loss, metrics
|
89 |
+
|
90 |
+
# To use the function, you would do something like this:
|
91 |
+
# epoch, loss, metrics = load_checkpoint(model, optimizer, 'model_checkpoint.pth')
|