Spaces:
Runtime error
Runtime error
# MIT License | |
# Copyright (c) 2022 Intelligent Systems Lab Org | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
# File author: Shariq Farooq Bhat | |
import torch | |
import torch.nn as nn | |
def log_binom(n, k, eps=1e-7): | |
""" log(nCk) using stirling approximation """ | |
n = n + eps | |
k = k + eps | |
return n * torch.log(n) - k * torch.log(k) - (n-k) * torch.log(n-k+eps) | |
class LogBinomial(nn.Module): | |
def __init__(self, n_classes=256, act=torch.softmax): | |
"""Compute log binomial distribution for n_classes | |
Args: | |
n_classes (int, optional): number of output classes. Defaults to 256. | |
""" | |
super().__init__() | |
self.K = n_classes | |
self.act = act | |
self.register_buffer('k_idx', torch.arange( | |
0, n_classes).view(1, -1, 1, 1)) | |
self.register_buffer('K_minus_1', torch.Tensor( | |
[self.K-1]).view(1, -1, 1, 1)) | |
def forward(self, x, t=1., eps=1e-4): | |
"""Compute log binomial distribution for x | |
Args: | |
x (torch.Tensor - NCHW): probabilities | |
t (float, torch.Tensor - NCHW, optional): Temperature of distribution. Defaults to 1.. | |
eps (float, optional): Small number for numerical stability. Defaults to 1e-4. | |
Returns: | |
torch.Tensor -NCHW: log binomial distribution logbinomial(p;t) | |
""" | |
if x.ndim == 3: | |
x = x.unsqueeze(1) # make it nchw | |
one_minus_x = torch.clamp(1 - x, eps, 1) | |
x = torch.clamp(x, eps, 1) | |
y = log_binom(self.K_minus_1, self.k_idx) + self.k_idx * \ | |
torch.log(x) + (self.K - 1 - self.k_idx) * torch.log(one_minus_x) | |
return self.act(y/t, dim=1) | |
class ConditionalLogBinomial(nn.Module): | |
def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax): | |
"""Conditional Log Binomial distribution | |
Args: | |
in_features (int): number of input channels in main feature | |
condition_dim (int): number of input channels in condition feature | |
n_classes (int, optional): Number of classes. Defaults to 256. | |
bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2. | |
p_eps (float, optional): small eps value. Defaults to 1e-4. | |
max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50. | |
min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7. | |
""" | |
super().__init__() | |
self.p_eps = p_eps | |
self.max_temp = max_temp | |
self.min_temp = min_temp | |
self.log_binomial_transform = LogBinomial(n_classes, act=act) | |
bottleneck = (in_features + condition_dim) // bottleneck_factor | |
self.mlp = nn.Sequential( | |
nn.Conv2d(in_features + condition_dim, bottleneck, | |
kernel_size=1, stride=1, padding=0), | |
nn.GELU(), | |
# 2 for p linear norm, 2 for t linear norm | |
nn.Conv2d(bottleneck, 2+2, kernel_size=1, stride=1, padding=0), | |
nn.Softplus() | |
) | |
def forward(self, x, cond): | |
"""Forward pass | |
Args: | |
x (torch.Tensor - NCHW): Main feature | |
cond (torch.Tensor - NCHW): condition feature | |
Returns: | |
torch.Tensor: Output log binomial distribution | |
""" | |
pt = self.mlp(torch.concat((x, cond), dim=1)) | |
p, t = pt[:, :2, ...], pt[:, 2:, ...] | |
p = p + self.p_eps | |
p = p[:, 0, ...] / (p[:, 0, ...] + p[:, 1, ...]) | |
t = t + self.p_eps | |
t = t[:, 0, ...] / (t[:, 0, ...] + t[:, 1, ...]) | |
t = t.unsqueeze(1) | |
t = (self.max_temp - self.min_temp) * t + self.min_temp | |
return self.log_binomial_transform(p, t) | |
class ConditionalLogBinomialV2(nn.Module): | |
def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax): | |
"""Conditional Log Binomial distribution | |
Args: | |
in_features (int): number of input channels in main feature | |
condition_dim (int): number of input channels in condition feature | |
n_classes (int, optional): Number of classes. Defaults to 256. | |
bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2. | |
p_eps (float, optional): small eps value. Defaults to 1e-4. | |
max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50. | |
min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7. | |
""" | |
super().__init__() | |
self.p_eps = p_eps | |
self.max_temp = max_temp | |
self.min_temp = min_temp | |
self.log_binomial_transform = LogBinomial(n_classes, act=act) | |
bottleneck = (in_features + condition_dim) // bottleneck_factor | |
self.mlp = nn.Sequential( | |
nn.Conv2d(in_features + condition_dim, bottleneck, | |
kernel_size=1, stride=1, padding=0), | |
nn.GELU(), | |
# 2 for p linear norm, 2 for t linear norm | |
nn.Conv2d(bottleneck, n_classes*2, kernel_size=1, stride=1, padding=0), | |
nn.Sigmoid() | |
) | |
self.n_classes = n_classes | |
def forward(self, x, cond): | |
"""Forward pass | |
Args: | |
x (torch.Tensor - NCHW): Main feature | |
cond (torch.Tensor - NCHW): condition feature | |
Returns: | |
torch.Tensor: Output log binomial distribution | |
""" | |
pt = self.mlp(torch.concat((x, cond), dim=1)) | |
prob, shift = pt[:, :self.n_classes, ...], pt[:, self.n_classes:, ...] | |
return prob, shift | |