''' * Copyright (c) 2023 Salesforce, Inc. * All rights reserved. * SPDX-License-Identifier: Apache License 2.0 * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/ * By Can Qin * Modified from ControlNet repo: https://github.com/lllyasviel/ControlNet * Copyright (c) 2023 Lvmin Zhang and Maneesh Agrawala * Modified from MMCV repo: From https://github.com/open-mmlab/mmcv * Copyright (c) OpenMMLab. All rights reserved. ''' import torch from torch import nn from torch.nn import functional as F class Encoding(nn.Module): """Encoding Layer: a learnable residual encoder. Input is of shape (batch_size, channels, height, width). Output is of shape (batch_size, num_codes, channels). Args: channels: dimension of the features or feature channels num_codes: number of code words """ def __init__(self, channels, num_codes): super(Encoding, self).__init__() # init codewords and smoothing factor self.channels, self.num_codes = channels, num_codes std = 1. / ((num_codes * channels)**0.5) # [num_codes, channels] self.codewords = nn.Parameter( torch.empty(num_codes, channels, dtype=torch.float).uniform_(-std, std), requires_grad=True) # [num_codes] self.scale = nn.Parameter( torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), requires_grad=True) @staticmethod def scaled_l2(x, codewords, scale): num_codes, channels = codewords.size() batch_size = x.size(0) reshaped_scale = scale.view((1, 1, num_codes)) expanded_x = x.unsqueeze(2).expand( (batch_size, x.size(1), num_codes, channels)) reshaped_codewords = codewords.view((1, 1, num_codes, channels)) scaled_l2_norm = reshaped_scale * ( expanded_x - reshaped_codewords).pow(2).sum(dim=3) return scaled_l2_norm @staticmethod def aggregate(assignment_weights, x, codewords): num_codes, channels = codewords.size() reshaped_codewords = codewords.view((1, 1, num_codes, channels)) batch_size = x.size(0) expanded_x = x.unsqueeze(2).expand( (batch_size, x.size(1), num_codes, channels)) encoded_feat = (assignment_weights.unsqueeze(3) * (expanded_x - reshaped_codewords)).sum(dim=1) return encoded_feat def forward(self, x): assert x.dim() == 4 and x.size(1) == self.channels # [batch_size, channels, height, width] batch_size = x.size(0) # [batch_size, height x width, channels] x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() # assignment_weights: [batch_size, channels, num_codes] assignment_weights = F.softmax( self.scaled_l2(x, self.codewords, self.scale), dim=2) # aggregate encoded_feat = self.aggregate(assignment_weights, x, self.codewords) return encoded_feat def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ f'x{self.channels})' return repr_str