import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class Encoder(nn.Module): | |
def __init__(self): | |
super(Encoder, self).__init__() | |
basemodel_name = 'tf_efficientnet_b5_ap' | |
print('Loading base model ()...'.format(basemodel_name), end='') | |
repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo') | |
basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local') | |
print('Done.') | |
# Remove last layer | |
print('Removing last two layers (global_pool & classifier).') | |
basemodel.global_pool = nn.Identity() | |
basemodel.classifier = nn.Identity() | |
self.original_model = basemodel | |
def forward(self, x): | |
features = [x] | |
for k, v in self.original_model._modules.items(): | |
if (k == 'blocks'): | |
for ki, vi in v._modules.items(): | |
features.append(vi(features[-1])) | |
else: | |
features.append(v(features[-1])) | |
return features | |