import torch import torch.nn.functional as F import torch.nn as nn from torch.distributions import multinomial, categorical import torch.optim as optim import math try: from . import helpers as h from . import ai from . import scheduling as S except: import helpers as h import ai import scheduling as S import math import abc from torch.nn.modules.conv import _ConvNd from enum import Enum class InferModule(nn.Module): def __init__(self, *args, normal = False, ibp_init = False, **kwargs): self.args = args self.kwargs = kwargs self.infered = False self.normal = normal self.ibp_init = ibp_init def infer(self, in_shape, global_args = None): """ this is really actually stateful. """ if self.infered: return self self.infered = True super(InferModule, self).__init__() self.inShape = list(in_shape) self.outShape = list(self.init(list(in_shape), *self.args, global_args = global_args, **self.kwargs)) if self.outShape is None: raise "init should set the out_shape" self.reset_parameters() return self def reset_parameters(self): if not hasattr(self,'weight') or self.weight is None: return n = h.product(self.weight.size()) / self.outShape[0] stdv = 1 / math.sqrt(n) if self.ibp_init: torch.nn.init.orthogonal_(self.weight.data) elif self.normal: self.weight.data.normal_(0, stdv) self.weight.data.clamp_(-1, 1) else: self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: if self.ibp_init: self.bias.data.zero_() elif self.normal: self.bias.data.normal_(0, stdv) self.bias.data.clamp_(-1, 1) else: self.bias.data.uniform_(-stdv, stdv) def clip_norm(self): if not hasattr(self, "weight"): return if not hasattr(self,"weight_g"): if torch.__version__[0] == "0": nn.utils.weight_norm(self, dim=None) else: nn.utils.weight_norm(self) self.weight_g.data.clamp_(-h.max_c_for_norm, h.max_c_for_norm) if torch.__version__[0] != "0": self.weight_v.data.clamp_(-h.max_c_for_norm * 10000,h.max_c_for_norm * 10000) if hasattr(self, "bias"): self.bias.data.clamp_(-h.max_c_for_norm * 10000, h.max_c_for_norm * 10000) def regularize(self, p): reg = 0 if torch.__version__[0] == "0": for param in self.parameters(): reg += param.norm(p) else: if hasattr(self, "weight_g"): reg += self.weight_g.norm().sum() reg += self.weight_v.norm().sum() elif hasattr(self, "weight"): reg += self.weight.norm().sum() if hasattr(self, "bias"): reg += self.bias.view(-1).norm(p=p).sum() return reg def remove_norm(self): if hasattr(self,"weight_g"): torch.nn.utils.remove_weight_norm(self) def showNet(self, t = ""): print(t + self.__class__.__name__) def printNet(self, f): print(self.__class__.__name__, file=f) @abc.abstractmethod def forward(self, *args, **kargs): pass def __call__(self, *args, onyx=False, **kargs): if onyx: return self.forward(*args, onyx=onyx, **kargs) else: return super(InferModule, self).__call__(*args, **kargs) @abc.abstractmethod def neuronCount(self): pass def depth(self): return 0 def getShapeConv(in_shape, conv_shape, stride = 1, padding = 0): inChan, inH, inW = in_shape outChan, kH, kW = conv_shape[:3] outH = 1 + int((2 * padding + inH - kH) / stride) outW = 1 + int((2 * padding + inW - kW) / stride) return (outChan, outH, outW) def getShapeConvTranspose(in_shape, conv_shape, stride = 1, padding = 0, out_padding=0): inChan, inH, inW = in_shape outChan, kH, kW = conv_shape[:3] outH = (inH - 1 ) * stride - 2 * padding + kH + out_padding outW = (inW - 1 ) * stride - 2 * padding + kW + out_padding return (outChan, outH, outW) class Linear(InferModule): def init(self, in_shape, out_shape, **kargs): self.in_neurons = h.product(in_shape) if isinstance(out_shape, int): out_shape = [out_shape] self.out_neurons = h.product(out_shape) self.weight = torch.nn.Parameter(torch.Tensor(self.in_neurons, self.out_neurons)) self.bias = torch.nn.Parameter(torch.Tensor(self.out_neurons)) return out_shape def forward(self, x, **kargs): s = x.size() x = x.view(s[0], h.product(s[1:])) return (x.matmul(self.weight) + self.bias).view(s[0], *self.outShape) def neuronCount(self): return 0 def showNet(self, t = ""): print(t + "Linear out=" + str(self.out_neurons)) def printNet(self, f): print("Linear(" + str(self.out_neurons) + ")" ) print(h.printListsNumpy(list(self.weight.transpose(1,0).data)), file= f) print(h.printNumpy(self.bias), file= f) class Activation(InferModule): def init(self, in_shape, global_args = None, activation = "ReLU", **kargs): self.activation = [ "ReLU","Sigmoid", "Tanh", "Softplus", "ELU", "SELU"].index(activation) self.activation_name = activation return in_shape def regularize(self, p): return 0 def forward(self, x, **kargs): return [lambda x:x.relu(), lambda x:x.sigmoid(), lambda x:x.tanh(), lambda x:x.softplus(), lambda x:x.elu(), lambda x:x.selu()][self.activation](x) def neuronCount(self): return h.product(self.outShape) def depth(self): return 1 def showNet(self, t = ""): print(t + self.activation_name) def printNet(self, f): pass class ReLU(Activation): pass def activation(*args, batch_norm = False, **kargs): a = Activation(*args, **kargs) return Seq(BatchNorm(), a) if batch_norm else a class Identity(InferModule): # for feigning model equivelence when removing an op def init(self, in_shape, global_args = None, **kargs): return in_shape def forward(self, x, **kargs): return x def neuronCount(self): return 0 def printNet(self, f): pass def regularize(self, p): return 0 def showNet(self, *args, **kargs): pass class Dropout(InferModule): def init(self, in_shape, p=0.5, use_2d = False, alpha_dropout = False, **kargs): self.p = S.Const.initConst(p) self.use_2d = use_2d self.alpha_dropout = alpha_dropout return in_shape def forward(self, x, time = 0, **kargs): if self.training: with torch.no_grad(): p = self.p.getVal(time = time) mask = (F.dropout2d if self.use_2d else F.dropout)(h.ones(x.size()),p=p, training=True) if self.alpha_dropout: with torch.no_grad(): keep_prob = 1 - p alpha = -1.7580993408473766 a = math.pow(keep_prob + alpha * alpha * keep_prob * (1 - keep_prob), -0.5) b = -a * alpha * (1 - keep_prob) mask = mask * a return x * mask + b else: return x * mask else: return x def neuronCount(self): return 0 def showNet(self, t = ""): print(t + "Dropout p=" + str(self.p)) def printNet(self, f): print("Dropout(" + str(self.p) + ")" ) class PrintActivation(Identity): def init(self, in_shape, global_args = None, activation = "ReLU", **kargs): self.activation = activation return in_shape def printNet(self, f): print(self.activation, file = f) class PrintReLU(PrintActivation): pass class Conv2D(InferModule): def init(self, in_shape, out_channels, kernel_size, stride = 1, global_args = None, bias=True, padding = 0, activation = "ReLU", **kargs): self.prev = in_shape self.in_channels = in_shape[0] self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.activation = activation self.use_softplus = h.default(global_args, 'use_softplus', False) weights_shape = (self.out_channels, self.in_channels, kernel_size, kernel_size) self.weight = torch.nn.Parameter(torch.Tensor(*weights_shape)) if bias: self.bias = torch.nn.Parameter(torch.Tensor(weights_shape[0])) else: self.bias = None # h.zeros(weights_shape[0]) outshape = getShapeConv(in_shape, (out_channels, kernel_size, kernel_size), stride, padding) return outshape def forward(self, input, **kargs): return input.conv2d(self.weight, bias=self.bias, stride=self.stride, padding = self.padding ) def printNet(self, f): # only complete if we've forwardt stride=1 print("Conv2D", file = f) sz = list(self.prev) print(self.activation + ", filters={}, kernel_size={}, input_shape={}, stride={}, padding={}".format(self.out_channels, [self.kernel_size, self.kernel_size], list(reversed(sz)), [self.stride, self.stride], self.padding ), file = f) print(h.printListsNumpy([[list(p) for p in l ] for l in self.weight.permute(2,3,1,0).data]) , file= f) print(h.printNumpy(self.bias if self.bias is not None else h.dten(self.out_channels)), file= f) def showNet(self, t = ""): sz = list(self.prev) print(t + "Conv2D, filters={}, kernel_size={}, input_shape={}, stride={}, padding={}".format(self.out_channels, [self.kernel_size, self.kernel_size], list(reversed(sz)), [self.stride, self.stride], self.padding )) def neuronCount(self): return 0 class ConvTranspose2D(InferModule): def init(self, in_shape, out_channels, kernel_size, stride = 1, global_args = None, bias=True, padding = 0, out_padding=0, activation = "ReLU", **kargs): self.prev = in_shape self.in_channels = in_shape[0] self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.out_padding = out_padding self.activation = activation self.use_softplus = h.default(global_args, 'use_softplus', False) weights_shape = (self.in_channels, self.out_channels, kernel_size, kernel_size) self.weight = torch.nn.Parameter(torch.Tensor(*weights_shape)) if bias: self.bias = torch.nn.Parameter(torch.Tensor(weights_shape[0])) else: self.bias = None # h.zeros(weights_shape[0]) outshape = getShapeConvTranspose(in_shape, (out_channels, kernel_size, kernel_size), stride, padding, out_padding) return outshape def forward(self, input, **kargs): return input.conv_transpose2d(self.weight, bias=self.bias, stride=self.stride, padding = self.padding, output_padding=self.out_padding) def printNet(self, f): # only complete if we've forwardt stride=1 print("ConvTranspose2D", file = f) print(self.activation + ", filters={}, kernel_size={}, input_shape={}".format(self.out_channels, list(self.kernel_size), list(self.prev) ), file = f) print(h.printListsNumpy([[list(p) for p in l ] for l in self.weight.permute(2,3,1,0).data]) , file= f) print(h.printNumpy(self.bias), file= f) def neuronCount(self): return 0 class MaxPool2D(InferModule): def init(self, in_shape, kernel_size, stride = None, **kargs): self.prev = in_shape self.kernel_size = kernel_size self.stride = kernel_size if stride is None else stride return getShapeConv(in_shape, (in_shape[0], kernel_size, kernel_size), stride) def forward(self, x, **kargs): return x.max_pool2d(self.kernel_size, self.stride) def printNet(self, f): print("MaxPool2D stride={}, kernel_size={}, input_shape={}".format(list(self.stride), list(self.shape[2:]), list(self.prev[1:]+self.prev[:1]) ), file = f) def neuronCount(self): return h.product(self.outShape) class AvgPool2D(InferModule): def init(self, in_shape, kernel_size, stride = None, **kargs): self.prev = in_shape self.kernel_size = kernel_size self.stride = kernel_size if stride is None else stride out_size = getShapeConv(in_shape, (in_shape[0], kernel_size, kernel_size), self.stride, padding = 1) return out_size def forward(self, x, **kargs): if h.product(x.size()[2:]) == 1: return x return x.avg_pool2d(kernel_size = self.kernel_size, stride = self.stride, padding = 1) def printNet(self, f): print("AvgPool2D stride={}, kernel_size={}, input_shape={}".format(list(self.stride), list(self.shape[2:]), list(self.prev[1:]+self.prev[:1]) ), file = f) def neuronCount(self): return h.product(self.outShape) class AdaptiveAvgPool2D(InferModule): def init(self, in_shape, out_shape, **kargs): self.prev = in_shape self.out_shape = list(out_shape) return [in_shape[0]] + self.out_shape def forward(self, x, **kargs): return x.adaptive_avg_pool2d(self.out_shape) def printNet(self, f): print("AdaptiveAvgPool2D out_Shape={} input_shape={}".format(list(self.out_shape), list(self.prev[1:]+self.prev[:1]) ), file = f) def neuronCount(self): return h.product(self.outShape) class Normalize(InferModule): def init(self, in_shape, mean, std, **kargs): self.mean_v = mean self.std_v = std self.mean = h.dten(mean) self.std = 1 / h.dten(std) return in_shape def forward(self, x, **kargs): mean_ex = self.mean.view(self.mean.shape[0],1,1).expand(*x.size()[1:]) std_ex = self.std.view(self.std.shape[0],1,1).expand(*x.size()[1:]) return (x - mean_ex) * std_ex def neuronCount(self): return 0 def printNet(self, f): print("Normalize mean={} std={}".format(self.mean_v, self.std_v), file = f) def showNet(self, t = ""): print(t + "Normalize mean={} std={}".format(self.mean_v, self.std_v)) class Flatten(InferModule): def init(self, in_shape, **kargs): return h.product(in_shape) def forward(self, x, **kargs): s = x.size() return x.view(s[0], h.product(s[1:])) def neuronCount(self): return 0 class BatchNorm(InferModule): def init(self, in_shape, track_running_stats = True, momentum = 0.1, eps=1e-5, **kargs): self.gamma = torch.nn.Parameter(torch.Tensor(*in_shape)) self.beta = torch.nn.Parameter(torch.Tensor(*in_shape)) self.eps = eps self.track_running_stats = track_running_stats self.momentum = momentum self.running_mean = None self.running_var = None self.num_batches_tracked = 0 return in_shape def reset_parameters(self): self.gamma.data.fill_(1) self.beta.data.zero_() def forward(self, x, **kargs): exponential_average_factor = 0.0 if self.training and self.track_running_stats: # TODO: if statement only here to tell the jit to skip emitting this when it is None if self.num_batches_tracked is not None: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum new_mean = x.vanillaTensorPart().detach().mean(dim=0) new_var = x.vanillaTensorPart().detach().var(dim=0, unbiased=False) if torch.isnan(new_var * 0).any(): return x if self.training: self.running_mean = (1 - exponential_average_factor) * self.running_mean + exponential_average_factor * new_mean if self.running_mean is not None else new_mean if self.running_var is None: self.running_var = new_var else: q = (1 - exponential_average_factor) * self.running_var r = exponential_average_factor * new_var self.running_var = q + r if self.track_running_stats and self.running_mean is not None and self.running_var is not None: new_mean = self.running_mean new_var = self.running_var diver = 1 / (new_var + self.eps).sqrt() if torch.isnan(diver).any(): print("Really shouldn't happen ever") return x else: out = (x - new_mean) * diver * self.gamma + self.beta return out def neuronCount(self): return 0 class Unflatten2d(InferModule): def init(self, in_shape, w, **kargs): self.w = w self.outChan = int(h.product(in_shape) / (w * w)) return (self.outChan, self.w, self.w) def forward(self, x, **kargs): s = x.size() return x.view(s[0], self.outChan, self.w, self.w) def neuronCount(self): return 0 class View(InferModule): def init(self, in_shape, out_shape, **kargs): assert(h.product(in_shape) == h.product(out_shape)) return out_shape def forward(self, x, **kargs): s = x.size() return x.view(s[0], *self.outShape) def neuronCount(self): return 0 class Seq(InferModule): def init(self, in_shape, *layers, **kargs): self.layers = layers self.net = nn.Sequential(*layers) self.prev = in_shape for s in layers: in_shape = s.infer(in_shape, **kargs).outShape return in_shape def forward(self, x, **kargs): for l in self.layers: x = l(x, **kargs) return x def clip_norm(self): for l in self.layers: l.clip_norm() def regularize(self, p): return sum(n.regularize(p) for n in self.layers) def remove_norm(self): for l in self.layers: l.remove_norm() def printNet(self, f): for l in self.layers: l.printNet(f) def showNet(self, *args, **kargs): for l in self.layers: l.showNet(*args, **kargs) def neuronCount(self): return sum([l.neuronCount() for l in self.layers ]) def depth(self): return sum([l.depth() for l in self.layers ]) def FFNN(layers, last_lin = False, last_zono = False, **kargs): starts = layers ends = [] if last_lin: ends = ([CorrelateAll(only_train=False)] if last_zono else []) + [PrintActivation(activation = "Affine"), Linear(layers[-1],**kargs)] starts = layers[:-1] return Seq(*([ Seq(PrintActivation(**kargs), Linear(s, **kargs), activation(**kargs)) for s in starts] + ends)) def Conv(*args, **kargs): return Seq(Conv2D(*args, **kargs), activation(**kargs)) def ConvTranspose(*args, **kargs): return Seq(ConvTranspose2D(*args, **kargs), activation(**kargs)) MP = MaxPool2D def LeNet(conv_layers, ly = [], bias = True, normal=False, **kargs): def transfer(tp): if isinstance(tp, InferModule): return tp if isinstance(tp[0], str): return MaxPool2D(*tp[1:]) return Conv(out_channels = tp[0], kernel_size = tp[1], stride = tp[-1] if len(tp) == 4 else 1, bias=bias, normal=normal, **kargs) conv = [transfer(s) for s in conv_layers] return Seq(*conv, FFNN(ly, **kargs, bias=bias)) if len(ly) > 0 else Seq(*conv) def InvLeNet(ly, w, conv_layers, bias = True, normal=False, **kargs): def transfer(tp): return ConvTranspose(out_channels = tp[0], kernel_size = tp[1], stride = tp[2], padding = tp[3], out_padding = tp[4], bias=False, normal=normal) return Seq(FFNN(ly, bias=bias), Unflatten2d(w), *[transfer(s) for s in conv_layers]) class FromByteImg(InferModule): def init(self, in_shape, **kargs): return in_shape def forward(self, x, **kargs): return x.to_dtype()/ 256. def neuronCount(self): return 0 class Skip(InferModule): def init(self, in_shape, net1, net2, **kargs): self.net1 = net1.infer(in_shape, **kargs) self.net2 = net2.infer(in_shape, **kargs) assert(net1.outShape[1:] == net2.outShape[1:]) return [ net1.outShape[0] + net2.outShape[0] ] + net1.outShape[1:] def forward(self, x, **kargs): r1 = self.net1(x, **kargs) r2 = self.net2(x, **kargs) return r1.cat(r2, dim=1) def regularize(self, p): return self.net1.regularize(p) + self.net2.regularize(p) def clip_norm(self): self.net1.clip_norm() self.net2.clip_norm() def remove_norm(self): self.net1.remove_norm() self.net2.remove_norm() def neuronCount(self): return self.net1.neuronCount() + self.net2.neuronCount() def printNet(self, f): print("SkipNet1", file=f) self.net1.printNet(f) print("SkipNet2", file=f) self.net2.printNet(f) print("SkipCat dim=1", file=f) def showNet(self, t = ""): print(t+"SkipNet1") self.net1.showNet(" "+t) print(t+"SkipNet2") self.net2.showNet(" "+t) print(t+"SkipCat dim=1") class ParSum(InferModule): def init(self, in_shape, net1, net2, **kargs): self.net1 = net1.infer(in_shape, **kargs) self.net2 = net2.infer(in_shape, **kargs) assert(net1.outShape == net2.outShape) return net1.outShape def forward(self, x, **kargs): r1 = self.net1(x, **kargs) r2 = self.net2(x, **kargs) return x.addPar(r1,r2) def clip_norm(self): self.net1.clip_norm() self.net2.clip_norm() def remove_norm(self): self.net1.remove_norm() self.net2.remove_norm() def neuronCount(self): return self.net1.neuronCount() + self.net2.neuronCount() def depth(self): return max(self.net1.depth(), self.net2.depth()) def printNet(self, f): print("ParNet1", file=f) self.net1.printNet(f) print("ParNet2", file=f) self.net2.printNet(f) print("ParCat dim=1", file=f) def showNet(self, t = ""): print(t + "ParNet1") self.net1.showNet(" "+t) print(t + "ParNet2") self.net2.showNet(" "+t) print(t + "ParSum") class ToZono(Identity): def init(self, in_shape, customRelu = None, only_train = False, **kargs): self.customRelu = customRelu self.only_train = only_train return in_shape def forward(self, x, **kargs): return self.abstract_forward(x, **kargs) if self.training or not self.only_train else x def abstract_forward(self, x, **kargs): return x.abstractApplyLeaf('hybrid_to_zono', customRelu = self.customRelu) def showNet(self, t = ""): print(t + self.__class__.__name__ + " only_train=" + str(self.only_train)) class CorrelateAll(ToZono): def abstract_forward(self, x, **kargs): return x.abstractApplyLeaf('hybrid_to_zono',correlate=True, customRelu = self.customRelu) class ToHZono(ToZono): def abstract_forward(self, x, **kargs): return x.abstractApplyLeaf('zono_to_hybrid',customRelu = self.customRelu) class Concretize(ToZono): def init(self, in_shape, only_train = True, **kargs): self.only_train = only_train return in_shape def abstract_forward(self, x, **kargs): return x.abstractApplyLeaf('concretize') # stochastic correlation class CorrRand(Concretize): def init(self, in_shape, num_correlate, only_train = True, **kargs): self.only_train = only_train self.num_correlate = num_correlate return in_shape def abstract_forward(self, x): return x.abstractApplyLeaf("stochasticCorrelate", self.num_correlate) def showNet(self, t = ""): print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " num_correlate="+ str(self.num_correlate)) class CorrMaxK(CorrRand): def abstract_forward(self, x): return x.abstractApplyLeaf("correlateMaxK", self.num_correlate) class CorrMaxPool2D(Concretize): def init(self,in_shape, kernel_size, only_train = True, max_type = ai.MaxTypes.head_beta, **kargs): self.only_train = only_train self.kernel_size = kernel_size self.max_type = max_type return in_shape def abstract_forward(self, x): return x.abstractApplyLeaf("correlateMaxPool", kernel_size = self.kernel_size, stride = self.kernel_size, max_type = self.max_type) def showNet(self, t = ""): print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " kernel_size="+ str(self.kernel_size) + " max_type=" +str(self.max_type)) class CorrMaxPool3D(Concretize): def init(self,in_shape, kernel_size, only_train = True, max_type = ai.MaxTypes.only_beta, **kargs): self.only_train = only_train self.kernel_size = kernel_size self.max_type = max_type return in_shape def abstract_forward(self, x): return x.abstractApplyLeaf("correlateMaxPool", kernel_size = self.kernel_size, stride = self.kernel_size, max_type = self.max_type, max_pool = F.max_pool3d) def showNet(self, t = ""): print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " kernel_size="+ str(self.kernel_size) + " max_type=" +self.max_type) class CorrFix(Concretize): def init(self,in_shape, k, only_train = True, **kargs): self.k = k self.only_train = only_train return in_shape def abstract_forward(self, x): sz = x.size() """ # for more control in the future indxs_1 = torch.arange(start = 0, end = sz[1], step = math.ceil(sz[1] / self.dims[1]) ) indxs_2 = torch.arange(start = 0, end = sz[2], step = math.ceil(sz[2] / self.dims[2]) ) indxs_3 = torch.arange(start = 0, end = sz[3], step = math.ceil(sz[3] / self.dims[3]) ) indxs = torch.stack(torch.meshgrid((indxs_1,indxs_2,indxs_3)), dim=3).view(-1,3) """ szm = h.product(sz[1:]) indxs = torch.arange(start = 0, end = szm, step = math.ceil(szm / self.k)) indxs = indxs.unsqueeze(0).expand(sz[0], indxs.size()[0]) return x.abstractApplyLeaf("correlate", indxs) def showNet(self, t = ""): print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " k="+ str(self.k)) class DecorrRand(Concretize): def init(self, in_shape, num_decorrelate, only_train = True, **kargs): self.only_train = only_train self.num_decorrelate = num_decorrelate return in_shape def abstract_forward(self, x): return x.abstractApplyLeaf("stochasticDecorrelate", self.num_decorrelate) class DecorrMin(Concretize): def init(self, in_shape, num_decorrelate, only_train = True, num_to_keep = False, **kargs): self.only_train = only_train self.num_decorrelate = num_decorrelate self.num_to_keep = num_to_keep return in_shape def abstract_forward(self, x): return x.abstractApplyLeaf("decorrelateMin", self.num_decorrelate, num_to_keep = self.num_to_keep) def showNet(self, t = ""): print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " k="+ str(self.num_decorrelate) + " num_to_keep=" + str(self.num_to_keep) ) class DeepLoss(ToZono): def init(self, in_shape, bw = 0.01, act = F.relu, **kargs): # weight must be between 0 and 1 self.only_train = True self.bw = S.Const.initConst(bw) self.act = act return in_shape def abstract_forward(self, x, **kargs): if x.isPoint(): return x return ai.TaggedDomain(x, self.MLoss(self, x)) class MLoss(): def __init__(self, obj, x): self.obj = obj self.x = x def loss(self, a, *args, lr = 1, time = 0, **kargs): bw = self.obj.bw.getVal(time = time) pre_loss = a.loss(*args, time = time, **kargs, lr = lr * (1 - bw)) if bw <= 0.0: return pre_loss return (1 - bw) * pre_loss + bw * self.x.deep_loss(act = self.obj.act) def showNet(self, t = ""): print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " bw="+ str(self.bw) + " act=" + str(self.act) ) class IdentLoss(DeepLoss): def abstract_forward(self, x, **kargs): return x def SkipNet(net1, net2, ffnn, **kargs): return Seq(Skip(net1,net2), FFNN(ffnn, **kargs)) def WideBlock(out_filters, downsample=False, k=3, bias=False, **kargs): if not downsample: k_first = 3 skip_stride = 1 k_skip = 1 else: k_first = 4 skip_stride = 2 k_skip = 2 # conv2d280(input) blockA = Conv2D(out_filters, kernel_size=k_skip, stride=skip_stride, padding=0, bias=bias, normal=True, **kargs) # conv2d282(relu(conv2d278(input))) blockB = Seq( Conv(out_filters, kernel_size = k_first, stride = skip_stride, padding = 1, bias=bias, normal=True, **kargs) , Conv2D(out_filters, kernel_size = k, stride = 1, padding = 1, bias=bias, normal=True, **kargs)) return Seq(ParSum(blockA, blockB), activation(**kargs)) def BasicBlock(in_planes, planes, stride=1, bias = False, skip_net = False, **kargs): block = Seq( Conv(planes, kernel_size = 3, stride = stride, padding = 1, bias=bias, normal=True, **kargs) , Conv2D(planes, kernel_size = 3, stride = 1, padding = 1, bias=bias, normal=True, **kargs)) if stride != 1 or in_planes != planes: block = ParSum(block, Conv2D(planes, kernel_size=1, stride=stride, bias=bias, normal=True, **kargs)) elif not skip_net: block = ParSum(block, Identity()) return Seq(block, activation(**kargs)) # https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py def ResNet(blocksList, extra = [], bias = False, **kargs): layers = [] in_planes = 64 planes = 64 stride = 0 for num_blocks in blocksList: if stride < 2: stride += 1 strides = [stride] + [1]*(num_blocks-1) for stride in strides: layers.append(BasicBlock(in_planes, planes, stride, bias = bias, **kargs)) in_planes = planes planes *= 2 print("RESlayers: ", len(layers)) for e,l in extra: layers[l] = Seq(layers[l], e) return Seq(Conv(64, kernel_size=3, stride=1, padding = 1, bias=bias, normal=True, printShape=True), *layers) def DenseNet(growthRate, depth, reduction, num_classes, bottleneck = True): def Bottleneck(growthRate): interChannels = 4*growthRate n = Seq( ReLU(), Conv2D(interChannels, kernel_size=1, bias=True, ibp_init = True), ReLU(), Conv2D(growthRate, kernel_size=3, padding=1, bias=True, ibp_init = True) ) return Skip(Identity(), n) def SingleLayer(growthRate): n = Seq( ReLU(), Conv2D(growthRate, kernel_size=3, padding=1, bias=True, ibp_init = True)) return Skip(Identity(), n) def Transition(nOutChannels): return Seq( ReLU(), Conv2D(nOutChannels, kernel_size = 1, bias = True, ibp_init = True), AvgPool2D(kernel_size=2)) def make_dense(growthRate, nDenseBlocks, bottleneck): return Seq(*[Bottleneck(growthRate) if bottleneck else SingleLayer(growthRate) for i in range(nDenseBlocks)]) nDenseBlocks = (depth-4) // 3 if bottleneck: nDenseBlocks //= 2 nChannels = 2*growthRate conv1 = Conv2D(nChannels, kernel_size=3, padding=1, bias=True, ibp_init = True) dense1 = make_dense(growthRate, nDenseBlocks, bottleneck) nChannels += nDenseBlocks * growthRate nOutChannels = int(math.floor(nChannels*reduction)) trans1 = Transition(nOutChannels) nChannels = nOutChannels dense2 = make_dense(growthRate, nDenseBlocks, bottleneck) nChannels += nDenseBlocks*growthRate nOutChannels = int(math.floor(nChannels*reduction)) trans2 = Transition(nOutChannels) nChannels = nOutChannels dense3 = make_dense(growthRate, nDenseBlocks, bottleneck) return Seq(conv1, dense1, trans1, dense2, trans2, dense3, ReLU(), AvgPool2D(kernel_size=8), CorrelateAll(only_train=False, ignore_point = True), Linear(num_classes, ibp_init = True))