|
import math |
|
|
|
|
|
def adjust_learning_rate(optimizer, |
|
base_lr, |
|
p, |
|
itr, |
|
max_itr, |
|
restart=1, |
|
warm_up_steps=1000, |
|
is_cosine_decay=False, |
|
min_lr=1e-5, |
|
encoder_lr_ratio=1.0, |
|
freeze_params=[]): |
|
|
|
if restart > 1: |
|
each_max_itr = int(math.ceil(float(max_itr) / restart)) |
|
itr = itr % each_max_itr |
|
warm_up_steps /= restart |
|
max_itr = each_max_itr |
|
|
|
if itr < warm_up_steps: |
|
now_lr = min_lr + (base_lr - min_lr) * itr / warm_up_steps |
|
else: |
|
itr = itr - warm_up_steps |
|
max_itr = max_itr - warm_up_steps |
|
if is_cosine_decay: |
|
now_lr = min_lr + (base_lr - min_lr) * (math.cos(math.pi * itr / |
|
(max_itr + 1)) + |
|
1.) * 0.5 |
|
else: |
|
now_lr = min_lr + (base_lr - min_lr) * (1 - itr / (max_itr + 1))**p |
|
|
|
for param_group in optimizer.param_groups: |
|
if encoder_lr_ratio != 1.0 and "encoder." in param_group["name"]: |
|
param_group['lr'] = (now_lr - min_lr) * encoder_lr_ratio + min_lr |
|
else: |
|
param_group['lr'] = now_lr |
|
|
|
for freeze_param in freeze_params: |
|
if freeze_param in param_group["name"]: |
|
param_group['lr'] = 0 |
|
param_group['weight_decay'] = 0 |
|
break |
|
|
|
return now_lr |
|
|
|
|
|
def get_trainable_params(model, |
|
base_lr, |
|
weight_decay, |
|
use_frozen_bn=False, |
|
exclusive_wd_dict={}, |
|
no_wd_keys=[]): |
|
params = [] |
|
memo = set() |
|
total_param = 0 |
|
for key, value in model.named_parameters(): |
|
if value in memo: |
|
continue |
|
total_param += value.numel() |
|
if not value.requires_grad: |
|
continue |
|
memo.add(value) |
|
wd = weight_decay |
|
for exclusive_key in exclusive_wd_dict.keys(): |
|
if exclusive_key in key: |
|
wd = exclusive_wd_dict[exclusive_key] |
|
break |
|
if len(value.shape) == 1: |
|
if 'bias' in key: |
|
wd = 0. |
|
elif not use_frozen_bn: |
|
wd = 0. |
|
elif 'encoder.' not in key: |
|
wd = 0. |
|
else: |
|
for no_wd_key in no_wd_keys: |
|
if no_wd_key in key: |
|
wd = 0. |
|
break |
|
params += [{ |
|
"params": [value], |
|
"lr": base_lr, |
|
"weight_decay": wd, |
|
"name": key |
|
}] |
|
|
|
print('Total Param: {:.2f}M'.format(total_param / 1e6)) |
|
return params |
|
|
|
|
|
def freeze_params(module): |
|
for p in module.parameters(): |
|
p.requires_grad = False |
|
|
|
|
|
def calculate_params(state_dict): |
|
memo = set() |
|
total_param = 0 |
|
for key, value in state_dict.items(): |
|
if value in memo: |
|
continue |
|
memo.add(value) |
|
total_param += value.numel() |
|
print('Total Param: {:.2f}M'.format(total_param / 1e6)) |
|
|