import torch import torch.nn as nn import torch.nn.functional as F from .submodules.encoder import Encoder from .submodules.decoder import Decoder class NNET(nn.Module): def __init__(self, args): super(NNET, self).__init__() self.encoder = Encoder() self.decoder = Decoder(args) def get_1x_lr_params(self): # lr/10 learning rate return self.encoder.parameters() def get_10x_lr_params(self): # lr learning rate return self.decoder.parameters() def forward(self, img, **kwargs): return self.decoder(self.encoder(img), **kwargs)