bill-jiang's picture
Init
4409449
raw
history blame contribute delete
No virus
1.12 kB
import torch.nn as nn
def remove_padding(tensors, lengths):
return [tensor[:tensor_length] for tensor, tensor_length in zip(tensors, lengths)]
class AutoParams(nn.Module):
def __init__(self, **kargs):
try:
for param in self.needed_params:
if param in kargs:
setattr(self, param, kargs[param])
else:
raise ValueError(f"{param} is needed.")
except :
pass
try:
for param, default in self.optional_params.items():
if param in kargs and kargs[param] is not None:
setattr(self, param, kargs[param])
else:
setattr(self, param, default)
except :
pass
super().__init__()
# taken from joeynmt repo
def freeze_params(module: nn.Module) -> None:
"""
Freeze the parameters of this module,
i.e. do not update them during training
:param module: freeze parameters of this module
"""
for _, p in module.named_parameters():
p.requires_grad = False