import os import torch import torch.nn as nn import torch.nn.functional as F class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() basemodel_name = 'tf_efficientnet_b5_ap' print('Loading base model ()...'.format(basemodel_name), end='') repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo') basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local') print('Done.') # Remove last layer print('Removing last two layers (global_pool & classifier).') basemodel.global_pool = nn.Identity() basemodel.classifier = nn.Identity() self.original_model = basemodel def forward(self, x): features = [x] for k, v in self.original_model._modules.items(): if (k == 'blocks'): for ki, vi in v._modules.items(): features.append(vi(features[-1])) else: features.append(v(features[-1])) return features