|
import future |
|
import builtins |
|
import past |
|
import six |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
import torch.autograd |
|
|
|
from functools import reduce |
|
|
|
try: |
|
from . import helpers as h |
|
except: |
|
import helpers as h |
|
|
|
|
|
|
|
def catNonNullErrors(op, ref_errs=None): |
|
def doop(er1, er2): |
|
erS, erL = (er1, er2) |
|
sS, sL = (erS.size()[0], erL.size()[0]) |
|
|
|
if sS == sL: |
|
return op(erS,erL) |
|
|
|
if ref_errs is not None: |
|
sz = ref_errs.size()[0] |
|
else: |
|
sz = min(sS, sL) |
|
|
|
p1 = op(erS[:sz], erL[:sz]) |
|
erSrem = erS[sz:] |
|
erLrem = erS[sz:] |
|
p2 = op(erSrem, h.zeros(erSrem.shape)) |
|
p3 = op(h.zeros(erLrem.shape), erLrem) |
|
return torch.cat((p1,p2,p3), dim=0) |
|
return doop |
|
|
|
def creluBoxy(dom): |
|
if dom.errors is None: |
|
if dom.beta is None: |
|
return dom.new(F.relu(dom.head), None, None) |
|
er = dom.beta |
|
mx = F.relu(dom.head + er) |
|
mn = F.relu(dom.head - er) |
|
return dom.new((mn + mx) / 2, (mx - mn) / 2 , None) |
|
|
|
aber = torch.abs(dom.errors) |
|
|
|
sm = torch.sum(aber, 0) |
|
|
|
if not dom.beta is None: |
|
sm += dom.beta |
|
|
|
mx = dom.head + sm |
|
mn = dom.head - sm |
|
|
|
should_box = mn.lt(0) * mx.gt(0) |
|
gtz = dom.head.gt(0).to_dtype() |
|
mx /= 2 |
|
newhead = h.ifThenElse(should_box, mx, gtz * dom.head) |
|
newbeta = h.ifThenElse(should_box, mx, gtz * (dom.beta if not dom.beta is None else 0)) |
|
newerr = (1 - should_box.to_dtype()) * gtz * dom.errors |
|
|
|
return dom.new(newhead, newbeta , newerr) |
|
|
|
|
|
def creluBoxySound(dom): |
|
if dom.errors is None: |
|
if dom.beta is None: |
|
return dom.new(F.relu(dom.head), None, None) |
|
er = dom.beta |
|
mx = F.relu(dom.head + er) |
|
mn = F.relu(dom.head - er) |
|
return dom.new((mn + mx) / 2, (mx - mn) / 2 + 2e-6 , None) |
|
|
|
aber = torch.abs(dom.errors) |
|
|
|
sm = torch.sum(aber, 0) |
|
|
|
if not dom.beta is None: |
|
sm += dom.beta |
|
|
|
mx = dom.head + sm |
|
mn = dom.head - sm |
|
|
|
should_box = mn.lt(0) * mx.gt(0) |
|
gtz = dom.head.gt(0).to_dtype() |
|
mx /= 2 |
|
newhead = h.ifThenElse(should_box, mx, gtz * dom.head) |
|
newbeta = h.ifThenElse(should_box, mx + 2e-6, gtz * (dom.beta if not dom.beta is None else 0)) |
|
newerr = (1 - should_box.to_dtype()) * gtz * dom.errors |
|
|
|
return dom.new(newhead, newbeta, newerr) |
|
|
|
|
|
def creluSwitch(dom): |
|
if dom.errors is None: |
|
if dom.beta is None: |
|
return dom.new(F.relu(dom.head), None, None) |
|
er = dom.beta |
|
mx = F.relu(dom.head + er) |
|
mn = F.relu(dom.head - er) |
|
return dom.new((mn + mx) / 2, (mx - mn) / 2 , None) |
|
|
|
aber = torch.abs(dom.errors) |
|
|
|
sm = torch.sum(aber, 0) |
|
|
|
if not dom.beta is None: |
|
sm += dom.beta |
|
|
|
mn = dom.head - sm |
|
mx = sm |
|
mx += dom.head |
|
|
|
should_box = mn.lt(0) * mx.gt(0) |
|
gtz = dom.head.gt(0) |
|
|
|
mn.neg_() |
|
should_boxer = mn.gt(mx) |
|
|
|
mn /= 2 |
|
newhead = h.ifThenElse(should_box, h.ifThenElse(should_boxer, mx / 2, dom.head + mn ), gtz.to_dtype() * dom.head) |
|
zbet = dom.beta if not dom.beta is None else 0 |
|
newbeta = h.ifThenElse(should_box, h.ifThenElse(should_boxer, mx / 2, mn + zbet), gtz.to_dtype() * zbet) |
|
newerr = h.ifThenElseL(should_box, 1 - should_boxer, gtz).to_dtype() * dom.errors |
|
|
|
return dom.new(newhead, newbeta , newerr) |
|
|
|
def creluSmooth(dom): |
|
if dom.errors is None: |
|
if dom.beta is None: |
|
return dom.new(F.relu(dom.head), None, None) |
|
er = dom.beta |
|
mx = F.relu(dom.head + er) |
|
mn = F.relu(dom.head - er) |
|
return dom.new((mn + mx) / 2, (mx - mn) / 2 , None) |
|
|
|
aber = torch.abs(dom.errors) |
|
|
|
sm = torch.sum(aber, 0) |
|
|
|
if not dom.beta is None: |
|
sm += dom.beta |
|
|
|
mn = dom.head - sm |
|
mx = sm |
|
mx += dom.head |
|
|
|
|
|
nmn = F.relu(-1 * mn) |
|
|
|
zbet = (dom.beta if not dom.beta is None else 0) |
|
newheadS = dom.head + nmn / 2 |
|
newbetaS = zbet + nmn / 2 |
|
newerrS = dom.errors |
|
|
|
mmx = F.relu(mx) |
|
|
|
newheadB = mmx / 2 |
|
newbetaB = newheadB |
|
newerrB = 0 |
|
|
|
eps = 0.0001 |
|
t = nmn / (mmx + nmn + eps) |
|
|
|
shouldnt_zero = mx.gt(0).to_dtype() |
|
|
|
newhead = shouldnt_zero * ( (1 - t) * newheadS + t * newheadB) |
|
newbeta = shouldnt_zero * ( (1 - t) * newbetaS + t * newbetaB) |
|
newerr = shouldnt_zero * ( (1 - t) * newerrS + t * newerrB) |
|
|
|
return dom.new(newhead, newbeta , newerr) |
|
|
|
|
|
def creluNIPS(dom): |
|
if dom.errors is None: |
|
if dom.beta is None: |
|
return dom.new(F.relu(dom.head), None, None) |
|
er = dom.beta |
|
mx = F.relu(dom.head + er) |
|
mn = F.relu(dom.head - er) |
|
return dom.new((mn + mx) / 2, (mx - mn) / 2 , None) |
|
|
|
sm = torch.sum(torch.abs(dom.errors), 0) |
|
|
|
if not dom.beta is None: |
|
sm += dom.beta |
|
|
|
mn = dom.head - sm |
|
mx = dom.head + sm |
|
|
|
mngz = mn >= 0.0 |
|
|
|
zs = h.zeros(dom.head.shape) |
|
|
|
diff = mx - mn |
|
|
|
lam = torch.where((mx > 0) & (diff > 0.0), mx / diff, zs) |
|
mu = lam * mn * (-0.5) |
|
|
|
betaz = zs if dom.beta is None else dom.beta |
|
|
|
newhead = torch.where(mngz, dom.head , lam * dom.head + mu) |
|
mngz += diff <= 0.0 |
|
newbeta = torch.where(mngz, betaz , lam * betaz + mu ) |
|
newerr = torch.where(mngz, dom.errors, lam * dom.errors ) |
|
return dom.new(newhead, newbeta, newerr) |
|
|
|
|
|
|
|
|
|
class MaxTypes: |
|
|
|
@staticmethod |
|
def ub(x): |
|
return x.ub() |
|
|
|
@staticmethod |
|
def only_beta(x): |
|
return x.beta if x.beta is not None else x.head * 0 |
|
|
|
@staticmethod |
|
def head_beta(x): |
|
return MaxTypes.only_beta(x) + x.head |
|
|
|
class HybridZonotope: |
|
|
|
def isSafe(self, target): |
|
od,_ = torch.min(h.preDomRes(self,target).lb(), 1) |
|
return od.gt(0.0).long() |
|
|
|
def isPoint(self): |
|
return False |
|
|
|
def labels(self): |
|
target = torch.max(self.ub(), 1)[1] |
|
l = list(h.preDomRes(self,target).lb()[0]) |
|
return [target.item()] + [ i for i,v in zip(range(len(l)), l) if v <= 0] |
|
|
|
def relu(self): |
|
return self.customRelu(self) |
|
|
|
def __init__(self, head, beta, errors, customRelu = creluBoxy, **kargs): |
|
self.head = head |
|
self.errors = errors |
|
self.beta = beta |
|
self.customRelu = creluBoxy if customRelu is None else customRelu |
|
|
|
def new(self, *args, customRelu = None, **kargs): |
|
return self.__class__(*args, **kargs, customRelu = self.customRelu if customRelu is None else customRelu).checkSizes() |
|
|
|
def zono_to_hybrid(self, *args, **kargs): |
|
return self.new(self.head, self.beta, self.errors, **kargs) |
|
|
|
def hybrid_to_zono(self, *args, correlate=True, customRelu = None, **kargs): |
|
beta = self.beta |
|
errors = self.errors |
|
if correlate and beta is not None: |
|
batches = beta.shape[0] |
|
num_elem = h.product(beta.shape[1:]) |
|
ei = h.getEi(batches, num_elem) |
|
|
|
if len(beta.shape) > 2: |
|
ei = ei.contiguous().view(num_elem, *beta.shape) |
|
err = ei * beta |
|
errors = torch.cat((err, errors), dim=0) if errors is not None else err |
|
beta = None |
|
|
|
return Zonotope(self.head, beta, errors if errors is not None else (self.beta * 0).unsqueeze(0) , customRelu = self.customRelu if customRelu is None else None) |
|
|
|
|
|
|
|
def abstractApplyLeaf(self, foo, *args, **kargs): |
|
return getattr(self, foo)(*args, **kargs) |
|
|
|
def decorrelate(self, cc_indx_batch_err): |
|
if self.errors is None: |
|
return self |
|
|
|
batch_size = self.head.shape[0] |
|
num_error_terms = self.errors.shape[0] |
|
|
|
|
|
|
|
beta = h.zeros(self.head.shape).to_dtype() if self.beta is None else self.beta |
|
errors = h.zeros([0] + list(self.head.shape)).to_dtype() if self.errors is None else self.errors |
|
|
|
inds_i = torch.arange(self.head.shape[0], device=h.device).unsqueeze(1).long() |
|
errors = errors.to_dtype().permute(1,0, *list(range(len(self.errors.shape)))[2:]) |
|
|
|
sm = errors.clone() |
|
sm[inds_i, cc_indx_batch_err] = 0 |
|
|
|
beta = beta.to_dtype() + sm.abs().sum(dim=1) |
|
|
|
errors = errors[inds_i, cc_indx_batch_err] |
|
errors = errors.permute(1,0, *list(range(len(self.errors.shape)))[2:]).contiguous() |
|
return self.new(self.head, beta, errors) |
|
|
|
def dummyDecorrelate(self, num_decorrelate): |
|
if num_decorrelate == 0 or self.errors is None: |
|
return self |
|
elif num_decorrelate >= self.errors.shape[0]: |
|
beta = self.beta |
|
if self.errors is not None: |
|
errs = self.errors.abs().sum(dim=0) |
|
if beta is None: |
|
beta = errs |
|
else: |
|
beta += errs |
|
return self.new(self.head, beta, None) |
|
return None |
|
|
|
def stochasticDecorrelate(self, num_decorrelate, choices = None, num_to_keep=False): |
|
dummy = self.dummyDecorrelate(num_decorrelate) |
|
if dummy is not None: |
|
return dummy |
|
num_error_terms = self.errors.shape[0] |
|
batch_size = self.head.shape[0] |
|
|
|
ucc_mask = h.ones([batch_size, self.errors.shape[0]]).long() |
|
cc_indx_batch_err = h.cudify(torch.multinomial(ucc_mask.to_dtype(), num_decorrelate if num_to_keep else num_error_terms - num_decorrelate, replacement=False)) if choices is None else choices |
|
return self.decorrelate(cc_indx_batch_err) |
|
|
|
def decorrelateMin(self, num_decorrelate, num_to_keep=False): |
|
dummy = self.dummyDecorrelate(num_decorrelate) |
|
if dummy is not None: |
|
return dummy |
|
|
|
num_error_terms = self.errors.shape[0] |
|
batch_size = self.head.shape[0] |
|
|
|
error_sum_b_e = self.errors.abs().view(self.errors.shape[0], batch_size, -1).sum(dim=2).permute(1,0) |
|
cc_indx_batch_err = error_sum_b_e.topk(num_decorrelate if num_to_keep else num_error_terms - num_decorrelate)[1] |
|
return self.decorrelate(cc_indx_batch_err) |
|
|
|
def correlate(self, cc_indx_batch_beta): |
|
num_correlate = h.product(cc_indx_batch_beta.shape[1:]) |
|
|
|
beta = h.zeros(self.head.shape).to_dtype() if self.beta is None else self.beta |
|
errors = h.zeros([0] + list(self.head.shape)).to_dtype() if self.errors is None else self.errors |
|
|
|
batch_size = beta.shape[0] |
|
new_errors = h.zeros([num_correlate] + list(self.head.shape)).to_dtype() |
|
|
|
inds_i = torch.arange(batch_size, device=h.device).unsqueeze(1).long() |
|
|
|
nc = torch.arange(num_correlate, device=h.device).unsqueeze(1).long() |
|
|
|
new_errors = new_errors.permute(1,0, *list(range(len(new_errors.shape)))[2:]).contiguous().view(batch_size, num_correlate, -1) |
|
new_errors[inds_i, nc.unsqueeze(0).expand([batch_size]+list(nc.shape)).squeeze(2), cc_indx_batch_beta] = beta.view(batch_size,-1)[inds_i, cc_indx_batch_beta] |
|
|
|
new_errors = new_errors.permute(1,0, *list(range(len(new_errors.shape)))[2:]).contiguous().view(num_correlate, batch_size, *beta.shape[1:]) |
|
errors = torch.cat((errors, new_errors), dim=0) |
|
|
|
beta.view(batch_size, -1)[inds_i, cc_indx_batch_beta] = 0 |
|
|
|
return self.new(self.head, beta, errors) |
|
|
|
def stochasticCorrelate(self, num_correlate, choices = None): |
|
if num_correlate == 0: |
|
return self |
|
|
|
domshape = self.head.shape |
|
batch_size = domshape[0] |
|
num_pixs = h.product(domshape[1:]) |
|
num_correlate = min(num_correlate, num_pixs) |
|
ucc_mask = h.ones([batch_size, num_pixs ]).long() |
|
|
|
cc_indx_batch_beta = h.cudify(torch.multinomial(ucc_mask.to_dtype(), num_correlate, replacement=False)) if choices is None else choices |
|
return self.correlate(cc_indx_batch_beta) |
|
|
|
|
|
def correlateMaxK(self, num_correlate): |
|
if num_correlate == 0: |
|
return self |
|
|
|
domshape = self.head.shape |
|
batch_size = domshape[0] |
|
num_pixs = h.product(domshape[1:]) |
|
num_correlate = min(num_correlate, num_pixs) |
|
|
|
concrete_max_image = self.ub().view(batch_size, -1) |
|
|
|
cc_indx_batch_beta = concrete_max_image.topk(num_correlate)[1] |
|
return self.correlate(cc_indx_batch_beta) |
|
|
|
def correlateMaxPool(self, *args, max_type = MaxTypes.ub , max_pool = F.max_pool2d, **kargs): |
|
domshape = self.head.shape |
|
batch_size = domshape[0] |
|
num_pixs = h.product(domshape[1:]) |
|
|
|
concrete_max_image = max_type(self) |
|
|
|
cc_indx_batch_beta = max_pool(concrete_max_image, *args, return_indices=True, **kargs)[1].view(batch_size, -1) |
|
|
|
return self.correlate(cc_indx_batch_beta) |
|
|
|
def checkSizes(self): |
|
if not self.errors is None: |
|
if not self.errors.size()[1:] == self.head.size(): |
|
raise Exception("Such bad sizes on error:", self.errors.shape, " head:", self.head.shape) |
|
if torch.isnan(self.errors).any(): |
|
raise Exception("Such nan in errors") |
|
if not self.beta is None: |
|
if not self.beta.size() == self.head.size(): |
|
raise Exception("Such bad sizes on beta") |
|
|
|
if torch.isnan(self.beta).any(): |
|
raise Exception("Such nan in errors") |
|
if self.beta.lt(0.0).any(): |
|
self.beta = self.beta.abs() |
|
|
|
return self |
|
|
|
def __mul__(self, flt): |
|
return self.new(self.head * flt, None if self.beta is None else self.beta * abs(flt), None if self.errors is None else self.errors * flt) |
|
|
|
def __truediv__(self, flt): |
|
flt = 1. / flt |
|
return self.new(self.head * flt, None if self.beta is None else self.beta * abs(flt), None if self.errors is None else self.errors * flt) |
|
|
|
def __add__(self, other): |
|
if isinstance(other, HybridZonotope): |
|
return self.new(self.head + other.head, h.msum(self.beta, other.beta, lambda a,b: a + b), h.msum(self.errors, other.errors, catNonNullErrors(lambda a,b: a + b))) |
|
else: |
|
|
|
return self.new(self.head + other, self.beta, self.errors) |
|
|
|
def addPar(self, a, b): |
|
return self.new(a.head + b.head, h.msum(a.beta, b.beta, lambda a,b: a + b), h.msum(a.errors, b.errors, catNonNullErrors(lambda a,b: a + b, self.errors))) |
|
|
|
def __sub__(self, other): |
|
if isinstance(other, HybridZonotope): |
|
return self.new(self.head - other.head |
|
, h.msum(self.beta, other.beta, lambda a,b: a + b) |
|
, h.msum(self.errors, None if other.errors is None else -other.errors, catNonNullErrors(lambda a,b: a + b))) |
|
else: |
|
|
|
return self.new(self.head - other, self.beta, self.errors) |
|
|
|
def bmm(self, other): |
|
hd = self.head.bmm(other) |
|
bet = None if self.beta is None else self.beta.bmm(other.abs()) |
|
|
|
if self.errors is None: |
|
er = None |
|
else: |
|
er = self.errors.matmul(other) |
|
return self.new(hd, bet, er) |
|
|
|
|
|
def getBeta(self): |
|
return self.head * 0 if self.beta is None else self.beta |
|
|
|
def getErrors(self): |
|
return (self.head * 0).unsqueeze(0) if self.beta is None else self.errors |
|
|
|
def merge(self, other, ref = None): |
|
s_beta = self.getBeta() |
|
|
|
sbox_u = self.head + s_beta |
|
sbox_l = self.head - s_beta |
|
o_u = other.ub() |
|
o_l = other.lb() |
|
o_in_s = (o_u <= sbox_u) & (o_l >= sbox_l) |
|
|
|
s_err_mx = self.errors.abs().sum(dim=0) |
|
|
|
if not isinstance(other, HybridZonotope): |
|
new_head = (self.head + other.center()) / 2 |
|
new_beta = torch.max(sbox_u + s_err_mx,o_u) - new_head |
|
return self.new(torch.where(o_in_s, self.head, new_head), torch.where(o_in_s, self.beta,new_beta), o_in_s.float() * self.errors) |
|
|
|
|
|
s_u = sbox_u + s_err_mx |
|
s_l = sbox_l - s_err_mx |
|
|
|
obox_u = o_u - other.head |
|
obox_l = o_l + other.head |
|
|
|
s_in_o = (s_u <= obox_u) & (s_l >= obox_l) |
|
|
|
|
|
new_head = (self.head + other.center()) / 2 |
|
new_beta = torch.max(sbox_u + self.getErrors().abs().sum(dim=0),o_u) - new_head |
|
|
|
return self.new(torch.where(o_in_s, self.head, torch.where(s_in_o, other.head, new_head)) |
|
, torch.where(o_in_s, s_beta,torch.where(s_in_o, other.getBeta(), new_beta)) |
|
, h.msum(o_in_s.float() * self.errors, s_in_o.float() * other.errors, catNonNullErrors(lambda a,b: a + b, ref_errs = ref.errors if ref is not None else ref))) |
|
|
|
|
|
def conv(self, conv, weight, bias = None, **kargs): |
|
h = self.errors |
|
inter = h if h is None else h.view(-1, *h.size()[2:]) |
|
hd = conv(self.head, weight, bias=bias, **kargs) |
|
res = h if h is None else conv(inter, weight, bias=None, **kargs) |
|
|
|
return self.new( hd |
|
, None if self.beta is None else conv(self.beta, weight.abs(), bias = None, **kargs) |
|
, h if h is None else res.view(h.size()[0], h.size()[1], *res.size()[1:])) |
|
|
|
|
|
def conv1d(self, *args, **kargs): |
|
return self.conv(lambda x, *args, **kargs: x.conv1d(*args,**kargs), *args, **kargs) |
|
|
|
def conv2d(self, *args, **kargs): |
|
return self.conv(lambda x, *args, **kargs: x.conv2d(*args,**kargs), *args, **kargs) |
|
|
|
def conv3d(self, *args, **kargs): |
|
return self.conv(lambda x, *args, **kargs: x.conv3d(*args,**kargs), *args, **kargs) |
|
|
|
def conv_transpose1d(self, *args, **kargs): |
|
return self.conv(lambda x, *args, **kargs: x.conv_transpose1d(*args,**kargs), *args, **kargs) |
|
|
|
def conv_transpose2d(self, *args, **kargs): |
|
return self.conv(lambda x, *args, **kargs: x.conv_transpose2d(*args,**kargs), *args, **kargs) |
|
|
|
def conv_transpose3d(self, *args, **kargs): |
|
return self.conv(lambda x, *args, **kargs: x.conv_transpose3d(*args,**kargs), *args, **kargs) |
|
|
|
def matmul(self, other): |
|
return self.new(self.head.matmul(other), None if self.beta is None else self.beta.matmul(other.abs()), None if self.errors is None else self.errors.matmul(other)) |
|
|
|
def unsqueeze(self, i): |
|
return self.new(self.head.unsqueeze(i), None if self.beta is None else self.beta.unsqueeze(i), None if self.errors is None else self.errors.unsqueeze(i + 1)) |
|
|
|
def squeeze(self, dim): |
|
return self.new(self.head.squeeze(dim), |
|
None if self.beta is None else self.beta.squeeze(dim), |
|
None if self.errors is None else self.errors.squeeze(dim + 1 if dim >= 0 else dim)) |
|
|
|
def double(self): |
|
return self.new(self.head.double(), self.beta.double() if self.beta is not None else None, self.errors.double() if self.errors is not None else None) |
|
|
|
def float(self): |
|
return self.new(self.head.float(), self.beta.float() if self.beta is not None else None, self.errors.float() if self.errors is not None else None) |
|
|
|
def to_dtype(self): |
|
return self.new(self.head.to_dtype(), self.beta.to_dtype() if self.beta is not None else None, self.errors.to_dtype() if self.errors is not None else None) |
|
|
|
def sum(self, dim=1): |
|
return self.new(torch.sum(self.head,dim=dim), None if self.beta is None else torch.sum(self.beta,dim=dim), None if self.errors is None else torch.sum(self.errors, dim= dim + 1 if dim >= 0 else dim)) |
|
|
|
def view(self,*newshape): |
|
return self.new(self.head.view(*newshape), |
|
None if self.beta is None else self.beta.view(*newshape), |
|
None if self.errors is None else self.errors.view(self.errors.size()[0], *newshape)) |
|
|
|
def gather(self,dim, index): |
|
return self.new(self.head.gather(dim, index), |
|
None if self.beta is None else self.beta.gather(dim, index), |
|
None if self.errors is None else self.errors.gather(dim + 1, index.expand([self.errors.size()[0]] + list(index.size())))) |
|
|
|
def concretize(self): |
|
if self.errors is None: |
|
return self |
|
|
|
return self.new(self.head, torch.sum(self.concreteErrors().abs(),0), None) |
|
|
|
def cat(self,other, dim=0): |
|
return self.new(self.head.cat(other.head, dim = dim), |
|
h.msum(other.beta, self.beta, lambda a,b: a.cat(b, dim = dim)), |
|
h.msum(self.errors, other.errors, catNonNullErrors(lambda a,b: a.cat(b, dim+1)))) |
|
|
|
|
|
def split(self, split_size, dim = 0): |
|
heads = list(self.head.split(split_size, dim)) |
|
betas = list(self.beta.split(split_size, dim)) if not self.beta is None else None |
|
errorss = list(self.errors.split(split_size, dim + 1)) if not self.errors is None else None |
|
|
|
def makeFromI(i): |
|
return self.new( heads[i], |
|
None if betas is None else betas[i], |
|
None if errorss is None else errorss[i]) |
|
return tuple(makeFromI(i) for i in range(len(heads))) |
|
|
|
|
|
|
|
def concreteErrors(self): |
|
if self.beta is None and self.errors is None: |
|
raise Exception("shouldn't have both beta and errors be none") |
|
if self.errors is None: |
|
return self.beta.unsqueeze(0) |
|
if self.beta is None: |
|
return self.errors |
|
return torch.cat([self.beta.unsqueeze(0),self.errors], dim=0) |
|
|
|
|
|
def applyMonotone(self, foo, *args, **kargs): |
|
if self.beta is None and self.errors is None: |
|
return self.new(foo(self.head), None , None) |
|
|
|
beta = self.concreteErrors().abs().sum(dim=0) |
|
|
|
tp = foo(self.head + beta, *args, **kargs) |
|
bt = foo(self.head - beta, *args, **kargs) |
|
|
|
new_hybrid = self.new((tp + bt) / 2, (tp - bt) / 2 , None) |
|
|
|
|
|
if self.errors is not None: |
|
return new_hybrid.correlateMaxK(self.errors.shape[0]) |
|
return new_hybrid |
|
|
|
def avg_pool2d(self, *args, **kargs): |
|
nhead = F.avg_pool2d(self.head, *args, **kargs) |
|
return self.new(nhead, |
|
None if self.beta is None else F.avg_pool2d(self.beta, *args, **kargs), |
|
None if self.errors is None else F.avg_pool2d(self.errors.view(-1, *self.head.shape[1:]), *args, **kargs).view(-1,*nhead.shape)) |
|
|
|
def adaptive_avg_pool2d(self, *args, **kargs): |
|
nhead = F.adaptive_avg_pool2d(self.head, *args, **kargs) |
|
return self.new(nhead, |
|
None if self.beta is None else F.adaptive_avg_pool2d(self.beta, *args, **kargs), |
|
None if self.errors is None else F.adaptive_avg_pool2d(self.errors.view(-1, *self.head.shape[1:]), *args, **kargs).view(-1,*nhead.shape)) |
|
|
|
def elu(self): |
|
return self.applyMonotone(F.elu) |
|
|
|
def selu(self): |
|
return self.applyMonotone(F.selu) |
|
|
|
def sigm(self): |
|
return self.applyMonotone(F.sigmoid) |
|
|
|
def softplus(self): |
|
if self.errors is None: |
|
if self.beta is None: |
|
return self.new(F.softplus(self.head), None , None) |
|
tp = F.softplus(self.head + self.beta) |
|
bt = F.softplus(self.head - self.beta) |
|
return self.new((tp + bt) / 2, (tp - bt) / 2 , None) |
|
|
|
errors = self.concreteErrors() |
|
o = h.ones(self.head.size()) |
|
|
|
def sp(hd): |
|
return F.softplus(hd) |
|
def spp(hd): |
|
ehd = torch.exp(hd) |
|
return ehd.div(ehd + o) |
|
def sppp(hd): |
|
ehd = torch.exp(hd) |
|
md = ehd + o |
|
return ehd.div(md.mul(md)) |
|
|
|
fa = sp(self.head) |
|
fpa = spp(self.head) |
|
|
|
a = self.head |
|
|
|
k = torch.sum(errors.abs(), 0) |
|
|
|
def evalG(r): |
|
return r.mul(r).mul(sppp(a + r)) |
|
|
|
m = torch.max(evalG(h.zeros(k.size())), torch.max(evalG(k), evalG(-k))) |
|
m = h.ifThenElse( a.abs().lt(k), torch.max(m, torch.max(evalG(a), evalG(-a))), m) |
|
m /= 2 |
|
|
|
return self.new(fa, m if self.beta is None else m + self.beta.mul(fpa), None if self.errors is None else self.errors.mul(fpa)) |
|
|
|
def center(self): |
|
return self.head |
|
|
|
def vanillaTensorPart(self): |
|
return self.head |
|
|
|
def lb(self): |
|
return self.head - self.concreteErrors().abs().sum(dim=0) |
|
|
|
def ub(self): |
|
return self.head + self.concreteErrors().abs().sum(dim=0) |
|
|
|
def size(self): |
|
return self.head.size() |
|
|
|
def diameter(self): |
|
abal = torch.abs(self.concreteErrors()).transpose(0,1) |
|
return abal.sum(1).sum(1) |
|
|
|
def loss(self, target, **args): |
|
r = -h.preDomRes(self, target).lb() |
|
return F.softplus(r.max(1)[0]) |
|
|
|
def deep_loss(self, act = F.relu, *args, **kargs): |
|
batch_size = self.head.shape[0] |
|
inds = torch.arange(batch_size, device=h.device).unsqueeze(1).long() |
|
|
|
def dl(l,u): |
|
ls, lsi = torch.sort(l, dim=1) |
|
ls_u = u[inds, lsi] |
|
|
|
def slidingMax(a): |
|
k = a.shape[1] |
|
ml = a.min(dim=1)[0].unsqueeze(1) |
|
|
|
inp = torch.cat((h.zeros([batch_size, k]), a - ml), dim=1) |
|
mpl = F.max_pool1d(inp.unsqueeze(1) , kernel_size = k, stride=1, padding = 0, return_indices=False).squeeze(1) |
|
return mpl[:,:-1] + ml |
|
|
|
return act(slidingMax(ls_u) - ls).sum(dim=1) |
|
|
|
l = self.lb().view(batch_size, -1) |
|
u = self.ub().view(batch_size, -1) |
|
return ( dl(l,u) + dl(-u,-l) ) / (2 * l.shape[1]) |
|
|
|
|
|
|
|
class Zonotope(HybridZonotope): |
|
def applySuper(self, ret): |
|
batches = ret.head.size()[0] |
|
num_elem = h.product(ret.head.size()[1:]) |
|
ei = h.getEi(batches, num_elem) |
|
|
|
if len(ret.head.size()) > 2: |
|
ei = ei.contiguous().view(num_elem, *ret.head.size()) |
|
|
|
ret.errors = torch.cat( (ret.errors, ei * ret.beta) ) if not ret.beta is None else ret.errors |
|
ret.beta = None |
|
return ret.checkSizes() |
|
|
|
def zono_to_hybrid(self, *args, customRelu = None, **kargs): |
|
return HybridZonotope(self.head, self.beta, self.errors, customRelu = self.customRelu if customRelu is None else customRelu) |
|
|
|
def hybrid_to_zono(self, *args, **kargs): |
|
return self.new(self.head, self.beta, self.errors, **kargs) |
|
|
|
def applyMonotone(self, *args, **kargs): |
|
return self.applySuper(super(Zonotope,self).applyMonotone(*args, **kargs)) |
|
|
|
def softplus(self): |
|
return self.applySuper(super(Zonotope,self).softplus()) |
|
|
|
def relu(self): |
|
return self.applySuper(super(Zonotope,self).relu()) |
|
|
|
def splitRelu(self, *args, **kargs): |
|
return [self.applySuper(a) for a in super(Zonotope, self).splitRelu(*args, **kargs)] |
|
|
|
|
|
def mysign(x): |
|
e = x.eq(0).to_dtype() |
|
r = x.sign().to_dtype() |
|
return r + e |
|
|
|
def mulIfEq(grad,out,target): |
|
pred = out.max(1, keepdim=True)[1] |
|
is_eq = pred.eq(target.view_as(pred)).to_dtype() |
|
is_eq = is_eq.view([-1] + [1 for _ in grad.size()[1:]]).expand_as(grad) |
|
return is_eq |
|
|
|
|
|
def stdLoss(out, target): |
|
if torch.__version__[0] == "0": |
|
return F.cross_entropy(out, target, reduce = False) |
|
else: |
|
return F.cross_entropy(out, target, reduction='none') |
|
|
|
|
|
|
|
class ListDomain(object): |
|
|
|
def __init__(self, al, *args, **kargs): |
|
self.al = list(al) |
|
|
|
def new(self, *args, **kargs): |
|
return self.__class__(*args, **kargs) |
|
|
|
def isSafe(self,*args,**kargs): |
|
raise "Domain Not Suitable For Testing" |
|
|
|
def labels(self): |
|
raise "Domain Not Suitable For Testing" |
|
|
|
def isPoint(self): |
|
return all(a.isPoint() for a in self.al) |
|
|
|
def __mul__(self, flt): |
|
return self.new(a.__mul__(flt) for a in self.al) |
|
|
|
def __truediv__(self, flt): |
|
return self.new(a.__truediv__(flt) for a in self.al) |
|
|
|
def __add__(self, other): |
|
if isinstance(other, ListDomain): |
|
return self.new(a.__add__(o) for a,o in zip(self.al, other.al)) |
|
else: |
|
return self.new(a.__add__(other) for a in self.al) |
|
|
|
def merge(self, other, ref = None): |
|
if ref is None: |
|
return self.new(a.merge(o) for a,o in zip(self.al,other.al) ) |
|
return self.new(a.merge(o, ref = r) for a,o,r in zip(self.al,other.al, ref.al)) |
|
|
|
def addPar(self, a, b): |
|
return self.new(s.addPar(av,bv) for s,av,bv in zip(self.al, a.al, b.al)) |
|
|
|
def __sub__(self, other): |
|
if isinstance(other, ListDomain): |
|
return self.new(a.__sub__(o) for a,o in zip(self.al, other.al)) |
|
else: |
|
return self.new(a.__sub__(other) for a in self.al) |
|
|
|
def abstractApplyLeaf(self, *args, **kargs): |
|
return self.new(a.abstractApplyLeaf(*args, **kargs) for a in self.al) |
|
|
|
def bmm(self, other): |
|
return self.new(a.bmm(other) for a in self.al) |
|
|
|
def matmul(self, other): |
|
return self.new(a.matmul(other) for a in self.al) |
|
|
|
def conv(self, *args, **kargs): |
|
return self.new(a.conv(*args, **kargs) for a in self.al) |
|
|
|
def conv1d(self, *args, **kargs): |
|
return self.new(a.conv1d(*args, **kargs) for a in self.al) |
|
|
|
def conv2d(self, *args, **kargs): |
|
return self.new(a.conv2d(*args, **kargs) for a in self.al) |
|
|
|
def conv3d(self, *args, **kargs): |
|
return self.new(a.conv3d(*args, **kargs) for a in self.al) |
|
|
|
def max_pool2d(self, *args, **kargs): |
|
return self.new(a.max_pool2d(*args, **kargs) for a in self.al) |
|
|
|
def avg_pool2d(self, *args, **kargs): |
|
return self.new(a.avg_pool2d(*args, **kargs) for a in self.al) |
|
|
|
def adaptive_avg_pool2d(self, *args, **kargs): |
|
return self.new(a.adaptive_avg_pool2d(*args, **kargs) for a in self.al) |
|
|
|
def unsqueeze(self, *args, **kargs): |
|
return self.new(a.unsqueeze(*args, **kargs) for a in self.al) |
|
|
|
def squeeze(self, *args, **kargs): |
|
return self.new(a.squeeze(*args, **kargs) for a in self.al) |
|
|
|
def view(self, *args, **kargs): |
|
return self.new(a.view(*args, **kargs) for a in self.al) |
|
|
|
def gather(self, *args, **kargs): |
|
return self.new(a.gather(*args, **kargs) for a in self.al) |
|
|
|
def sum(self, *args, **kargs): |
|
return self.new(a.sum(*args,**kargs) for a in self.al) |
|
|
|
def double(self): |
|
return self.new(a.double() for a in self.al) |
|
|
|
def float(self): |
|
return self.new(a.float() for a in self.al) |
|
|
|
def to_dtype(self): |
|
return self.new(a.to_dtype() for a in self.al) |
|
|
|
def vanillaTensorPart(self): |
|
return self.al[0].vanillaTensorPart() |
|
|
|
def center(self): |
|
return self.new(a.center() for a in self.al) |
|
|
|
def ub(self): |
|
return self.new(a.ub() for a in self.al) |
|
|
|
def lb(self): |
|
return self.new(a.lb() for a in self.al) |
|
|
|
def relu(self): |
|
return self.new(a.relu() for a in self.al) |
|
|
|
def splitRelu(self, *args, **kargs): |
|
return self.new(a.splitRelu(*args, **kargs) for a in self.al) |
|
|
|
def softplus(self): |
|
return self.new(a.softplus() for a in self.al) |
|
|
|
def elu(self): |
|
return self.new(a.elu() for a in self.al) |
|
|
|
def selu(self): |
|
return self.new(a.selu() for a in self.al) |
|
|
|
def sigm(self): |
|
return self.new(a.sigm() for a in self.al) |
|
|
|
def cat(self, other, *args, **kargs): |
|
return self.new(a.cat(o, *args, **kargs) for a,o in zip(self.al, other.al)) |
|
|
|
|
|
def split(self, *args, **kargs): |
|
return [self.new(*z) for z in zip(a.split(*args, **kargs) for a in self.al)] |
|
|
|
def size(self): |
|
return self.al[0].size() |
|
|
|
def loss(self, *args, **kargs): |
|
return sum(a.loss(*args, **kargs) for a in self.al) |
|
|
|
def deep_loss(self, *args, **kargs): |
|
return sum(a.deep_loss(*args, **kargs) for a in self.al) |
|
|
|
def checkSizes(self): |
|
for a in self.al: |
|
a.checkSizes() |
|
return self |
|
|
|
|
|
class TaggedDomain(object): |
|
|
|
|
|
def __init__(self, a, tag = None): |
|
self.tag = tag |
|
self.a = a |
|
|
|
def isSafe(self,*args,**kargs): |
|
return self.a.isSafe(*args, **kargs) |
|
|
|
def isPoint(self): |
|
return self.a.isPoint() |
|
|
|
def labels(self): |
|
raise "Domain Not Suitable For Testing" |
|
|
|
def __mul__(self, flt): |
|
return TaggedDomain(self.a.__mul__(flt), self.tag) |
|
|
|
def __truediv__(self, flt): |
|
return TaggedDomain(self.a.__truediv__(flt), self.tag) |
|
|
|
def __add__(self, other): |
|
if isinstance(other, TaggedDomain): |
|
return TaggedDomain(self.a.__add__(other.a), self.tag) |
|
else: |
|
return TaggedDomain(self.a.__add__(other), self.tag) |
|
|
|
def addPar(self, a,b): |
|
return TaggedDomain(self.a.addPar(a.a, b.a), self.tag) |
|
|
|
def __sub__(self, other): |
|
if isinstance(other, TaggedDomain): |
|
return TaggedDomain(self.a.__sub__(other.a), self.tag) |
|
else: |
|
return TaggedDomain(self.a.__sub__(other), self.tag) |
|
|
|
def bmm(self, other): |
|
return TaggedDomain(self.a.bmm(other), self.tag) |
|
|
|
def matmul(self, other): |
|
return TaggedDomain(self.a.matmul(other), self.tag) |
|
|
|
def conv(self, *args, **kargs): |
|
return TaggedDomain(self.a.conv(*args, **kargs) , self.tag) |
|
|
|
def conv1d(self, *args, **kargs): |
|
return TaggedDomain(self.a.conv1d(*args, **kargs), self.tag) |
|
|
|
def conv2d(self, *args, **kargs): |
|
return TaggedDomain(self.a.conv2d(*args, **kargs), self.tag) |
|
|
|
def conv3d(self, *args, **kargs): |
|
return TaggedDomain(self.a.conv3d(*args, **kargs), self.tag) |
|
|
|
def max_pool2d(self, *args, **kargs): |
|
return TaggedDomain(self.a.max_pool2d(*args, **kargs), self.tag) |
|
|
|
def avg_pool2d(self, *args, **kargs): |
|
return TaggedDomain(self.a.avg_pool2d(*args, **kargs), self.tag) |
|
|
|
def adaptive_avg_pool2d(self, *args, **kargs): |
|
return TaggedDomain(self.a.adaptive_avg_pool2d(*args, **kargs), self.tag) |
|
|
|
|
|
def unsqueeze(self, *args, **kargs): |
|
return TaggedDomain(self.a.unsqueeze(*args, **kargs), self.tag) |
|
|
|
def squeeze(self, *args, **kargs): |
|
return TaggedDomain(self.a.squeeze(*args, **kargs), self.tag) |
|
|
|
def abstractApplyLeaf(self, *args, **kargs): |
|
return TaggedDomain(self.a.abstractApplyLeaf(*args, **kargs), self.tag) |
|
|
|
def view(self, *args, **kargs): |
|
return TaggedDomain(self.a.view(*args, **kargs), self.tag) |
|
|
|
def gather(self, *args, **kargs): |
|
return TaggedDomain(self.a.gather(*args, **kargs), self.tag) |
|
|
|
def sum(self, *args, **kargs): |
|
return TaggedDomain(self.a.sum(*args,**kargs), self.tag) |
|
|
|
def double(self): |
|
return TaggedDomain(self.a.double(), self.tag) |
|
|
|
def float(self): |
|
return TaggedDomain(self.a.float(), self.tag) |
|
|
|
def to_dtype(self): |
|
return TaggedDomain(self.a.to_dtype(), self.tag) |
|
|
|
def vanillaTensorPart(self): |
|
return self.a.vanillaTensorPart() |
|
|
|
def center(self): |
|
return TaggedDomain(self.a.center(), self.tag) |
|
|
|
def ub(self): |
|
return TaggedDomain(self.a.ub(), self.tag) |
|
|
|
def lb(self): |
|
return TaggedDomain(self.a.lb(), self.tag) |
|
|
|
def relu(self): |
|
return TaggedDomain(self.a.relu(), self.tag) |
|
|
|
def splitRelu(self, *args, **kargs): |
|
return TaggedDomain(self.a.splitRelu(*args, **kargs), self.tag) |
|
|
|
def diameter(self): |
|
return self.a.diameter() |
|
|
|
def softplus(self): |
|
return TaggedDomain(self.a.softplus(), self.tag) |
|
|
|
def elu(self): |
|
return TaggedDomain(self.a.elu(), self.tag) |
|
|
|
def selu(self): |
|
return TaggedDomain(self.a.selu(), self.tag) |
|
|
|
def sigm(self): |
|
return TaggedDomain(self.a.sigm(), self.tag) |
|
|
|
|
|
def cat(self, other, *args, **kargs): |
|
return TaggedDomain(self.a.cat(other.a, *args, **kargs), self.tag) |
|
|
|
def split(self, *args, **kargs): |
|
return [TaggedDomain(z, self.tag) for z in self.a.split(*args, **kargs)] |
|
|
|
def size(self): |
|
|
|
return self.a.size() |
|
|
|
def loss(self, *args, **kargs): |
|
return self.tag.loss(self.a, *args, **kargs) |
|
|
|
def deep_loss(self, *args, **kargs): |
|
return self.a.deep_loss(*args, **kargs) |
|
|
|
def checkSizes(self): |
|
self.a.checkSizes() |
|
return self |
|
|
|
def merge(self, other, ref = None): |
|
return TaggedDomain(self.a.merge(other.a, ref = None if ref is None else ref.a), self.tag) |
|
|