import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .submodules.encoder import Encoder | |
from .submodules.decoder import Decoder | |
class NNET(nn.Module): | |
def __init__(self, args): | |
super(NNET, self).__init__() | |
self.encoder = Encoder() | |
self.decoder = Decoder(args) | |
def get_1x_lr_params(self): # lr/10 learning rate | |
return self.encoder.parameters() | |
def get_10x_lr_params(self): # lr learning rate | |
return self.decoder.parameters() | |
def forward(self, img, **kwargs): | |
return self.decoder(self.encoder(img), **kwargs) |