File size: 1,060 Bytes
18dd6ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
basemodel_name = 'tf_efficientnet_b5_ap'
print('Loading base model ()...'.format(basemodel_name), end='')
repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo')
basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local')
print('Done.')
# Remove last layer
print('Removing last two layers (global_pool & classifier).')
basemodel.global_pool = nn.Identity()
basemodel.classifier = nn.Identity()
self.original_model = basemodel
def forward(self, x):
features = [x]
for k, v in self.original_model._modules.items():
if (k == 'blocks'):
for ki, vi in v._modules.items():
features.append(vi(features[-1]))
else:
features.append(v(features[-1]))
return features
|