BenjiELCA's picture
commit with training and evaluation code
2da5c78
import copy
import numpy as np
import time
import torch
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt
from modules.eval import main_evaluation
from torch.optim import SGD, AdamW
from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from tqdm import tqdm
from modules.utils import write_results
def get_arrow_model(num_classes, num_keypoints=2):
"""
Configures and returns a modified Keypoint R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes and keypoints.
Parameters:
- num_classes (int): Number of classes for the model to detect, excluding the background class.
- num_keypoints (int): Number of keypoints to predict for each detected object.
Returns:
- model (torch.nn.Module): The modified Keypoint R-CNN model.
"""
# Load a model pre-trained on COCO, initialized without pre-trained weights
model = keypointrcnn_resnet50_fpn(weights=None)
# Get the number of input features for the classifier in the box predictor.
in_features = model.roi_heads.box_predictor.cls_score.in_features
# Replace the box predictor in the ROI heads with a new one, tailored to the number of classes.
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# Replace the keypoint predictor in the ROI heads with a new one, specifically designed for the desired number of keypoints.
model.roi_heads.keypoint_predictor = KeypointRCNNPredictor(512, num_keypoints)
return model
def get_faster_rcnn_model(num_classes):
"""
Configures and returns a modified Faster R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes.
Parameters:
- num_classes (int): Number of classes for the model to detect, including the background class.
Returns:
- model (torch.nn.Module): The modified Faster R-CNN model.
"""
# Load a pre-trained Faster R-CNN model
model = fasterrcnn_resnet50_fpn(weights=None)
# Get the number of input features for the classifier in the box predictor
in_features = model.roi_heads.box_predictor.cls_score.in_features
# Replace the box predictor with a new one, tailored to the number of classes (num_classes includes the background)
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
def prepare_model(dict, opti, learning_rate=0.0003, model_to_load=None, model_type='object'):
"""
Prepares the model and optimizer for training.
Parameters:
- dict (dict): Dictionary of classes.
- opti (str): Optimizer type ('SGD' or 'Adam').
- learning_rate (float): Learning rate for the optimizer.
- model_to_load (str, optional): Name of the model to load.
- model_type (str): Type of model to prepare ('object' or 'arrow').
Returns:
- model (torch.nn.Module): The prepared model.
- optimizer (torch.optim.Optimizer): The configured optimizer.
- device (torch.device): The device (CPU or CUDA) on which to perform training.
"""
# Adjusted to pass the class_dict directly
if model_type == 'object':
model = get_faster_rcnn_model(len(dict))
elif model_type == 'arrow':
model = get_arrow_model(len(dict), 2)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# Load the model weights
if model_to_load:
model.load_state_dict(torch.load(model_to_load + '.pth', map_location=device))
print(f"Model '{model_to_load}' loaded")
model.to(device)
if opti == 'SGD':
optimizer = SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001)
elif opti == 'Adam':
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.001, eps=1e-08, betas=(0.9, 0.999))
else:
print('Optimizer not found')
return model, optimizer, device
def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=False):
"""
Evaluate the loss of the model on a validation dataset.
Parameters:
- model (torch.nn.Module): The model to evaluate.
- data_loader (torch.utils.data.DataLoader): DataLoader for the validation dataset.
- device (torch.device): Device to perform evaluation on.
- loss_config (dict, optional): Configuration specifying which losses to use.
- print_losses (bool): Whether to print individual loss components.
Returns:
- float: Average loss over the validation dataset.
"""
model.train() # Set the model to evaluation mode
total_loss = 0
# Initialize lists to keep track of individual losses
loss_classifier_list = []
loss_box_reg_list = []
loss_objectness_list = []
loss_rpn_box_reg_list = []
loss_keypoints_list = []
with torch.no_grad(): # Disable gradient computation
for images, targets_im in tqdm(data_loader, desc="Evaluating"):
images = [image.to(device) for image in images]
targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im]
loss_dict = model(images, targets)
# Calculate the total loss for the current batch
losses = 0
if loss_config is not None:
for key, loss in loss_dict.items():
if loss_config.get(key, False):
losses += loss
else:
losses = sum(loss for key, loss in loss_dict.items())
total_loss += losses.item()
# Collect individual losses
if loss_dict.get('loss_classifier') is not None:
loss_classifier_list.append(loss_dict['loss_classifier'].item())
else:
loss_classifier_list.append(0)
if loss_dict.get('loss_box_reg') is not None:
loss_box_reg_list.append(loss_dict['loss_box_reg'].item())
else:
loss_box_reg_list.append(0)
if loss_dict.get('loss_objectness') is not None:
loss_objectness_list.append(loss_dict['loss_objectness'].item())
else:
loss_objectness_list.append(0)
if loss_dict.get('loss_rpn_box_reg') is not None:
loss_rpn_box_reg_list.append(loss_dict['loss_rpn_box_reg'].item())
else:
loss_rpn_box_reg_list.append(0)
if 'loss_keypoint' in loss_dict:
loss_keypoints_list.append(loss_dict['loss_keypoint'].item())
else:
loss_keypoints_list.append(0)
# Calculate average loss
avg_loss = total_loss / len(data_loader)
avg_loss_classifier = np.mean(loss_classifier_list)
avg_loss_box_reg = np.mean(loss_box_reg_list)
avg_loss_objectness = np.mean(loss_objectness_list)
avg_loss_rpn_box_reg = np.mean(loss_rpn_box_reg_list)
avg_loss_keypoints = np.mean(loss_keypoints_list)
if print_losses:
print(f"Average Loss: {avg_loss:.4f}")
print(f"Average Classifier Loss: {avg_loss_classifier:.4f}")
print(f"Average Box Regression Loss: {avg_loss_box_reg:.4f}")
print(f"Average Objectness Loss: {avg_loss_objectness:.4f}")
print(f"Average RPN Box Regression Loss: {avg_loss_rpn_box_reg:.4f}")
print(f"Average Keypoints Loss: {avg_loss_keypoints:.4f}")
return avg_loss
def training_model(num_epochs, model, data_loader, subset_test_loader,
optimizer, model_to_load=None, change_learning_rate=100, start_key=100, save_every=5,
parameters=None, blur_prob=0.02,
score_threshold=0.7, iou_threshold=0.5, early_stop_f1_score=0.97,
information_training='training', start_epoch=0, loss_config=None, model_type='object',
eval_metric='f1_score', device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')):
# Set the model to training mode
model.train()
if loss_config is None:
print('No loss config found, all losses will be used.')
else:
# Print the list of the losses that will be used
print('The following losses will be used: ', end='')
for key, value in loss_config.items():
if value:
print(key, end=", ")
print()
# Initialize lists to store epoch-wise average losses and other metrics
epoch_avg_losses = []
epoch_avg_loss_classifier = []
epoch_avg_loss_box_reg = []
epoch_avg_loss_objectness = []
epoch_avg_loss_rpn_box_reg = []
epoch_avg_loss_keypoints = []
epoch_precision = []
epoch_recall = []
epoch_f1_score = []
epoch_test_loss = []
start_tot = time.time()
best_metric_value = -1000
best_epoch = 0
best_model_state = None
epochs_with_high_f1 = 0
learning_rate = optimizer.param_groups[0]['lr']
bad_test_loss_epochs = 0
previous_test_loss = 1000
if parameters is not None:
batch_size, crop_prob, rotate_90_proba, h_flip_prob, v_flip_prob, max_rotate_deg, rotate_proba, keep_ratio = parameters.values()
print(f"Let's go training {model_type} model with {num_epochs} epochs!")
if parameters is not None:
print(f"Learning rate: {learning_rate}, Batch size: {batch_size}, Crop prob: {crop_prob}, H flip prob: {h_flip_prob}, V flip prob: {v_flip_prob}, Max rotate deg: {max_rotate_deg}, Rotate proba: {rotate_proba}, Rotate 90 proba: {rotate_90_proba}, Keep ratio: {keep_ratio}")
for epoch in range(num_epochs):
if (epoch > 0 and epoch % change_learning_rate == 0) or bad_test_loss_epochs >= 2:
learning_rate *= 0.7
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=learning_rate, eps=1e-08, betas=(0.9, 0.999))
if best_model_state is not None:
model.load_state_dict(best_model_state)
print(f'Learning rate changed to {learning_rate:.4} and the best epoch for now is {best_epoch}')
bad_test_loss_epochs = 0
if epoch > 0 and epoch == start_key:
print("Now it's training Keypoints also")
loss_config['loss_keypoint'] = True
for name, param in model.named_parameters():
if 'keypoint' in name:
param.requires_grad = True
model.train()
start = time.time()
total_loss = 0
# Initialize lists to keep track of individual losses
loss_classifier_list = []
loss_box_reg_list = []
loss_objectness_list = []
loss_rpn_box_reg_list = []
loss_keypoints_list = []
# Create a tqdm progress bar
progress_bar = tqdm(data_loader, desc=f'Epoch {epoch+1+start_epoch}')
for images, targets_im in progress_bar:
images = [image.to(device) for image in images]
targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im]
optimizer.zero_grad()
loss_dict = model(images, targets)
# Inside the training loop where losses are calculated:
losses = 0
if loss_config is not None:
for key, loss in loss_dict.items():
if loss_config.get(key, False):
if key == 'loss_classifier':
loss *= 3
losses += loss
else:
losses = sum(loss for key, loss in loss_dict.items())
# Collect individual losses
loss_classifier_list.append(loss_dict.get('loss_classifier', torch.tensor(0)).item())
loss_box_reg_list.append(loss_dict.get('loss_box_reg', torch.tensor(0)).item())
loss_objectness_list.append(loss_dict.get('loss_objectness', torch.tensor(0)).item())
loss_rpn_box_reg_list.append(loss_dict.get('loss_rpn_box_reg', torch.tensor(0)).item())
loss_keypoints_list.append(loss_dict.get('loss_keypoint', torch.tensor(0)).item())
losses.backward()
optimizer.step()
total_loss += losses.item()
# Update the description with the current loss
progress_bar.set_description(f'Epoch {epoch+1+start_epoch}, Loss: {losses.item():.4f}')
# Calculate average loss
avg_loss = total_loss / len(data_loader)
epoch_avg_losses.append(avg_loss)
epoch_avg_loss_classifier.append(np.mean(loss_classifier_list))
epoch_avg_loss_box_reg.append(np.mean(loss_box_reg_list))
epoch_avg_loss_objectness.append(np.mean(loss_objectness_list))
epoch_avg_loss_rpn_box_reg.append(np.mean(loss_rpn_box_reg_list))
epoch_avg_loss_keypoints.append(np.mean(loss_keypoints_list))
# Evaluate the model on the test set
if eval_metric == 'loss':
labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = 0, 0, 0, 0, 0, 0
avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
print(f"Epoch {epoch+1+start_epoch}, Average Training Loss: {avg_loss:.4f}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
else:
avg_test_loss = 0
labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, subset_test_loader, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_threshold=10, key_correction=False, model_type=model_type)
print(f"Epoch {epoch+1+start_epoch}, Average Loss: {avg_loss:.4f}, Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f} ", end=", ")
avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
print(f"Epoch {epoch+1+start_epoch}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
print(f"Time: {time.time() - start:.2f} [s]")
if eval_metric == 'f1_score':
metric_used = f1_score
elif eval_metric == 'precision':
metric_used = precision
elif eval_metric == 'recall':
metric_used = recall
else:
metric_used = -avg_test_loss
# Check if this epoch's model has the best evaluation metric
if metric_used > best_metric_value:
best_metric_value = metric_used
best_epoch = epoch + 1 + start_epoch
best_model_state = copy.deepcopy(model.state_dict())
if epoch > 0 and f1_score > early_stop_f1_score:
epochs_with_high_f1 += 1
epoch_precision.append(precision)
epoch_recall.append(recall)
epoch_f1_score.append(f1_score)
epoch_test_loss.append(avg_test_loss)
name_model = f"model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}"
metrics_list = [epoch_avg_losses, epoch_avg_loss_classifier, epoch_avg_loss_box_reg, epoch_avg_loss_objectness, epoch_avg_loss_rpn_box_reg, epoch_avg_loss_keypoints, epoch_precision, epoch_recall, epoch_f1_score, epoch_test_loss]
if epochs_with_high_f1 >= 1:
torch.save(best_model_state, './models/' + name_model + '.pth')
write_results(name_model, metrics_list, start_epoch)
break
if (epoch + 1 + start_epoch) % save_every == 0:
torch.save(best_model_state, './models/' + name_model + '.pth')
model.load_state_dict(best_model_state)
write_results(name_model, metrics_list, start_epoch)
if avg_test_loss > previous_test_loss:
bad_test_loss_epochs += 1
previous_test_loss = avg_test_loss
print(f"\nTotal time: {(time.time() - start_tot) / 60:.2f} minutes, Best Epoch is {best_epoch} with an {eval_metric} of {best_metric_value:.4f}")
if best_model_state:
torch.save(best_model_state, './models/' + name_model + '.pth')
model.load_state_dict(best_model_state)
write_results(name_model, metrics_list, start_epoch)
print(f"Name of the best model: {name_model}")
return model