Spaces:
Runtime error
Runtime error
import argparse | |
import cv2 | |
import os | |
import torch | |
from torch.nn import DataParallel | |
import torch.optim as optim | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
from datasets.coco import CocoTrainDataset | |
from datasets.transformations import ConvertKeypoints, Scale, Rotate, CropPad, Flip | |
from modules.get_parameters import get_parameters_conv, get_parameters_bn, get_parameters_conv_depthwise | |
from models.with_mobilenet import PoseEstimationWithMobileNet | |
from modules.loss import l2_loss | |
from modules.load_state import load_state, load_from_mobilenet | |
from val import evaluate | |
cv2.setNumThreads(0) | |
cv2.ocl.setUseOpenCL(False) # To prevent freeze of DataLoader | |
def train(prepared_train_labels, train_images_folder, num_refinement_stages, base_lr, batch_size, batches_per_iter, | |
num_workers, checkpoint_path, weights_only, from_mobilenet, checkpoints_folder, log_after, | |
val_labels, val_images_folder, val_output_name, checkpoint_after, val_after): | |
net = PoseEstimationWithMobileNet(num_refinement_stages) | |
stride = 8 | |
sigma = 7 | |
path_thickness = 1 | |
dataset = CocoTrainDataset(prepared_train_labels, train_images_folder, | |
stride, sigma, path_thickness, | |
transform=transforms.Compose([ | |
ConvertKeypoints(), | |
Scale(), | |
Rotate(pad=(128, 128, 128)), | |
CropPad(pad=(128, 128, 128)), | |
Flip()])) | |
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) | |
optimizer = optim.Adam([ | |
{'params': get_parameters_conv(net.model, 'weight')}, | |
{'params': get_parameters_conv_depthwise(net.model, 'weight'), 'weight_decay': 0}, | |
{'params': get_parameters_bn(net.model, 'weight'), 'weight_decay': 0}, | |
{'params': get_parameters_bn(net.model, 'bias'), 'lr': base_lr * 2, 'weight_decay': 0}, | |
{'params': get_parameters_conv(net.cpm, 'weight'), 'lr': base_lr}, | |
{'params': get_parameters_conv(net.cpm, 'bias'), 'lr': base_lr * 2, 'weight_decay': 0}, | |
{'params': get_parameters_conv_depthwise(net.cpm, 'weight'), 'weight_decay': 0}, | |
{'params': get_parameters_conv(net.initial_stage, 'weight'), 'lr': base_lr}, | |
{'params': get_parameters_conv(net.initial_stage, 'bias'), 'lr': base_lr * 2, 'weight_decay': 0}, | |
{'params': get_parameters_conv(net.refinement_stages, 'weight'), 'lr': base_lr * 4}, | |
{'params': get_parameters_conv(net.refinement_stages, 'bias'), 'lr': base_lr * 8, 'weight_decay': 0}, | |
{'params': get_parameters_bn(net.refinement_stages, 'weight'), 'weight_decay': 0}, | |
{'params': get_parameters_bn(net.refinement_stages, 'bias'), 'lr': base_lr * 2, 'weight_decay': 0}, | |
], lr=base_lr, weight_decay=5e-4) | |
num_iter = 0 | |
current_epoch = 0 | |
drop_after_epoch = [100, 200, 260] | |
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=drop_after_epoch, gamma=0.333) | |
if checkpoint_path: | |
checkpoint = torch.load(checkpoint_path) | |
if from_mobilenet: | |
load_from_mobilenet(net, checkpoint) | |
else: | |
load_state(net, checkpoint) | |
if not weights_only: | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
scheduler.load_state_dict(checkpoint['scheduler']) | |
num_iter = checkpoint['iter'] | |
current_epoch = checkpoint['current_epoch'] | |
net = DataParallel(net).cuda() | |
net.train() | |
for epochId in range(current_epoch, 280): | |
scheduler.step() | |
total_losses = [0, 0] * (num_refinement_stages + 1) # heatmaps loss, paf loss per stage | |
batch_per_iter_idx = 0 | |
for batch_data in train_loader: | |
if batch_per_iter_idx == 0: | |
optimizer.zero_grad() | |
images = batch_data['image'].cuda() | |
keypoint_masks = batch_data['keypoint_mask'].cuda() | |
paf_masks = batch_data['paf_mask'].cuda() | |
keypoint_maps = batch_data['keypoint_maps'].cuda() | |
paf_maps = batch_data['paf_maps'].cuda() | |
stages_output = net(images) | |
losses = [] | |
for loss_idx in range(len(total_losses) // 2): | |
losses.append(l2_loss(stages_output[loss_idx * 2], keypoint_maps, keypoint_masks, images.shape[0])) | |
losses.append(l2_loss(stages_output[loss_idx * 2 + 1], paf_maps, paf_masks, images.shape[0])) | |
total_losses[loss_idx * 2] += losses[-2].item() / batches_per_iter | |
total_losses[loss_idx * 2 + 1] += losses[-1].item() / batches_per_iter | |
loss = losses[0] | |
for loss_idx in range(1, len(losses)): | |
loss += losses[loss_idx] | |
loss /= batches_per_iter | |
loss.backward() | |
batch_per_iter_idx += 1 | |
if batch_per_iter_idx == batches_per_iter: | |
optimizer.step() | |
batch_per_iter_idx = 0 | |
num_iter += 1 | |
else: | |
continue | |
if num_iter % log_after == 0: | |
print('Iter: {}'.format(num_iter)) | |
for loss_idx in range(len(total_losses) // 2): | |
print('\n'.join(['stage{}_pafs_loss: {}', 'stage{}_heatmaps_loss: {}']).format( | |
loss_idx + 1, total_losses[loss_idx * 2 + 1] / log_after, | |
loss_idx + 1, total_losses[loss_idx * 2] / log_after)) | |
for loss_idx in range(len(total_losses)): | |
total_losses[loss_idx] = 0 | |
if num_iter % checkpoint_after == 0: | |
snapshot_name = '{}/checkpoint_iter_{}.pth'.format(checkpoints_folder, num_iter) | |
torch.save({'state_dict': net.module.state_dict(), | |
'optimizer': optimizer.state_dict(), | |
'scheduler': scheduler.state_dict(), | |
'iter': num_iter, | |
'current_epoch': epochId}, | |
snapshot_name) | |
if num_iter % val_after == 0: | |
print('Validation...') | |
evaluate(val_labels, val_output_name, val_images_folder, net) | |
net.train() | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--prepared-train-labels', type=str, required=True, | |
help='path to the file with prepared annotations') | |
parser.add_argument('--train-images-folder', type=str, required=True, help='path to COCO train images folder') | |
parser.add_argument('--num-refinement-stages', type=int, default=1, help='number of refinement stages') | |
parser.add_argument('--base-lr', type=float, default=4e-5, help='initial learning rate') | |
parser.add_argument('--batch-size', type=int, default=80, help='batch size') | |
parser.add_argument('--batches-per-iter', type=int, default=1, help='number of batches to accumulate gradient from') | |
parser.add_argument('--num-workers', type=int, default=8, help='number of workers') | |
parser.add_argument('--checkpoint-path', type=str, required=True, help='path to the checkpoint to continue training from') | |
parser.add_argument('--from-mobilenet', action='store_true', | |
help='load weights from mobilenet feature extractor') | |
parser.add_argument('--weights-only', action='store_true', | |
help='just initialize layers with pre-trained weights and start training from the beginning') | |
parser.add_argument('--experiment-name', type=str, default='default', | |
help='experiment name to create folder for checkpoints') | |
parser.add_argument('--log-after', type=int, default=100, help='number of iterations to print train loss') | |
parser.add_argument('--val-labels', type=str, required=True, help='path to json with keypoints val labels') | |
parser.add_argument('--val-images-folder', type=str, required=True, help='path to COCO val images folder') | |
parser.add_argument('--val-output-name', type=str, default='detections.json', | |
help='name of output json file with detected keypoints') | |
parser.add_argument('--checkpoint-after', type=int, default=5000, | |
help='number of iterations to save checkpoint') | |
parser.add_argument('--val-after', type=int, default=5000, | |
help='number of iterations to run validation') | |
args = parser.parse_args() | |
checkpoints_folder = '{}_checkpoints'.format(args.experiment_name) | |
if not os.path.exists(checkpoints_folder): | |
os.makedirs(checkpoints_folder) | |
train(args.prepared_train_labels, args.train_images_folder, args.num_refinement_stages, args.base_lr, args.batch_size, | |
args.batches_per_iter, args.num_workers, args.checkpoint_path, args.weights_only, args.from_mobilenet, | |
checkpoints_folder, args.log_after, args.val_labels, args.val_images_folder, args.val_output_name, | |
args.checkpoint_after, args.val_after) | |