import torch |
import torch.nn.functional as F |
import torch.nn as nn |
from torch.distributions import multinomial, categorical |
import torch.optim as optim |
import math |
try: |
from . import helpers as h |
from . import ai |
from . import scheduling as S |
except: |
import helpers as h |
import ai |
import scheduling as S |
import math |
import abc |
from torch.nn.modules.conv import _ConvNd |
from enum import Enum |
class InferModule(nn.Module): |
def __init__(self, *args, normal = False, ibp_init = False, **kwargs): |
self.args = args |
self.kwargs = kwargs |
self.infered = False |
self.normal = normal |
self.ibp_init = ibp_init |
def infer(self, in_shape, global_args = None): |
""" this is really actually stateful. """ |
if self.infered: |
return self |
self.infered = True |
super(InferModule, self).__init__() |
self.inShape = list(in_shape) |
self.outShape = list(self.init(list(in_shape), *self.args, global_args = global_args, **self.kwargs)) |
if self.outShape is None: |
raise "init should set the out_shape" |
self.reset_parameters() |
return self |
def reset_parameters(self): |
if not hasattr(self,'weight') or self.weight is None: |
return |
n = h.product(self.weight.size()) / self.outShape[0] |
stdv = 1 / math.sqrt(n) |
if self.ibp_init: |
torch.nn.init.orthogonal_(self.weight.data) |
elif self.normal: |
self.weight.data.normal_(0, stdv) |
self.weight.data.clamp_(-1, 1) |
else: |
self.weight.data.uniform_(-stdv, stdv) |
if self.bias is not None: |
if self.ibp_init: |
self.bias.data.zero_() |
elif self.normal: |
self.bias.data.normal_(0, stdv) |
self.bias.data.clamp_(-1, 1) |
else: |
self.bias.data.uniform_(-stdv, stdv) |
def clip_norm(self): |
if not hasattr(self, "weight"): |
return |
if not hasattr(self,"weight_g"): |
if torch.__version__[0] == "0": |
nn.utils.weight_norm(self, dim=None) |
else: |
nn.utils.weight_norm(self) |
self.weight_g.data.clamp_(-h.max_c_for_norm, h.max_c_for_norm) |
if torch.__version__[0] != "0": |
self.weight_v.data.clamp_(-h.max_c_for_norm * 10000,h.max_c_for_norm * 10000) |
if hasattr(self, "bias"): |
self.bias.data.clamp_(-h.max_c_for_norm * 10000, h.max_c_for_norm * 10000) |
def regularize(self, p): |
reg = 0 |
if torch.__version__[0] == "0": |
for param in self.parameters(): |
reg += param.norm(p) |
else: |
if hasattr(self, "weight_g"): |
reg += self.weight_g.norm().sum() |
reg += self.weight_v.norm().sum() |
elif hasattr(self, "weight"): |
reg += self.weight.norm().sum() |
if hasattr(self, "bias"): |
reg += self.bias.view(-1).norm(p=p).sum() |
return reg |
def remove_norm(self): |
if hasattr(self,"weight_g"): |
torch.nn.utils.remove_weight_norm(self) |
def showNet(self, t = ""): |
print(t + self.__class__.__name__) |
def printNet(self, f): |
print(self.__class__.__name__, file=f) |
@abc.abstractmethod |
def forward(self, *args, **kargs): |
pass |
def __call__(self, *args, onyx=False, **kargs): |
if onyx: |
return self.forward(*args, onyx=onyx, **kargs) |
else: |
return super(InferModule, self).__call__(*args, **kargs) |
@abc.abstractmethod |
def neuronCount(self): |
pass |
def depth(self): |
return 0 |
def getShapeConv(in_shape, conv_shape, stride = 1, padding = 0): |
inChan, inH, inW = in_shape |
outChan, kH, kW = conv_shape[:3] |
outH = 1 + int((2 * padding + inH - kH) / stride) |
outW = 1 + int((2 * padding + inW - kW) / stride) |
return (outChan, outH, outW) |
def getShapeConvTranspose(in_shape, conv_shape, stride = 1, padding = 0, out_padding=0): |
inChan, inH, inW = in_shape |
outChan, kH, kW = conv_shape[:3] |
outH = (inH - 1 ) * stride - 2 * padding + kH + out_padding |
outW = (inW - 1 ) * stride - 2 * padding + kW + out_padding |
return (outChan, outH, outW) |
class Linear(InferModule): |
def init(self, in_shape, out_shape, **kargs): |
self.in_neurons = h.product(in_shape) |
if isinstance(out_shape, int): |
out_shape = [out_shape] |
self.out_neurons = h.product(out_shape) |
self.weight = torch.nn.Parameter(torch.Tensor(self.in_neurons, self.out_neurons)) |
self.bias = torch.nn.Parameter(torch.Tensor(self.out_neurons)) |
return out_shape |
def forward(self, x, **kargs): |
s = x.size() |
x = x.view(s[0], h.product(s[1:])) |
return (x.matmul(self.weight) + self.bias).view(s[0], *self.outShape) |
def neuronCount(self): |
return 0 |
def showNet(self, t = ""): |
print(t + "Linear out=" + str(self.out_neurons)) |
def printNet(self, f): |
print("Linear(" + str(self.out_neurons) + ")" ) |
print(h.printListsNumpy(list(self.weight.transpose(1,0).data)), file= f) |
print(h.printNumpy(self.bias), file= f) |
class Activation(InferModule): |
def init(self, in_shape, global_args = None, activation = "ReLU", **kargs): |
self.activation = [ "ReLU","Sigmoid", "Tanh", "Softplus", "ELU", "SELU"].index(activation) |
self.activation_name = activation |
return in_shape |
def regularize(self, p): |
return 0 |
def forward(self, x, **kargs): |
return [lambda x:x.relu(), lambda x:x.sigmoid(), lambda x:x.tanh(), lambda x:x.softplus(), lambda x:x.elu(), lambda x:x.selu()][self.activation](x) |
def neuronCount(self): |
return h.product(self.outShape) |
def depth(self): |
return 1 |
def showNet(self, t = ""): |
print(t + self.activation_name) |
def printNet(self, f): |
pass |
class ReLU(Activation): |
pass |
def activation(*args, batch_norm = False, **kargs): |
a = Activation(*args, **kargs) |
return Seq(BatchNorm(), a) if batch_norm else a |
class Identity(InferModule): |
def init(self, in_shape, global_args = None, **kargs): |
return in_shape |
def forward(self, x, **kargs): |
return x |
def neuronCount(self): |
return 0 |
def printNet(self, f): |
pass |
def regularize(self, p): |
return 0 |
def showNet(self, *args, **kargs): |
pass |
class Dropout(InferModule): |
def init(self, in_shape, p=0.5, use_2d = False, alpha_dropout = False, **kargs): |
self.p = S.Const.initConst(p) |
self.use_2d = use_2d |
self.alpha_dropout = alpha_dropout |
return in_shape |
def forward(self, x, time = 0, **kargs): |
if self.training: |
with torch.no_grad(): |
p = self.p.getVal(time = time) |
mask = (F.dropout2d if self.use_2d else F.dropout)(h.ones(x.size()),p=p, training=True) |
if self.alpha_dropout: |
with torch.no_grad(): |
keep_prob = 1 - p |
alpha = -1.7580993408473766 |
a = math.pow(keep_prob + alpha * alpha * keep_prob * (1 - keep_prob), -0.5) |
b = -a * alpha * (1 - keep_prob) |
mask = mask * a |
return x * mask + b |
else: |
return x * mask |
else: |
return x |
def neuronCount(self): |
return 0 |
def showNet(self, t = ""): |
print(t + "Dropout p=" + str(self.p)) |
def printNet(self, f): |
print("Dropout(" + str(self.p) + ")" ) |
class PrintActivation(Identity): |
def init(self, in_shape, global_args = None, activation = "ReLU", **kargs): |
self.activation = activation |
return in_shape |
def printNet(self, f): |
print(self.activation, file = f) |
class PrintReLU(PrintActivation): |
pass |
class Conv2D(InferModule): |
def init(self, in_shape, out_channels, kernel_size, stride = 1, global_args = None, bias=True, padding = 0, activation = "ReLU", **kargs): |
self.prev = in_shape |
self.in_channels = in_shape[0] |
self.out_channels = out_channels |
self.kernel_size = kernel_size |
self.stride = stride |
self.padding = padding |
self.activation = activation |
self.use_softplus = h.default(global_args, 'use_softplus', False) |
weights_shape = (self.out_channels, self.in_channels, kernel_size, kernel_size) |
self.weight = torch.nn.Parameter(torch.Tensor(*weights_shape)) |
if bias: |
self.bias = torch.nn.Parameter(torch.Tensor(weights_shape[0])) |
else: |
self.bias = None |
outshape = getShapeConv(in_shape, (out_channels, kernel_size, kernel_size), stride, padding) |
return outshape |
def forward(self, input, **kargs): |
return input.conv2d(self.weight, bias=self.bias, stride=self.stride, padding = self.padding ) |
def printNet(self, f): |
print("Conv2D", file = f) |
sz = list(self.prev) |
print(self.activation + ", filters={}, kernel_size={}, input_shape={}, stride={}, padding={}".format(self.out_channels, [self.kernel_size, self.kernel_size], list(reversed(sz)), [self.stride, self.stride], self.padding ), file = f) |
print(h.printListsNumpy([[list(p) for p in l ] for l in self.weight.permute(2,3,1,0).data]) , file= f) |
print(h.printNumpy(self.bias if self.bias is not None else h.dten(self.out_channels)), file= f) |
def showNet(self, t = ""): |
sz = list(self.prev) |
print(t + "Conv2D, filters={}, kernel_size={}, input_shape={}, stride={}, padding={}".format(self.out_channels, [self.kernel_size, self.kernel_size], list(reversed(sz)), [self.stride, self.stride], self.padding )) |
def neuronCount(self): |
return 0 |
class ConvTranspose2D(InferModule): |
def init(self, in_shape, out_channels, kernel_size, stride = 1, global_args = None, bias=True, padding = 0, out_padding=0, activation = "ReLU", **kargs): |
self.prev = in_shape |
self.in_channels = in_shape[0] |
self.out_channels = out_channels |
self.kernel_size = kernel_size |
self.stride = stride |
self.padding = padding |
self.out_padding = out_padding |
self.activation = activation |
self.use_softplus = h.default(global_args, 'use_softplus', False) |
weights_shape = (self.in_channels, self.out_channels, kernel_size, kernel_size) |
self.weight = torch.nn.Parameter(torch.Tensor(*weights_shape)) |
if bias: |
self.bias = torch.nn.Parameter(torch.Tensor(weights_shape[0])) |
else: |
self.bias = None |
outshape = getShapeConvTranspose(in_shape, (out_channels, kernel_size, kernel_size), stride, padding, out_padding) |
return outshape |
def forward(self, input, **kargs): |
return input.conv_transpose2d(self.weight, bias=self.bias, stride=self.stride, padding = self.padding, output_padding=self.out_padding) |
def printNet(self, f): |
print("ConvTranspose2D", file = f) |
print(self.activation + ", filters={}, kernel_size={}, input_shape={}".format(self.out_channels, list(self.kernel_size), list(self.prev) ), file = f) |
print(h.printListsNumpy([[list(p) for p in l ] for l in self.weight.permute(2,3,1,0).data]) , file= f) |
print(h.printNumpy(self.bias), file= f) |
def neuronCount(self): |
return 0 |
class MaxPool2D(InferModule): |
def init(self, in_shape, kernel_size, stride = None, **kargs): |
self.prev = in_shape |
self.kernel_size = kernel_size |
self.stride = kernel_size if stride is None else stride |
return getShapeConv(in_shape, (in_shape[0], kernel_size, kernel_size), stride) |
def forward(self, x, **kargs): |
return x.max_pool2d(self.kernel_size, self.stride) |
def printNet(self, f): |
print("MaxPool2D stride={}, kernel_size={}, input_shape={}".format(list(self.stride), list(self.shape[2:]), list(self.prev[1:]+self.prev[:1]) ), file = f) |
def neuronCount(self): |
return h.product(self.outShape) |
class AvgPool2D(InferModule): |
def init(self, in_shape, kernel_size, stride = None, **kargs): |
self.prev = in_shape |
self.kernel_size = kernel_size |
self.stride = kernel_size if stride is None else stride |
out_size = getShapeConv(in_shape, (in_shape[0], kernel_size, kernel_size), self.stride, padding = 1) |
return out_size |
def forward(self, x, **kargs): |
if h.product(x.size()[2:]) == 1: |
return x |
return x.avg_pool2d(kernel_size = self.kernel_size, stride = self.stride, padding = 1) |
def printNet(self, f): |
print("AvgPool2D stride={}, kernel_size={}, input_shape={}".format(list(self.stride), list(self.shape[2:]), list(self.prev[1:]+self.prev[:1]) ), file = f) |
def neuronCount(self): |
return h.product(self.outShape) |
class AdaptiveAvgPool2D(InferModule): |
def init(self, in_shape, out_shape, **kargs): |
self.prev = in_shape |
self.out_shape = list(out_shape) |
return [in_shape[0]] + self.out_shape |
def forward(self, x, **kargs): |
return x.adaptive_avg_pool2d(self.out_shape) |
def printNet(self, f): |
print("AdaptiveAvgPool2D out_Shape={} input_shape={}".format(list(self.out_shape), list(self.prev[1:]+self.prev[:1]) ), file = f) |
def neuronCount(self): |
return h.product(self.outShape) |
class Normalize(InferModule): |
def init(self, in_shape, mean, std, **kargs): |
self.mean_v = mean |
self.std_v = std |
self.mean = h.dten(mean) |
self.std = 1 / h.dten(std) |
return in_shape |
def forward(self, x, **kargs): |
mean_ex = self.mean.view(self.mean.shape[0],1,1).expand(*x.size()[1:]) |
std_ex = self.std.view(self.std.shape[0],1,1).expand(*x.size()[1:]) |
return (x - mean_ex) * std_ex |
def neuronCount(self): |
return 0 |
def printNet(self, f): |
print("Normalize mean={} std={}".format(self.mean_v, self.std_v), file = f) |
def showNet(self, t = ""): |
print(t + "Normalize mean={} std={}".format(self.mean_v, self.std_v)) |
class Flatten(InferModule): |
def init(self, in_shape, **kargs): |
return h.product(in_shape) |
def forward(self, x, **kargs): |
s = x.size() |
return x.view(s[0], h.product(s[1:])) |
def neuronCount(self): |
return 0 |
class BatchNorm(InferModule): |
def init(self, in_shape, track_running_stats = True, momentum = 0.1, eps=1e-5, **kargs): |
self.gamma = torch.nn.Parameter(torch.Tensor(*in_shape)) |
self.beta = torch.nn.Parameter(torch.Tensor(*in_shape)) |
self.eps = eps |
self.track_running_stats = track_running_stats |
self.momentum = momentum |
self.running_mean = None |
self.running_var = None |
self.num_batches_tracked = 0 |
return in_shape |
def reset_parameters(self): |
self.gamma.data.fill_(1) |
self.beta.data.zero_() |
def forward(self, x, **kargs): |
exponential_average_factor = 0.0 |
if self.training and self.track_running_stats: |
if self.num_batches_tracked is not None: |
self.num_batches_tracked += 1 |
if self.momentum is None: |
exponential_average_factor = 1.0 / float(self.num_batches_tracked) |
else: |
exponential_average_factor = self.momentum |
new_mean = x.vanillaTensorPart().detach().mean(dim=0) |
new_var = x.vanillaTensorPart().detach().var(dim=0, unbiased=False) |
if torch.isnan(new_var * 0).any(): |
return x |
if self.training: |
self.running_mean = (1 - exponential_average_factor) * self.running_mean + exponential_average_factor * new_mean if self.running_mean is not None else new_mean |
if self.running_var is None: |
self.running_var = new_var |
else: |
q = (1 - exponential_average_factor) * self.running_var |
r = exponential_average_factor * new_var |
self.running_var = q + r |
if self.track_running_stats and self.running_mean is not None and self.running_var is not None: |
new_mean = self.running_mean |
new_var = self.running_var |
diver = 1 / (new_var + self.eps).sqrt() |
if torch.isnan(diver).any(): |
print("Really shouldn't happen ever") |
return x |
else: |
out = (x - new_mean) * diver * self.gamma + self.beta |
return out |
def neuronCount(self): |
return 0 |
class Unflatten2d(InferModule): |
def init(self, in_shape, w, **kargs): |
self.w = w |
self.outChan = int(h.product(in_shape) / (w * w)) |
return (self.outChan, self.w, self.w) |
def forward(self, x, **kargs): |
s = x.size() |
return x.view(s[0], self.outChan, self.w, self.w) |
def neuronCount(self): |
return 0 |
class View(InferModule): |
def init(self, in_shape, out_shape, **kargs): |
assert(h.product(in_shape) == h.product(out_shape)) |
return out_shape |
def forward(self, x, **kargs): |
s = x.size() |
return x.view(s[0], *self.outShape) |
def neuronCount(self): |
return 0 |
class Seq(InferModule): |
def init(self, in_shape, *layers, **kargs): |
self.layers = layers |
self.net = nn.Sequential(*layers) |
self.prev = in_shape |
for s in layers: |
in_shape = s.infer(in_shape, **kargs).outShape |
return in_shape |
def forward(self, x, **kargs): |
for l in self.layers: |
x = l(x, **kargs) |
return x |
def clip_norm(self): |
for l in self.layers: |
l.clip_norm() |
def regularize(self, p): |
return sum(n.regularize(p) for n in self.layers) |
def remove_norm(self): |
for l in self.layers: |
l.remove_norm() |
def printNet(self, f): |
for l in self.layers: |
l.printNet(f) |
def showNet(self, *args, **kargs): |
for l in self.layers: |
l.showNet(*args, **kargs) |
def neuronCount(self): |
return sum([l.neuronCount() for l in self.layers ]) |
def depth(self): |
return sum([l.depth() for l in self.layers ]) |
def FFNN(layers, last_lin = False, last_zono = False, **kargs): |
starts = layers |
ends = [] |
if last_lin: |
ends = ([CorrelateAll(only_train=False)] if last_zono else []) + [PrintActivation(activation = "Affine"), Linear(layers[-1],**kargs)] |
starts = layers[:-1] |
return Seq(*([ Seq(PrintActivation(**kargs), Linear(s, **kargs), activation(**kargs)) for s in starts] + ends)) |
def Conv(*args, **kargs): |
return Seq(Conv2D(*args, **kargs), activation(**kargs)) |
def ConvTranspose(*args, **kargs): |
return Seq(ConvTranspose2D(*args, **kargs), activation(**kargs)) |
MP = MaxPool2D |
def LeNet(conv_layers, ly = [], bias = True, normal=False, **kargs): |
def transfer(tp): |
if isinstance(tp, InferModule): |
return tp |
if isinstance(tp[0], str): |
return MaxPool2D(*tp[1:]) |
return Conv(out_channels = tp[0], kernel_size = tp[1], stride = tp[-1] if len(tp) == 4 else 1, bias=bias, normal=normal, **kargs) |
conv = [transfer(s) for s in conv_layers] |
return Seq(*conv, FFNN(ly, **kargs, bias=bias)) if len(ly) > 0 else Seq(*conv) |
def InvLeNet(ly, w, conv_layers, bias = True, normal=False, **kargs): |
def transfer(tp): |
return ConvTranspose(out_channels = tp[0], kernel_size = tp[1], stride = tp[2], padding = tp[3], out_padding = tp[4], bias=False, normal=normal) |
return Seq(FFNN(ly, bias=bias), Unflatten2d(w), *[transfer(s) for s in conv_layers]) |
class FromByteImg(InferModule): |
def init(self, in_shape, **kargs): |
return in_shape |
def forward(self, x, **kargs): |
return x.to_dtype()/ 256. |
def neuronCount(self): |
return 0 |
class Skip(InferModule): |
def init(self, in_shape, net1, net2, **kargs): |
self.net1 = net1.infer(in_shape, **kargs) |
self.net2 = net2.infer(in_shape, **kargs) |
assert(net1.outShape[1:] == net2.outShape[1:]) |
return [ net1.outShape[0] + net2.outShape[0] ] + net1.outShape[1:] |
def forward(self, x, **kargs): |
r1 = self.net1(x, **kargs) |
r2 = self.net2(x, **kargs) |
return r1.cat(r2, dim=1) |
def regularize(self, p): |
return self.net1.regularize(p) + self.net2.regularize(p) |
def clip_norm(self): |
self.net1.clip_norm() |
self.net2.clip_norm() |
def remove_norm(self): |
self.net1.remove_norm() |
self.net2.remove_norm() |
def neuronCount(self): |
return self.net1.neuronCount() + self.net2.neuronCount() |
def printNet(self, f): |
print("SkipNet1", file=f) |
self.net1.printNet(f) |
print("SkipNet2", file=f) |
self.net2.printNet(f) |
print("SkipCat dim=1", file=f) |
def showNet(self, t = ""): |
print(t+"SkipNet1") |
self.net1.showNet(" "+t) |
print(t+"SkipNet2") |
self.net2.showNet(" "+t) |
print(t+"SkipCat dim=1") |
class ParSum(InferModule): |
def init(self, in_shape, net1, net2, **kargs): |
self.net1 = net1.infer(in_shape, **kargs) |
self.net2 = net2.infer(in_shape, **kargs) |
assert(net1.outShape == net2.outShape) |
return net1.outShape |
def forward(self, x, **kargs): |
r1 = self.net1(x, **kargs) |
r2 = self.net2(x, **kargs) |
return x.addPar(r1,r2) |
def clip_norm(self): |
self.net1.clip_norm() |
self.net2.clip_norm() |
def remove_norm(self): |
self.net1.remove_norm() |
self.net2.remove_norm() |
def neuronCount(self): |
return self.net1.neuronCount() + self.net2.neuronCount() |
def depth(self): |
return max(self.net1.depth(), self.net2.depth()) |
def printNet(self, f): |
print("ParNet1", file=f) |
self.net1.printNet(f) |
print("ParNet2", file=f) |
self.net2.printNet(f) |
print("ParCat dim=1", file=f) |
def showNet(self, t = ""): |
print(t + "ParNet1") |
self.net1.showNet(" "+t) |
print(t + "ParNet2") |
self.net2.showNet(" "+t) |
print(t + "ParSum") |
class ToZono(Identity): |
def init(self, in_shape, customRelu = None, only_train = False, **kargs): |
self.customRelu = customRelu |
self.only_train = only_train |
return in_shape |
def forward(self, x, **kargs): |
return self.abstract_forward(x, **kargs) if self.training or not self.only_train else x |
def abstract_forward(self, x, **kargs): |
return x.abstractApplyLeaf('hybrid_to_zono', customRelu = self.customRelu) |
def showNet(self, t = ""): |
print(t + self.__class__.__name__ + " only_train=" + str(self.only_train)) |
class CorrelateAll(ToZono): |
def abstract_forward(self, x, **kargs): |
return x.abstractApplyLeaf('hybrid_to_zono',correlate=True, customRelu = self.customRelu) |
class ToHZono(ToZono): |
def abstract_forward(self, x, **kargs): |
return x.abstractApplyLeaf('zono_to_hybrid',customRelu = self.customRelu) |
class Concretize(ToZono): |
def init(self, in_shape, only_train = True, **kargs): |
self.only_train = only_train |
return in_shape |
def abstract_forward(self, x, **kargs): |
return x.abstractApplyLeaf('concretize') |
class CorrRand(Concretize): |
def init(self, in_shape, num_correlate, only_train = True, **kargs): |
self.only_train = only_train |
self.num_correlate = num_correlate |
return in_shape |
def abstract_forward(self, x): |
return x.abstractApplyLeaf("stochasticCorrelate", self.num_correlate) |
def showNet(self, t = ""): |
print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " num_correlate="+ str(self.num_correlate)) |
class CorrMaxK(CorrRand): |
def abstract_forward(self, x): |
return x.abstractApplyLeaf("correlateMaxK", self.num_correlate) |
class CorrMaxPool2D(Concretize): |
def init(self,in_shape, kernel_size, only_train = True, max_type = ai.MaxTypes.head_beta, **kargs): |
self.only_train = only_train |
self.kernel_size = kernel_size |
self.max_type = max_type |
return in_shape |
def abstract_forward(self, x): |
return x.abstractApplyLeaf("correlateMaxPool", kernel_size = self.kernel_size, stride = self.kernel_size, max_type = self.max_type) |
def showNet(self, t = ""): |
print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " kernel_size="+ str(self.kernel_size) + " max_type=" +str(self.max_type)) |
class CorrMaxPool3D(Concretize): |
def init(self,in_shape, kernel_size, only_train = True, max_type = ai.MaxTypes.only_beta, **kargs): |
self.only_train = only_train |
self.kernel_size = kernel_size |
self.max_type = max_type |
return in_shape |
def abstract_forward(self, x): |
return x.abstractApplyLeaf("correlateMaxPool", kernel_size = self.kernel_size, stride = self.kernel_size, max_type = self.max_type, max_pool = F.max_pool3d) |
def showNet(self, t = ""): |
print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " kernel_size="+ str(self.kernel_size) + " max_type=" +self.max_type) |
class CorrFix(Concretize): |
def init(self,in_shape, k, only_train = True, **kargs): |
self.k = k |
self.only_train = only_train |
return in_shape |
def abstract_forward(self, x): |
sz = x.size() |
""" |
# for more control in the future |
indxs_1 = torch.arange(start = 0, end = sz[1], step = math.ceil(sz[1] / self.dims[1]) ) |
indxs_2 = torch.arange(start = 0, end = sz[2], step = math.ceil(sz[2] / self.dims[2]) ) |
indxs_3 = torch.arange(start = 0, end = sz[3], step = math.ceil(sz[3] / self.dims[3]) ) |
indxs = torch.stack(torch.meshgrid((indxs_1,indxs_2,indxs_3)), dim=3).view(-1,3) |
""" |
szm = h.product(sz[1:]) |
indxs = torch.arange(start = 0, end = szm, step = math.ceil(szm / self.k)) |
indxs = indxs.unsqueeze(0).expand(sz[0], indxs.size()[0]) |
return x.abstractApplyLeaf("correlate", indxs) |
def showNet(self, t = ""): |
print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " k="+ str(self.k)) |
class DecorrRand(Concretize): |
def init(self, in_shape, num_decorrelate, only_train = True, **kargs): |
self.only_train = only_train |
self.num_decorrelate = num_decorrelate |
return in_shape |
def abstract_forward(self, x): |
return x.abstractApplyLeaf("stochasticDecorrelate", self.num_decorrelate) |
class DecorrMin(Concretize): |
def init(self, in_shape, num_decorrelate, only_train = True, num_to_keep = False, **kargs): |
self.only_train = only_train |
self.num_decorrelate = num_decorrelate |
self.num_to_keep = num_to_keep |
return in_shape |
def abstract_forward(self, x): |
return x.abstractApplyLeaf("decorrelateMin", self.num_decorrelate, num_to_keep = self.num_to_keep) |
def showNet(self, t = ""): |
print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " k="+ str(self.num_decorrelate) + " num_to_keep=" + str(self.num_to_keep) ) |
class DeepLoss(ToZono): |
def init(self, in_shape, bw = 0.01, act = F.relu, **kargs): |
self.only_train = True |
self.bw = S.Const.initConst(bw) |
self.act = act |
return in_shape |
def abstract_forward(self, x, **kargs): |
if x.isPoint(): |
return x |
return ai.TaggedDomain(x, self.MLoss(self, x)) |
class MLoss(): |
def __init__(self, obj, x): |
self.obj = obj |
self.x = x |
def loss(self, a, *args, lr = 1, time = 0, **kargs): |
bw = self.obj.bw.getVal(time = time) |
pre_loss = a.loss(*args, time = time, **kargs, lr = lr * (1 - bw)) |
if bw <= 0.0: |
return pre_loss |
return (1 - bw) * pre_loss + bw * self.x.deep_loss(act = self.obj.act) |
def showNet(self, t = ""): |
print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " bw="+ str(self.bw) + " act=" + str(self.act) ) |
class IdentLoss(DeepLoss): |
def abstract_forward(self, x, **kargs): |
return x |
def SkipNet(net1, net2, ffnn, **kargs): |
return Seq(Skip(net1,net2), FFNN(ffnn, **kargs)) |
def WideBlock(out_filters, downsample=False, k=3, bias=False, **kargs): |
if not downsample: |
k_first = 3 |
skip_stride = 1 |
k_skip = 1 |
else: |
k_first = 4 |
skip_stride = 2 |
k_skip = 2 |
blockA = Conv2D(out_filters, kernel_size=k_skip, stride=skip_stride, padding=0, bias=bias, normal=True, **kargs) |
blockB = Seq( Conv(out_filters, kernel_size = k_first, stride = skip_stride, padding = 1, bias=bias, normal=True, **kargs) |
, Conv2D(out_filters, kernel_size = k, stride = 1, padding = 1, bias=bias, normal=True, **kargs)) |
return Seq(ParSum(blockA, blockB), activation(**kargs)) |
def BasicBlock(in_planes, planes, stride=1, bias = False, skip_net = False, **kargs): |
block = Seq( Conv(planes, kernel_size = 3, stride = stride, padding = 1, bias=bias, normal=True, **kargs) |
, Conv2D(planes, kernel_size = 3, stride = 1, padding = 1, bias=bias, normal=True, **kargs)) |
if stride != 1 or in_planes != planes: |
block = ParSum(block, Conv2D(planes, kernel_size=1, stride=stride, bias=bias, normal=True, **kargs)) |
elif not skip_net: |
block = ParSum(block, Identity()) |
return Seq(block, activation(**kargs)) |
def ResNet(blocksList, extra = [], bias = False, **kargs): |
layers = [] |
in_planes = 64 |
planes = 64 |
stride = 0 |
for num_blocks in blocksList: |
if stride < 2: |
stride += 1 |
strides = [stride] + [1]*(num_blocks-1) |
for stride in strides: |
layers.append(BasicBlock(in_planes, planes, stride, bias = bias, **kargs)) |
in_planes = planes |
planes *= 2 |
print("RESlayers: ", len(layers)) |
for e,l in extra: |
layers[l] = Seq(layers[l], e) |
return Seq(Conv(64, kernel_size=3, stride=1, padding = 1, bias=bias, normal=True, printShape=True), |
*layers) |
def DenseNet(growthRate, depth, reduction, num_classes, bottleneck = True): |
def Bottleneck(growthRate): |
interChannels = 4*growthRate |
n = Seq( ReLU(), |
Conv2D(interChannels, kernel_size=1, bias=True, ibp_init = True), |
ReLU(), |
Conv2D(growthRate, kernel_size=3, padding=1, bias=True, ibp_init = True) |
) |
return Skip(Identity(), n) |
def SingleLayer(growthRate): |
n = Seq( ReLU(), |
Conv2D(growthRate, kernel_size=3, padding=1, bias=True, ibp_init = True)) |
return Skip(Identity(), n) |
def Transition(nOutChannels): |
return Seq( ReLU(), |
Conv2D(nOutChannels, kernel_size = 1, bias = True, ibp_init = True), |
AvgPool2D(kernel_size=2)) |
def make_dense(growthRate, nDenseBlocks, bottleneck): |
return Seq(*[Bottleneck(growthRate) if bottleneck else SingleLayer(growthRate) for i in range(nDenseBlocks)]) |
nDenseBlocks = (depth-4) // 3 |
if bottleneck: |
nDenseBlocks //= 2 |
nChannels = 2*growthRate |
conv1 = Conv2D(nChannels, kernel_size=3, padding=1, bias=True, ibp_init = True) |
dense1 = make_dense(growthRate, nDenseBlocks, bottleneck) |
nChannels += nDenseBlocks * growthRate |
nOutChannels = int(math.floor(nChannels*reduction)) |
trans1 = Transition(nOutChannels) |
nChannels = nOutChannels |
dense2 = make_dense(growthRate, nDenseBlocks, bottleneck) |
nChannels += nDenseBlocks*growthRate |
nOutChannels = int(math.floor(nChannels*reduction)) |
trans2 = Transition(nOutChannels) |
nChannels = nOutChannels |
dense3 = make_dense(growthRate, nDenseBlocks, bottleneck) |
return Seq(conv1, dense1, trans1, dense2, trans2, dense3, |
ReLU(), |
AvgPool2D(kernel_size=8), |
CorrelateAll(only_train=False, ignore_point = True), |
Linear(num_classes, ibp_init = True)) |