"""This file is for Inception model borrowed from torch metrics / fidelity. This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. Reference: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py """ # Copyright The Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn.functional as F from torch_fidelity.feature_extractor_base import FeatureExtractorBase from torch_fidelity.helpers import vassert from torch_fidelity.feature_extractor_inceptionv3 import BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE_1, InceptionE_2 from torch_fidelity.interpolate_compat_tensorflow import interpolate_bilinear_2d_like_tensorflow1x try: from torchvision.models.utils import load_state_dict_from_url except ImportError: from torch.utils.model_zoo import load_url as load_state_dict_from_url # Note: Compared shasum and models should be the same. FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' class FeatureExtractorInceptionV3(FeatureExtractorBase): INPUT_IMAGE_SIZE = 299 def __init__( self, name, features_list, **kwargs, ): """ InceptionV3 feature extractor for 2D RGB 24bit images. Args: name (str): Unique name of the feature extractor, must be the same as used in :func:`register_feature_extractor`. features_list (list): A list of the requested feature names, which will be produced for each input. This feature extractor provides the following features: - '64' - '192' - '768' - '2048' - 'logits_unbiased' - 'logits' """ super(FeatureExtractorInceptionV3, self).__init__(name, features_list) self.feature_extractor_internal_dtype = torch.float64 self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) self.MaxPool_1 = torch.nn.MaxPool2d(kernel_size=3, stride=2) self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) self.MaxPool_2 = torch.nn.MaxPool2d(kernel_size=3, stride=2) self.Mixed_5b = InceptionA(192, pool_features=32) self.Mixed_5c = InceptionA(256, pool_features=64) self.Mixed_5d = InceptionA(288, pool_features=64) self.Mixed_6a = InceptionB(288) self.Mixed_6b = InceptionC(768, channels_7x7=128) self.Mixed_6c = InceptionC(768, channels_7x7=160) self.Mixed_6d = InceptionC(768, channels_7x7=160) self.Mixed_6e = InceptionC(768, channels_7x7=192) self.Mixed_7a = InceptionD(768) self.Mixed_7b = InceptionE_1(1280) self.Mixed_7c = InceptionE_2(2048) self.AvgPool = torch.nn.AdaptiveAvgPool2d(output_size=(1, 1)) self.fc = torch.nn.Linear(2048, 1008) state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) #state_dict = torch.load(FID_WEIGHTS_URL, map_location='cpu') self.load_state_dict(state_dict) self.to(self.feature_extractor_internal_dtype) self.requires_grad_(False) self.eval() def forward(self, x): vassert(torch.is_tensor(x) and x.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8') vassert(x.dim() == 4 and x.shape[1] == 3, f'Input is not Bx3xHxW: {x.shape}') features = {} remaining_features = self.features_list.copy() x = x.to(self.feature_extractor_internal_dtype) # N x 3 x ? x ? x = interpolate_bilinear_2d_like_tensorflow1x( x, size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE), align_corners=False, ) # N x 3 x 299 x 299 # x = (x - 128) * torch.tensor(0.0078125, dtype=torch.float32, device=x.device) # really happening in graph x = (x - 128) / 128 # but this gives bit-exact output _of this step_ too # N x 3 x 299 x 299 x = self.Conv2d_1a_3x3(x) # N x 32 x 149 x 149 x = self.Conv2d_2a_3x3(x) # N x 32 x 147 x 147 x = self.Conv2d_2b_3x3(x) # N x 64 x 147 x 147 x = self.MaxPool_1(x) # N x 64 x 73 x 73 if '64' in remaining_features: features['64'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1) remaining_features.remove('64') if len(remaining_features) == 0: return features x = self.Conv2d_3b_1x1(x) # N x 80 x 73 x 73 x = self.Conv2d_4a_3x3(x) # N x 192 x 71 x 71 x = self.MaxPool_2(x) # N x 192 x 35 x 35 if '192' in remaining_features: features['192'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1) remaining_features.remove('192') if len(remaining_features) == 0: return features x = self.Mixed_5b(x) # N x 256 x 35 x 35 x = self.Mixed_5c(x) # N x 288 x 35 x 35 x = self.Mixed_5d(x) # N x 288 x 35 x 35 x = self.Mixed_6a(x) # N x 768 x 17 x 17 x = self.Mixed_6b(x) # N x 768 x 17 x 17 x = self.Mixed_6c(x) # N x 768 x 17 x 17 x = self.Mixed_6d(x) # N x 768 x 17 x 17 x = self.Mixed_6e(x) # N x 768 x 17 x 17 if '768' in remaining_features: features['768'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1).to(torch.float32) remaining_features.remove('768') if len(remaining_features) == 0: return features x = self.Mixed_7a(x) # N x 1280 x 8 x 8 x = self.Mixed_7b(x) # N x 2048 x 8 x 8 x = self.Mixed_7c(x) # N x 2048 x 8 x 8 x = self.AvgPool(x) # N x 2048 x 1 x 1 x = torch.flatten(x, 1) # N x 2048 if '2048' in remaining_features: features['2048'] = x remaining_features.remove('2048') if len(remaining_features) == 0: return features if 'logits_unbiased' in remaining_features: x = x.mm(self.fc.weight.T) # N x 1008 (num_classes) features['logits_unbiased'] = x remaining_features.remove('logits_unbiased') if len(remaining_features) == 0: return features x = x + self.fc.bias.unsqueeze(0) else: x = self.fc(x) # N x 1008 (num_classes) features['logits'] = x return features @staticmethod def get_provided_features_list(): return '64', '192', '768', '2048', 'logits_unbiased', 'logits' @staticmethod def get_default_feature_layer_for_metric(metric): return { 'isc': 'logits_unbiased', 'fid': '2048', 'kid': '2048', 'prc': '2048', }[metric] @staticmethod def can_be_compiled(): return True @staticmethod def get_dummy_input_for_compile(): return (torch.rand([1, 3, 4, 4]) * 255).to(torch.uint8) def get_inception_model(): model = FeatureExtractorInceptionV3("inception_model", ["2048", "logits_unbiased"]) return model