samtrack / aot /utils /learning.py
aikenml's picture
Upload folder using huggingface_hub
c985ba4
raw
history blame
3.45 kB
import math
def adjust_learning_rate(optimizer,
base_lr,
p,
itr,
max_itr,
restart=1,
warm_up_steps=1000,
is_cosine_decay=False,
min_lr=1e-5,
encoder_lr_ratio=1.0,
freeze_params=[]):
if restart > 1:
each_max_itr = int(math.ceil(float(max_itr) / restart))
itr = itr % each_max_itr
warm_up_steps /= restart
max_itr = each_max_itr
if itr < warm_up_steps:
now_lr = min_lr + (base_lr - min_lr) * itr / warm_up_steps
else:
itr = itr - warm_up_steps
max_itr = max_itr - warm_up_steps
if is_cosine_decay:
now_lr = min_lr + (base_lr - min_lr) * (math.cos(math.pi * itr /
(max_itr + 1)) +
1.) * 0.5
else:
now_lr = min_lr + (base_lr - min_lr) * (1 - itr / (max_itr + 1))**p
for param_group in optimizer.param_groups:
if encoder_lr_ratio != 1.0 and "encoder." in param_group["name"]:
param_group['lr'] = (now_lr - min_lr) * encoder_lr_ratio + min_lr
else:
param_group['lr'] = now_lr
for freeze_param in freeze_params:
if freeze_param in param_group["name"]:
param_group['lr'] = 0
param_group['weight_decay'] = 0
break
return now_lr
def get_trainable_params(model,
base_lr,
weight_decay,
use_frozen_bn=False,
exclusive_wd_dict={},
no_wd_keys=[]):
params = []
memo = set()
total_param = 0
for key, value in model.named_parameters():
if value in memo:
continue
total_param += value.numel()
if not value.requires_grad:
continue
memo.add(value)
wd = weight_decay
for exclusive_key in exclusive_wd_dict.keys():
if exclusive_key in key:
wd = exclusive_wd_dict[exclusive_key]
break
if len(value.shape) == 1: # normalization layers
if 'bias' in key: # bias requires no weight decay
wd = 0.
elif not use_frozen_bn: # if not use frozen BN, apply zero weight decay
wd = 0.
elif 'encoder.' not in key: # if use frozen BN, apply weight decay to all frozen BNs in the encoder
wd = 0.
else:
for no_wd_key in no_wd_keys:
if no_wd_key in key:
wd = 0.
break
params += [{
"params": [value],
"lr": base_lr,
"weight_decay": wd,
"name": key
}]
print('Total Param: {:.2f}M'.format(total_param / 1e6))
return params
def freeze_params(module):
for p in module.parameters():
p.requires_grad = False
def calculate_params(state_dict):
memo = set()
total_param = 0
for key, value in state_dict.items():
if value in memo:
continue
memo.add(value)
total_param += value.numel()
print('Total Param: {:.2f}M'.format(total_param / 1e6))