diffai / components.py
khulnasoft's picture
Upload 16 files
746c674 verified
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.distributions import multinomial, categorical
import torch.optim as optim
import math
try:
from . import helpers as h
from . import ai
from . import scheduling as S
except:
import helpers as h
import ai
import scheduling as S
import math
import abc
from torch.nn.modules.conv import _ConvNd
from enum import Enum
class InferModule(nn.Module):
def __init__(self, *args, normal = False, ibp_init = False, **kwargs):
self.args = args
self.kwargs = kwargs
self.infered = False
self.normal = normal
self.ibp_init = ibp_init
def infer(self, in_shape, global_args = None):
""" this is really actually stateful. """
if self.infered:
return self
self.infered = True
super(InferModule, self).__init__()
self.inShape = list(in_shape)
self.outShape = list(self.init(list(in_shape), *self.args, global_args = global_args, **self.kwargs))
if self.outShape is None:
raise "init should set the out_shape"
self.reset_parameters()
return self
def reset_parameters(self):
if not hasattr(self,'weight') or self.weight is None:
return
n = h.product(self.weight.size()) / self.outShape[0]
stdv = 1 / math.sqrt(n)
if self.ibp_init:
torch.nn.init.orthogonal_(self.weight.data)
elif self.normal:
self.weight.data.normal_(0, stdv)
self.weight.data.clamp_(-1, 1)
else:
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
if self.ibp_init:
self.bias.data.zero_()
elif self.normal:
self.bias.data.normal_(0, stdv)
self.bias.data.clamp_(-1, 1)
else:
self.bias.data.uniform_(-stdv, stdv)
def clip_norm(self):
if not hasattr(self, "weight"):
return
if not hasattr(self,"weight_g"):
if torch.__version__[0] == "0":
nn.utils.weight_norm(self, dim=None)
else:
nn.utils.weight_norm(self)
self.weight_g.data.clamp_(-h.max_c_for_norm, h.max_c_for_norm)
if torch.__version__[0] != "0":
self.weight_v.data.clamp_(-h.max_c_for_norm * 10000,h.max_c_for_norm * 10000)
if hasattr(self, "bias"):
self.bias.data.clamp_(-h.max_c_for_norm * 10000, h.max_c_for_norm * 10000)
def regularize(self, p):
reg = 0
if torch.__version__[0] == "0":
for param in self.parameters():
reg += param.norm(p)
else:
if hasattr(self, "weight_g"):
reg += self.weight_g.norm().sum()
reg += self.weight_v.norm().sum()
elif hasattr(self, "weight"):
reg += self.weight.norm().sum()
if hasattr(self, "bias"):
reg += self.bias.view(-1).norm(p=p).sum()
return reg
def remove_norm(self):
if hasattr(self,"weight_g"):
torch.nn.utils.remove_weight_norm(self)
def showNet(self, t = ""):
print(t + self.__class__.__name__)
def printNet(self, f):
print(self.__class__.__name__, file=f)
@abc.abstractmethod
def forward(self, *args, **kargs):
pass
def __call__(self, *args, onyx=False, **kargs):
if onyx:
return self.forward(*args, onyx=onyx, **kargs)
else:
return super(InferModule, self).__call__(*args, **kargs)
@abc.abstractmethod
def neuronCount(self):
pass
def depth(self):
return 0
def getShapeConv(in_shape, conv_shape, stride = 1, padding = 0):
inChan, inH, inW = in_shape
outChan, kH, kW = conv_shape[:3]
outH = 1 + int((2 * padding + inH - kH) / stride)
outW = 1 + int((2 * padding + inW - kW) / stride)
return (outChan, outH, outW)
def getShapeConvTranspose(in_shape, conv_shape, stride = 1, padding = 0, out_padding=0):
inChan, inH, inW = in_shape
outChan, kH, kW = conv_shape[:3]
outH = (inH - 1 ) * stride - 2 * padding + kH + out_padding
outW = (inW - 1 ) * stride - 2 * padding + kW + out_padding
return (outChan, outH, outW)
class Linear(InferModule):
def init(self, in_shape, out_shape, **kargs):
self.in_neurons = h.product(in_shape)
if isinstance(out_shape, int):
out_shape = [out_shape]
self.out_neurons = h.product(out_shape)
self.weight = torch.nn.Parameter(torch.Tensor(self.in_neurons, self.out_neurons))
self.bias = torch.nn.Parameter(torch.Tensor(self.out_neurons))
return out_shape
def forward(self, x, **kargs):
s = x.size()
x = x.view(s[0], h.product(s[1:]))
return (x.matmul(self.weight) + self.bias).view(s[0], *self.outShape)
def neuronCount(self):
return 0
def showNet(self, t = ""):
print(t + "Linear out=" + str(self.out_neurons))
def printNet(self, f):
print("Linear(" + str(self.out_neurons) + ")" )
print(h.printListsNumpy(list(self.weight.transpose(1,0).data)), file= f)
print(h.printNumpy(self.bias), file= f)
class Activation(InferModule):
def init(self, in_shape, global_args = None, activation = "ReLU", **kargs):
self.activation = [ "ReLU","Sigmoid", "Tanh", "Softplus", "ELU", "SELU"].index(activation)
self.activation_name = activation
return in_shape
def regularize(self, p):
return 0
def forward(self, x, **kargs):
return [lambda x:x.relu(), lambda x:x.sigmoid(), lambda x:x.tanh(), lambda x:x.softplus(), lambda x:x.elu(), lambda x:x.selu()][self.activation](x)
def neuronCount(self):
return h.product(self.outShape)
def depth(self):
return 1
def showNet(self, t = ""):
print(t + self.activation_name)
def printNet(self, f):
pass
class ReLU(Activation):
pass
def activation(*args, batch_norm = False, **kargs):
a = Activation(*args, **kargs)
return Seq(BatchNorm(), a) if batch_norm else a
class Identity(InferModule): # for feigning model equivelence when removing an op
def init(self, in_shape, global_args = None, **kargs):
return in_shape
def forward(self, x, **kargs):
return x
def neuronCount(self):
return 0
def printNet(self, f):
pass
def regularize(self, p):
return 0
def showNet(self, *args, **kargs):
pass
class Dropout(InferModule):
def init(self, in_shape, p=0.5, use_2d = False, alpha_dropout = False, **kargs):
self.p = S.Const.initConst(p)
self.use_2d = use_2d
self.alpha_dropout = alpha_dropout
return in_shape
def forward(self, x, time = 0, **kargs):
if self.training:
with torch.no_grad():
p = self.p.getVal(time = time)
mask = (F.dropout2d if self.use_2d else F.dropout)(h.ones(x.size()),p=p, training=True)
if self.alpha_dropout:
with torch.no_grad():
keep_prob = 1 - p
alpha = -1.7580993408473766
a = math.pow(keep_prob + alpha * alpha * keep_prob * (1 - keep_prob), -0.5)
b = -a * alpha * (1 - keep_prob)
mask = mask * a
return x * mask + b
else:
return x * mask
else:
return x
def neuronCount(self):
return 0
def showNet(self, t = ""):
print(t + "Dropout p=" + str(self.p))
def printNet(self, f):
print("Dropout(" + str(self.p) + ")" )
class PrintActivation(Identity):
def init(self, in_shape, global_args = None, activation = "ReLU", **kargs):
self.activation = activation
return in_shape
def printNet(self, f):
print(self.activation, file = f)
class PrintReLU(PrintActivation):
pass
class Conv2D(InferModule):
def init(self, in_shape, out_channels, kernel_size, stride = 1, global_args = None, bias=True, padding = 0, activation = "ReLU", **kargs):
self.prev = in_shape
self.in_channels = in_shape[0]
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.activation = activation
self.use_softplus = h.default(global_args, 'use_softplus', False)
weights_shape = (self.out_channels, self.in_channels, kernel_size, kernel_size)
self.weight = torch.nn.Parameter(torch.Tensor(*weights_shape))
if bias:
self.bias = torch.nn.Parameter(torch.Tensor(weights_shape[0]))
else:
self.bias = None # h.zeros(weights_shape[0])
outshape = getShapeConv(in_shape, (out_channels, kernel_size, kernel_size), stride, padding)
return outshape
def forward(self, input, **kargs):
return input.conv2d(self.weight, bias=self.bias, stride=self.stride, padding = self.padding )
def printNet(self, f): # only complete if we've forwardt stride=1
print("Conv2D", file = f)
sz = list(self.prev)
print(self.activation + ", filters={}, kernel_size={}, input_shape={}, stride={}, padding={}".format(self.out_channels, [self.kernel_size, self.kernel_size], list(reversed(sz)), [self.stride, self.stride], self.padding ), file = f)
print(h.printListsNumpy([[list(p) for p in l ] for l in self.weight.permute(2,3,1,0).data]) , file= f)
print(h.printNumpy(self.bias if self.bias is not None else h.dten(self.out_channels)), file= f)
def showNet(self, t = ""):
sz = list(self.prev)
print(t + "Conv2D, filters={}, kernel_size={}, input_shape={}, stride={}, padding={}".format(self.out_channels, [self.kernel_size, self.kernel_size], list(reversed(sz)), [self.stride, self.stride], self.padding ))
def neuronCount(self):
return 0
class ConvTranspose2D(InferModule):
def init(self, in_shape, out_channels, kernel_size, stride = 1, global_args = None, bias=True, padding = 0, out_padding=0, activation = "ReLU", **kargs):
self.prev = in_shape
self.in_channels = in_shape[0]
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.out_padding = out_padding
self.activation = activation
self.use_softplus = h.default(global_args, 'use_softplus', False)
weights_shape = (self.in_channels, self.out_channels, kernel_size, kernel_size)
self.weight = torch.nn.Parameter(torch.Tensor(*weights_shape))
if bias:
self.bias = torch.nn.Parameter(torch.Tensor(weights_shape[0]))
else:
self.bias = None # h.zeros(weights_shape[0])
outshape = getShapeConvTranspose(in_shape, (out_channels, kernel_size, kernel_size), stride, padding, out_padding)
return outshape
def forward(self, input, **kargs):
return input.conv_transpose2d(self.weight, bias=self.bias, stride=self.stride, padding = self.padding, output_padding=self.out_padding)
def printNet(self, f): # only complete if we've forwardt stride=1
print("ConvTranspose2D", file = f)
print(self.activation + ", filters={}, kernel_size={}, input_shape={}".format(self.out_channels, list(self.kernel_size), list(self.prev) ), file = f)
print(h.printListsNumpy([[list(p) for p in l ] for l in self.weight.permute(2,3,1,0).data]) , file= f)
print(h.printNumpy(self.bias), file= f)
def neuronCount(self):
return 0
class MaxPool2D(InferModule):
def init(self, in_shape, kernel_size, stride = None, **kargs):
self.prev = in_shape
self.kernel_size = kernel_size
self.stride = kernel_size if stride is None else stride
return getShapeConv(in_shape, (in_shape[0], kernel_size, kernel_size), stride)
def forward(self, x, **kargs):
return x.max_pool2d(self.kernel_size, self.stride)
def printNet(self, f):
print("MaxPool2D stride={}, kernel_size={}, input_shape={}".format(list(self.stride), list(self.shape[2:]), list(self.prev[1:]+self.prev[:1]) ), file = f)
def neuronCount(self):
return h.product(self.outShape)
class AvgPool2D(InferModule):
def init(self, in_shape, kernel_size, stride = None, **kargs):
self.prev = in_shape
self.kernel_size = kernel_size
self.stride = kernel_size if stride is None else stride
out_size = getShapeConv(in_shape, (in_shape[0], kernel_size, kernel_size), self.stride, padding = 1)
return out_size
def forward(self, x, **kargs):
if h.product(x.size()[2:]) == 1:
return x
return x.avg_pool2d(kernel_size = self.kernel_size, stride = self.stride, padding = 1)
def printNet(self, f):
print("AvgPool2D stride={}, kernel_size={}, input_shape={}".format(list(self.stride), list(self.shape[2:]), list(self.prev[1:]+self.prev[:1]) ), file = f)
def neuronCount(self):
return h.product(self.outShape)
class AdaptiveAvgPool2D(InferModule):
def init(self, in_shape, out_shape, **kargs):
self.prev = in_shape
self.out_shape = list(out_shape)
return [in_shape[0]] + self.out_shape
def forward(self, x, **kargs):
return x.adaptive_avg_pool2d(self.out_shape)
def printNet(self, f):
print("AdaptiveAvgPool2D out_Shape={} input_shape={}".format(list(self.out_shape), list(self.prev[1:]+self.prev[:1]) ), file = f)
def neuronCount(self):
return h.product(self.outShape)
class Normalize(InferModule):
def init(self, in_shape, mean, std, **kargs):
self.mean_v = mean
self.std_v = std
self.mean = h.dten(mean)
self.std = 1 / h.dten(std)
return in_shape
def forward(self, x, **kargs):
mean_ex = self.mean.view(self.mean.shape[0],1,1).expand(*x.size()[1:])
std_ex = self.std.view(self.std.shape[0],1,1).expand(*x.size()[1:])
return (x - mean_ex) * std_ex
def neuronCount(self):
return 0
def printNet(self, f):
print("Normalize mean={} std={}".format(self.mean_v, self.std_v), file = f)
def showNet(self, t = ""):
print(t + "Normalize mean={} std={}".format(self.mean_v, self.std_v))
class Flatten(InferModule):
def init(self, in_shape, **kargs):
return h.product(in_shape)
def forward(self, x, **kargs):
s = x.size()
return x.view(s[0], h.product(s[1:]))
def neuronCount(self):
return 0
class BatchNorm(InferModule):
def init(self, in_shape, track_running_stats = True, momentum = 0.1, eps=1e-5, **kargs):
self.gamma = torch.nn.Parameter(torch.Tensor(*in_shape))
self.beta = torch.nn.Parameter(torch.Tensor(*in_shape))
self.eps = eps
self.track_running_stats = track_running_stats
self.momentum = momentum
self.running_mean = None
self.running_var = None
self.num_batches_tracked = 0
return in_shape
def reset_parameters(self):
self.gamma.data.fill_(1)
self.beta.data.zero_()
def forward(self, x, **kargs):
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
new_mean = x.vanillaTensorPart().detach().mean(dim=0)
new_var = x.vanillaTensorPart().detach().var(dim=0, unbiased=False)
if torch.isnan(new_var * 0).any():
return x
if self.training:
self.running_mean = (1 - exponential_average_factor) * self.running_mean + exponential_average_factor * new_mean if self.running_mean is not None else new_mean
if self.running_var is None:
self.running_var = new_var
else:
q = (1 - exponential_average_factor) * self.running_var
r = exponential_average_factor * new_var
self.running_var = q + r
if self.track_running_stats and self.running_mean is not None and self.running_var is not None:
new_mean = self.running_mean
new_var = self.running_var
diver = 1 / (new_var + self.eps).sqrt()
if torch.isnan(diver).any():
print("Really shouldn't happen ever")
return x
else:
out = (x - new_mean) * diver * self.gamma + self.beta
return out
def neuronCount(self):
return 0
class Unflatten2d(InferModule):
def init(self, in_shape, w, **kargs):
self.w = w
self.outChan = int(h.product(in_shape) / (w * w))
return (self.outChan, self.w, self.w)
def forward(self, x, **kargs):
s = x.size()
return x.view(s[0], self.outChan, self.w, self.w)
def neuronCount(self):
return 0
class View(InferModule):
def init(self, in_shape, out_shape, **kargs):
assert(h.product(in_shape) == h.product(out_shape))
return out_shape
def forward(self, x, **kargs):
s = x.size()
return x.view(s[0], *self.outShape)
def neuronCount(self):
return 0
class Seq(InferModule):
def init(self, in_shape, *layers, **kargs):
self.layers = layers
self.net = nn.Sequential(*layers)
self.prev = in_shape
for s in layers:
in_shape = s.infer(in_shape, **kargs).outShape
return in_shape
def forward(self, x, **kargs):
for l in self.layers:
x = l(x, **kargs)
return x
def clip_norm(self):
for l in self.layers:
l.clip_norm()
def regularize(self, p):
return sum(n.regularize(p) for n in self.layers)
def remove_norm(self):
for l in self.layers:
l.remove_norm()
def printNet(self, f):
for l in self.layers:
l.printNet(f)
def showNet(self, *args, **kargs):
for l in self.layers:
l.showNet(*args, **kargs)
def neuronCount(self):
return sum([l.neuronCount() for l in self.layers ])
def depth(self):
return sum([l.depth() for l in self.layers ])
def FFNN(layers, last_lin = False, last_zono = False, **kargs):
starts = layers
ends = []
if last_lin:
ends = ([CorrelateAll(only_train=False)] if last_zono else []) + [PrintActivation(activation = "Affine"), Linear(layers[-1],**kargs)]
starts = layers[:-1]
return Seq(*([ Seq(PrintActivation(**kargs), Linear(s, **kargs), activation(**kargs)) for s in starts] + ends))
def Conv(*args, **kargs):
return Seq(Conv2D(*args, **kargs), activation(**kargs))
def ConvTranspose(*args, **kargs):
return Seq(ConvTranspose2D(*args, **kargs), activation(**kargs))
MP = MaxPool2D
def LeNet(conv_layers, ly = [], bias = True, normal=False, **kargs):
def transfer(tp):
if isinstance(tp, InferModule):
return tp
if isinstance(tp[0], str):
return MaxPool2D(*tp[1:])
return Conv(out_channels = tp[0], kernel_size = tp[1], stride = tp[-1] if len(tp) == 4 else 1, bias=bias, normal=normal, **kargs)
conv = [transfer(s) for s in conv_layers]
return Seq(*conv, FFNN(ly, **kargs, bias=bias)) if len(ly) > 0 else Seq(*conv)
def InvLeNet(ly, w, conv_layers, bias = True, normal=False, **kargs):
def transfer(tp):
return ConvTranspose(out_channels = tp[0], kernel_size = tp[1], stride = tp[2], padding = tp[3], out_padding = tp[4], bias=False, normal=normal)
return Seq(FFNN(ly, bias=bias), Unflatten2d(w), *[transfer(s) for s in conv_layers])
class FromByteImg(InferModule):
def init(self, in_shape, **kargs):
return in_shape
def forward(self, x, **kargs):
return x.to_dtype()/ 256.
def neuronCount(self):
return 0
class Skip(InferModule):
def init(self, in_shape, net1, net2, **kargs):
self.net1 = net1.infer(in_shape, **kargs)
self.net2 = net2.infer(in_shape, **kargs)
assert(net1.outShape[1:] == net2.outShape[1:])
return [ net1.outShape[0] + net2.outShape[0] ] + net1.outShape[1:]
def forward(self, x, **kargs):
r1 = self.net1(x, **kargs)
r2 = self.net2(x, **kargs)
return r1.cat(r2, dim=1)
def regularize(self, p):
return self.net1.regularize(p) + self.net2.regularize(p)
def clip_norm(self):
self.net1.clip_norm()
self.net2.clip_norm()
def remove_norm(self):
self.net1.remove_norm()
self.net2.remove_norm()
def neuronCount(self):
return self.net1.neuronCount() + self.net2.neuronCount()
def printNet(self, f):
print("SkipNet1", file=f)
self.net1.printNet(f)
print("SkipNet2", file=f)
self.net2.printNet(f)
print("SkipCat dim=1", file=f)
def showNet(self, t = ""):
print(t+"SkipNet1")
self.net1.showNet(" "+t)
print(t+"SkipNet2")
self.net2.showNet(" "+t)
print(t+"SkipCat dim=1")
class ParSum(InferModule):
def init(self, in_shape, net1, net2, **kargs):
self.net1 = net1.infer(in_shape, **kargs)
self.net2 = net2.infer(in_shape, **kargs)
assert(net1.outShape == net2.outShape)
return net1.outShape
def forward(self, x, **kargs):
r1 = self.net1(x, **kargs)
r2 = self.net2(x, **kargs)
return x.addPar(r1,r2)
def clip_norm(self):
self.net1.clip_norm()
self.net2.clip_norm()
def remove_norm(self):
self.net1.remove_norm()
self.net2.remove_norm()
def neuronCount(self):
return self.net1.neuronCount() + self.net2.neuronCount()
def depth(self):
return max(self.net1.depth(), self.net2.depth())
def printNet(self, f):
print("ParNet1", file=f)
self.net1.printNet(f)
print("ParNet2", file=f)
self.net2.printNet(f)
print("ParCat dim=1", file=f)
def showNet(self, t = ""):
print(t + "ParNet1")
self.net1.showNet(" "+t)
print(t + "ParNet2")
self.net2.showNet(" "+t)
print(t + "ParSum")
class ToZono(Identity):
def init(self, in_shape, customRelu = None, only_train = False, **kargs):
self.customRelu = customRelu
self.only_train = only_train
return in_shape
def forward(self, x, **kargs):
return self.abstract_forward(x, **kargs) if self.training or not self.only_train else x
def abstract_forward(self, x, **kargs):
return x.abstractApplyLeaf('hybrid_to_zono', customRelu = self.customRelu)
def showNet(self, t = ""):
print(t + self.__class__.__name__ + " only_train=" + str(self.only_train))
class CorrelateAll(ToZono):
def abstract_forward(self, x, **kargs):
return x.abstractApplyLeaf('hybrid_to_zono',correlate=True, customRelu = self.customRelu)
class ToHZono(ToZono):
def abstract_forward(self, x, **kargs):
return x.abstractApplyLeaf('zono_to_hybrid',customRelu = self.customRelu)
class Concretize(ToZono):
def init(self, in_shape, only_train = True, **kargs):
self.only_train = only_train
return in_shape
def abstract_forward(self, x, **kargs):
return x.abstractApplyLeaf('concretize')
# stochastic correlation
class CorrRand(Concretize):
def init(self, in_shape, num_correlate, only_train = True, **kargs):
self.only_train = only_train
self.num_correlate = num_correlate
return in_shape
def abstract_forward(self, x):
return x.abstractApplyLeaf("stochasticCorrelate", self.num_correlate)
def showNet(self, t = ""):
print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " num_correlate="+ str(self.num_correlate))
class CorrMaxK(CorrRand):
def abstract_forward(self, x):
return x.abstractApplyLeaf("correlateMaxK", self.num_correlate)
class CorrMaxPool2D(Concretize):
def init(self,in_shape, kernel_size, only_train = True, max_type = ai.MaxTypes.head_beta, **kargs):
self.only_train = only_train
self.kernel_size = kernel_size
self.max_type = max_type
return in_shape
def abstract_forward(self, x):
return x.abstractApplyLeaf("correlateMaxPool", kernel_size = self.kernel_size, stride = self.kernel_size, max_type = self.max_type)
def showNet(self, t = ""):
print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " kernel_size="+ str(self.kernel_size) + " max_type=" +str(self.max_type))
class CorrMaxPool3D(Concretize):
def init(self,in_shape, kernel_size, only_train = True, max_type = ai.MaxTypes.only_beta, **kargs):
self.only_train = only_train
self.kernel_size = kernel_size
self.max_type = max_type
return in_shape
def abstract_forward(self, x):
return x.abstractApplyLeaf("correlateMaxPool", kernel_size = self.kernel_size, stride = self.kernel_size, max_type = self.max_type, max_pool = F.max_pool3d)
def showNet(self, t = ""):
print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " kernel_size="+ str(self.kernel_size) + " max_type=" +self.max_type)
class CorrFix(Concretize):
def init(self,in_shape, k, only_train = True, **kargs):
self.k = k
self.only_train = only_train
return in_shape
def abstract_forward(self, x):
sz = x.size()
"""
# for more control in the future
indxs_1 = torch.arange(start = 0, end = sz[1], step = math.ceil(sz[1] / self.dims[1]) )
indxs_2 = torch.arange(start = 0, end = sz[2], step = math.ceil(sz[2] / self.dims[2]) )
indxs_3 = torch.arange(start = 0, end = sz[3], step = math.ceil(sz[3] / self.dims[3]) )
indxs = torch.stack(torch.meshgrid((indxs_1,indxs_2,indxs_3)), dim=3).view(-1,3)
"""
szm = h.product(sz[1:])
indxs = torch.arange(start = 0, end = szm, step = math.ceil(szm / self.k))
indxs = indxs.unsqueeze(0).expand(sz[0], indxs.size()[0])
return x.abstractApplyLeaf("correlate", indxs)
def showNet(self, t = ""):
print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " k="+ str(self.k))
class DecorrRand(Concretize):
def init(self, in_shape, num_decorrelate, only_train = True, **kargs):
self.only_train = only_train
self.num_decorrelate = num_decorrelate
return in_shape
def abstract_forward(self, x):
return x.abstractApplyLeaf("stochasticDecorrelate", self.num_decorrelate)
class DecorrMin(Concretize):
def init(self, in_shape, num_decorrelate, only_train = True, num_to_keep = False, **kargs):
self.only_train = only_train
self.num_decorrelate = num_decorrelate
self.num_to_keep = num_to_keep
return in_shape
def abstract_forward(self, x):
return x.abstractApplyLeaf("decorrelateMin", self.num_decorrelate, num_to_keep = self.num_to_keep)
def showNet(self, t = ""):
print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " k="+ str(self.num_decorrelate) + " num_to_keep=" + str(self.num_to_keep) )
class DeepLoss(ToZono):
def init(self, in_shape, bw = 0.01, act = F.relu, **kargs): # weight must be between 0 and 1
self.only_train = True
self.bw = S.Const.initConst(bw)
self.act = act
return in_shape
def abstract_forward(self, x, **kargs):
if x.isPoint():
return x
return ai.TaggedDomain(x, self.MLoss(self, x))
class MLoss():
def __init__(self, obj, x):
self.obj = obj
self.x = x
def loss(self, a, *args, lr = 1, time = 0, **kargs):
bw = self.obj.bw.getVal(time = time)
pre_loss = a.loss(*args, time = time, **kargs, lr = lr * (1 - bw))
if bw <= 0.0:
return pre_loss
return (1 - bw) * pre_loss + bw * self.x.deep_loss(act = self.obj.act)
def showNet(self, t = ""):
print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " bw="+ str(self.bw) + " act=" + str(self.act) )
class IdentLoss(DeepLoss):
def abstract_forward(self, x, **kargs):
return x
def SkipNet(net1, net2, ffnn, **kargs):
return Seq(Skip(net1,net2), FFNN(ffnn, **kargs))
def WideBlock(out_filters, downsample=False, k=3, bias=False, **kargs):
if not downsample:
k_first = 3
skip_stride = 1
k_skip = 1
else:
k_first = 4
skip_stride = 2
k_skip = 2
# conv2d280(input)
blockA = Conv2D(out_filters, kernel_size=k_skip, stride=skip_stride, padding=0, bias=bias, normal=True, **kargs)
# conv2d282(relu(conv2d278(input)))
blockB = Seq( Conv(out_filters, kernel_size = k_first, stride = skip_stride, padding = 1, bias=bias, normal=True, **kargs)
, Conv2D(out_filters, kernel_size = k, stride = 1, padding = 1, bias=bias, normal=True, **kargs))
return Seq(ParSum(blockA, blockB), activation(**kargs))
def BasicBlock(in_planes, planes, stride=1, bias = False, skip_net = False, **kargs):
block = Seq( Conv(planes, kernel_size = 3, stride = stride, padding = 1, bias=bias, normal=True, **kargs)
, Conv2D(planes, kernel_size = 3, stride = 1, padding = 1, bias=bias, normal=True, **kargs))
if stride != 1 or in_planes != planes:
block = ParSum(block, Conv2D(planes, kernel_size=1, stride=stride, bias=bias, normal=True, **kargs))
elif not skip_net:
block = ParSum(block, Identity())
return Seq(block, activation(**kargs))
# https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
def ResNet(blocksList, extra = [], bias = False, **kargs):
layers = []
in_planes = 64
planes = 64
stride = 0
for num_blocks in blocksList:
if stride < 2:
stride += 1
strides = [stride] + [1]*(num_blocks-1)
for stride in strides:
layers.append(BasicBlock(in_planes, planes, stride, bias = bias, **kargs))
in_planes = planes
planes *= 2
print("RESlayers: ", len(layers))
for e,l in extra:
layers[l] = Seq(layers[l], e)
return Seq(Conv(64, kernel_size=3, stride=1, padding = 1, bias=bias, normal=True, printShape=True),
*layers)
def DenseNet(growthRate, depth, reduction, num_classes, bottleneck = True):
def Bottleneck(growthRate):
interChannels = 4*growthRate
n = Seq( ReLU(),
Conv2D(interChannels, kernel_size=1, bias=True, ibp_init = True),
ReLU(),
Conv2D(growthRate, kernel_size=3, padding=1, bias=True, ibp_init = True)
)
return Skip(Identity(), n)
def SingleLayer(growthRate):
n = Seq( ReLU(),
Conv2D(growthRate, kernel_size=3, padding=1, bias=True, ibp_init = True))
return Skip(Identity(), n)
def Transition(nOutChannels):
return Seq( ReLU(),
Conv2D(nOutChannels, kernel_size = 1, bias = True, ibp_init = True),
AvgPool2D(kernel_size=2))
def make_dense(growthRate, nDenseBlocks, bottleneck):
return Seq(*[Bottleneck(growthRate) if bottleneck else SingleLayer(growthRate) for i in range(nDenseBlocks)])
nDenseBlocks = (depth-4) // 3
if bottleneck:
nDenseBlocks //= 2
nChannels = 2*growthRate
conv1 = Conv2D(nChannels, kernel_size=3, padding=1, bias=True, ibp_init = True)
dense1 = make_dense(growthRate, nDenseBlocks, bottleneck)
nChannels += nDenseBlocks * growthRate
nOutChannels = int(math.floor(nChannels*reduction))
trans1 = Transition(nOutChannels)
nChannels = nOutChannels
dense2 = make_dense(growthRate, nDenseBlocks, bottleneck)
nChannels += nDenseBlocks*growthRate
nOutChannels = int(math.floor(nChannels*reduction))
trans2 = Transition(nOutChannels)
nChannels = nOutChannels
dense3 = make_dense(growthRate, nDenseBlocks, bottleneck)
return Seq(conv1, dense1, trans1, dense2, trans2, dense3,
ReLU(),
AvgPool2D(kernel_size=8),
CorrelateAll(only_train=False, ignore_point = True),
Linear(num_classes, ibp_init = True))