Spaces:
Running
Running
import copy | |
import numpy as np | |
import time | |
import torch | |
import torchvision.transforms.functional as F | |
import matplotlib.pyplot as plt | |
from modules.eval import main_evaluation | |
from torch.optim import SGD, AdamW | |
from torchvision.models.detection import keypointrcnn_resnet50_fpn | |
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor | |
from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor | |
from torchvision.models.detection import fasterrcnn_resnet50_fpn | |
from tqdm import tqdm | |
from modules.utils import write_results | |
def get_arrow_model(num_classes, num_keypoints=2): | |
""" | |
Configures and returns a modified Keypoint R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes and keypoints. | |
Parameters: | |
- num_classes (int): Number of classes for the model to detect, excluding the background class. | |
- num_keypoints (int): Number of keypoints to predict for each detected object. | |
Returns: | |
- model (torch.nn.Module): The modified Keypoint R-CNN model. | |
""" | |
# Load a model pre-trained on COCO, initialized without pre-trained weights | |
model = keypointrcnn_resnet50_fpn(weights=None) | |
# Get the number of input features for the classifier in the box predictor. | |
in_features = model.roi_heads.box_predictor.cls_score.in_features | |
# Replace the box predictor in the ROI heads with a new one, tailored to the number of classes. | |
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) | |
# Replace the keypoint predictor in the ROI heads with a new one, specifically designed for the desired number of keypoints. | |
model.roi_heads.keypoint_predictor = KeypointRCNNPredictor(512, num_keypoints) | |
return model | |
def get_faster_rcnn_model(num_classes): | |
""" | |
Configures and returns a modified Faster R-CNN model based on ResNet-50 with FPN, adapted for a custom number of classes. | |
Parameters: | |
- num_classes (int): Number of classes for the model to detect, including the background class. | |
Returns: | |
- model (torch.nn.Module): The modified Faster R-CNN model. | |
""" | |
# Load a pre-trained Faster R-CNN model | |
model = fasterrcnn_resnet50_fpn(weights=None) | |
# Get the number of input features for the classifier in the box predictor | |
in_features = model.roi_heads.box_predictor.cls_score.in_features | |
# Replace the box predictor with a new one, tailored to the number of classes (num_classes includes the background) | |
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) | |
return model | |
def prepare_model(dict, opti, learning_rate=0.0003, model_to_load=None, model_type='object'): | |
""" | |
Prepares the model and optimizer for training. | |
Parameters: | |
- dict (dict): Dictionary of classes. | |
- opti (str): Optimizer type ('SGD' or 'Adam'). | |
- learning_rate (float): Learning rate for the optimizer. | |
- model_to_load (str, optional): Name of the model to load. | |
- model_type (str): Type of model to prepare ('object' or 'arrow'). | |
Returns: | |
- model (torch.nn.Module): The prepared model. | |
- optimizer (torch.optim.Optimizer): The configured optimizer. | |
- device (torch.device): The device (CPU or CUDA) on which to perform training. | |
""" | |
# Adjusted to pass the class_dict directly | |
if model_type == 'object': | |
model = get_faster_rcnn_model(len(dict)) | |
elif model_type == 'arrow': | |
model = get_arrow_model(len(dict), 2) | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
# Load the model weights | |
if model_to_load: | |
model.load_state_dict(torch.load(model_to_load + '.pth', map_location=device)) | |
print(f"Model '{model_to_load}' loaded") | |
model.to(device) | |
if opti == 'SGD': | |
optimizer = SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001) | |
elif opti == 'Adam': | |
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.001, eps=1e-08, betas=(0.9, 0.999)) | |
else: | |
print('Optimizer not found') | |
return model, optimizer, device | |
def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=False): | |
""" | |
Evaluate the loss of the model on a validation dataset. | |
Parameters: | |
- model (torch.nn.Module): The model to evaluate. | |
- data_loader (torch.utils.data.DataLoader): DataLoader for the validation dataset. | |
- device (torch.device): Device to perform evaluation on. | |
- loss_config (dict, optional): Configuration specifying which losses to use. | |
- print_losses (bool): Whether to print individual loss components. | |
Returns: | |
- float: Average loss over the validation dataset. | |
""" | |
model.train() # Set the model to evaluation mode | |
total_loss = 0 | |
# Initialize lists to keep track of individual losses | |
loss_classifier_list = [] | |
loss_box_reg_list = [] | |
loss_objectness_list = [] | |
loss_rpn_box_reg_list = [] | |
loss_keypoints_list = [] | |
with torch.no_grad(): # Disable gradient computation | |
for images, targets_im in tqdm(data_loader, desc="Evaluating"): | |
images = [image.to(device) for image in images] | |
targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im] | |
loss_dict = model(images, targets) | |
# Calculate the total loss for the current batch | |
losses = 0 | |
if loss_config is not None: | |
for key, loss in loss_dict.items(): | |
if loss_config.get(key, False): | |
losses += loss | |
else: | |
losses = sum(loss for key, loss in loss_dict.items()) | |
total_loss += losses.item() | |
# Collect individual losses | |
if loss_dict.get('loss_classifier') is not None: | |
loss_classifier_list.append(loss_dict['loss_classifier'].item()) | |
else: | |
loss_classifier_list.append(0) | |
if loss_dict.get('loss_box_reg') is not None: | |
loss_box_reg_list.append(loss_dict['loss_box_reg'].item()) | |
else: | |
loss_box_reg_list.append(0) | |
if loss_dict.get('loss_objectness') is not None: | |
loss_objectness_list.append(loss_dict['loss_objectness'].item()) | |
else: | |
loss_objectness_list.append(0) | |
if loss_dict.get('loss_rpn_box_reg') is not None: | |
loss_rpn_box_reg_list.append(loss_dict['loss_rpn_box_reg'].item()) | |
else: | |
loss_rpn_box_reg_list.append(0) | |
if 'loss_keypoint' in loss_dict: | |
loss_keypoints_list.append(loss_dict['loss_keypoint'].item()) | |
else: | |
loss_keypoints_list.append(0) | |
# Calculate average loss | |
avg_loss = total_loss / len(data_loader) | |
avg_loss_classifier = np.mean(loss_classifier_list) | |
avg_loss_box_reg = np.mean(loss_box_reg_list) | |
avg_loss_objectness = np.mean(loss_objectness_list) | |
avg_loss_rpn_box_reg = np.mean(loss_rpn_box_reg_list) | |
avg_loss_keypoints = np.mean(loss_keypoints_list) | |
if print_losses: | |
print(f"Average Loss: {avg_loss:.4f}") | |
print(f"Average Classifier Loss: {avg_loss_classifier:.4f}") | |
print(f"Average Box Regression Loss: {avg_loss_box_reg:.4f}") | |
print(f"Average Objectness Loss: {avg_loss_objectness:.4f}") | |
print(f"Average RPN Box Regression Loss: {avg_loss_rpn_box_reg:.4f}") | |
print(f"Average Keypoints Loss: {avg_loss_keypoints:.4f}") | |
return avg_loss | |
def training_model(num_epochs, model, data_loader, subset_test_loader, | |
optimizer, model_to_load=None, change_learning_rate=100, start_key=100, save_every=5, | |
parameters=None, blur_prob=0.02, | |
score_threshold=0.7, iou_threshold=0.5, early_stop_f1_score=0.97, | |
information_training='training', start_epoch=0, loss_config=None, model_type='object', | |
eval_metric='f1_score', device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')): | |
# Set the model to training mode | |
model.train() | |
if loss_config is None: | |
print('No loss config found, all losses will be used.') | |
else: | |
# Print the list of the losses that will be used | |
print('The following losses will be used: ', end='') | |
for key, value in loss_config.items(): | |
if value: | |
print(key, end=", ") | |
print() | |
# Initialize lists to store epoch-wise average losses and other metrics | |
epoch_avg_losses = [] | |
epoch_avg_loss_classifier = [] | |
epoch_avg_loss_box_reg = [] | |
epoch_avg_loss_objectness = [] | |
epoch_avg_loss_rpn_box_reg = [] | |
epoch_avg_loss_keypoints = [] | |
epoch_precision = [] | |
epoch_recall = [] | |
epoch_f1_score = [] | |
epoch_test_loss = [] | |
start_tot = time.time() | |
best_metric_value = -1000 | |
best_epoch = 0 | |
best_model_state = None | |
epochs_with_high_f1 = 0 | |
learning_rate = optimizer.param_groups[0]['lr'] | |
bad_test_loss_epochs = 0 | |
previous_test_loss = 1000 | |
if parameters is not None: | |
batch_size, crop_prob, rotate_90_proba, h_flip_prob, v_flip_prob, max_rotate_deg, rotate_proba, keep_ratio = parameters.values() | |
print(f"Let's go training {model_type} model with {num_epochs} epochs!") | |
if parameters is not None: | |
print(f"Learning rate: {learning_rate}, Batch size: {batch_size}, Crop prob: {crop_prob}, H flip prob: {h_flip_prob}, V flip prob: {v_flip_prob}, Max rotate deg: {max_rotate_deg}, Rotate proba: {rotate_proba}, Rotate 90 proba: {rotate_90_proba}, Keep ratio: {keep_ratio}") | |
for epoch in range(num_epochs): | |
if (epoch > 0 and epoch % change_learning_rate == 0) or bad_test_loss_epochs >= 2: | |
learning_rate *= 0.7 | |
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=learning_rate, eps=1e-08, betas=(0.9, 0.999)) | |
if best_model_state is not None: | |
model.load_state_dict(best_model_state) | |
print(f'Learning rate changed to {learning_rate:.4} and the best epoch for now is {best_epoch}') | |
bad_test_loss_epochs = 0 | |
if epoch > 0 and epoch == start_key: | |
print("Now it's training Keypoints also") | |
loss_config['loss_keypoint'] = True | |
for name, param in model.named_parameters(): | |
if 'keypoint' in name: | |
param.requires_grad = True | |
model.train() | |
start = time.time() | |
total_loss = 0 | |
# Initialize lists to keep track of individual losses | |
loss_classifier_list = [] | |
loss_box_reg_list = [] | |
loss_objectness_list = [] | |
loss_rpn_box_reg_list = [] | |
loss_keypoints_list = [] | |
# Create a tqdm progress bar | |
progress_bar = tqdm(data_loader, desc=f'Epoch {epoch+1+start_epoch}') | |
for images, targets_im in progress_bar: | |
images = [image.to(device) for image in images] | |
targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im] | |
optimizer.zero_grad() | |
loss_dict = model(images, targets) | |
# Inside the training loop where losses are calculated: | |
losses = 0 | |
if loss_config is not None: | |
for key, loss in loss_dict.items(): | |
if loss_config.get(key, False): | |
if key == 'loss_classifier': | |
loss *= 3 | |
losses += loss | |
else: | |
losses = sum(loss for key, loss in loss_dict.items()) | |
# Collect individual losses | |
loss_classifier_list.append(loss_dict.get('loss_classifier', torch.tensor(0)).item()) | |
loss_box_reg_list.append(loss_dict.get('loss_box_reg', torch.tensor(0)).item()) | |
loss_objectness_list.append(loss_dict.get('loss_objectness', torch.tensor(0)).item()) | |
loss_rpn_box_reg_list.append(loss_dict.get('loss_rpn_box_reg', torch.tensor(0)).item()) | |
loss_keypoints_list.append(loss_dict.get('loss_keypoint', torch.tensor(0)).item()) | |
losses.backward() | |
optimizer.step() | |
total_loss += losses.item() | |
# Update the description with the current loss | |
progress_bar.set_description(f'Epoch {epoch+1+start_epoch}, Loss: {losses.item():.4f}') | |
# Calculate average loss | |
avg_loss = total_loss / len(data_loader) | |
epoch_avg_losses.append(avg_loss) | |
epoch_avg_loss_classifier.append(np.mean(loss_classifier_list)) | |
epoch_avg_loss_box_reg.append(np.mean(loss_box_reg_list)) | |
epoch_avg_loss_objectness.append(np.mean(loss_objectness_list)) | |
epoch_avg_loss_rpn_box_reg.append(np.mean(loss_rpn_box_reg_list)) | |
epoch_avg_loss_keypoints.append(np.mean(loss_keypoints_list)) | |
# Evaluate the model on the test set | |
if eval_metric == 'loss': | |
labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = 0, 0, 0, 0, 0, 0 | |
avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config) | |
print(f"Epoch {epoch+1+start_epoch}, Average Training Loss: {avg_loss:.4f}, Average Test Loss: {avg_test_loss:.4f}", end=", ") | |
else: | |
avg_test_loss = 0 | |
labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, subset_test_loader, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_threshold=10, key_correction=False, model_type=model_type) | |
print(f"Epoch {epoch+1+start_epoch}, Average Loss: {avg_loss:.4f}, Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f} ", end=", ") | |
avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config) | |
print(f"Epoch {epoch+1+start_epoch}, Average Test Loss: {avg_test_loss:.4f}", end=", ") | |
print(f"Time: {time.time() - start:.2f} [s]") | |
if eval_metric == 'f1_score': | |
metric_used = f1_score | |
elif eval_metric == 'precision': | |
metric_used = precision | |
elif eval_metric == 'recall': | |
metric_used = recall | |
else: | |
metric_used = -avg_test_loss | |
# Check if this epoch's model has the best evaluation metric | |
if metric_used > best_metric_value: | |
best_metric_value = metric_used | |
best_epoch = epoch + 1 + start_epoch | |
best_model_state = copy.deepcopy(model.state_dict()) | |
if epoch > 0 and f1_score > early_stop_f1_score: | |
epochs_with_high_f1 += 1 | |
epoch_precision.append(precision) | |
epoch_recall.append(recall) | |
epoch_f1_score.append(f1_score) | |
epoch_test_loss.append(avg_test_loss) | |
name_model = f"model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}" | |
metrics_list = [epoch_avg_losses, epoch_avg_loss_classifier, epoch_avg_loss_box_reg, epoch_avg_loss_objectness, epoch_avg_loss_rpn_box_reg, epoch_avg_loss_keypoints, epoch_precision, epoch_recall, epoch_f1_score, epoch_test_loss] | |
if epochs_with_high_f1 >= 1: | |
torch.save(best_model_state, './models/' + name_model + '.pth') | |
write_results(name_model, metrics_list, start_epoch) | |
break | |
if (epoch + 1 + start_epoch) % save_every == 0: | |
torch.save(best_model_state, './models/' + name_model + '.pth') | |
model.load_state_dict(best_model_state) | |
write_results(name_model, metrics_list, start_epoch) | |
if avg_test_loss > previous_test_loss: | |
bad_test_loss_epochs += 1 | |
previous_test_loss = avg_test_loss | |
print(f"\nTotal time: {(time.time() - start_tot) / 60:.2f} minutes, Best Epoch is {best_epoch} with an {eval_metric} of {best_metric_value:.4f}") | |
if best_model_state: | |
torch.save(best_model_state, './models/' + name_model + '.pth') | |
model.load_state_dict(best_model_state) | |
write_results(name_model, metrics_list, start_epoch) | |
print(f"Name of the best model: {name_model}") | |
return model | |