import future import builtins import past import six import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch.autograd from functools import reduce try: from . import helpers as h except: import helpers as h def catNonNullErrors(op, ref_errs=None): # the way of things is ugly def doop(er1, er2): erS, erL = (er1, er2) sS, sL = (erS.size()[0], erL.size()[0]) if sS == sL: # TODO: here we know we used transformers on either side which didnt introduce new error terms (this is a hack for hybrid zonotopes and doesn't work with adaptive error term adding). return op(erS,erL) if ref_errs is not None: sz = ref_errs.size()[0] else: sz = min(sS, sL) p1 = op(erS[:sz], erL[:sz]) erSrem = erS[sz:] erLrem = erS[sz:] p2 = op(erSrem, h.zeros(erSrem.shape)) p3 = op(h.zeros(erLrem.shape), erLrem) return torch.cat((p1,p2,p3), dim=0) return doop def creluBoxy(dom): if dom.errors is None: if dom.beta is None: return dom.new(F.relu(dom.head), None, None) er = dom.beta mx = F.relu(dom.head + er) mn = F.relu(dom.head - er) return dom.new((mn + mx) / 2, (mx - mn) / 2 , None) aber = torch.abs(dom.errors) sm = torch.sum(aber, 0) if not dom.beta is None: sm += dom.beta mx = dom.head + sm mn = dom.head - sm should_box = mn.lt(0) * mx.gt(0) gtz = dom.head.gt(0).to_dtype() mx /= 2 newhead = h.ifThenElse(should_box, mx, gtz * dom.head) newbeta = h.ifThenElse(should_box, mx, gtz * (dom.beta if not dom.beta is None else 0)) newerr = (1 - should_box.to_dtype()) * gtz * dom.errors return dom.new(newhead, newbeta , newerr) def creluBoxySound(dom): if dom.errors is None: if dom.beta is None: return dom.new(F.relu(dom.head), None, None) er = dom.beta mx = F.relu(dom.head + er) mn = F.relu(dom.head - er) return dom.new((mn + mx) / 2, (mx - mn) / 2 + 2e-6 , None) aber = torch.abs(dom.errors) sm = torch.sum(aber, 0) if not dom.beta is None: sm += dom.beta mx = dom.head + sm mn = dom.head - sm should_box = mn.lt(0) * mx.gt(0) gtz = dom.head.gt(0).to_dtype() mx /= 2 newhead = h.ifThenElse(should_box, mx, gtz * dom.head) newbeta = h.ifThenElse(should_box, mx + 2e-6, gtz * (dom.beta if not dom.beta is None else 0)) newerr = (1 - should_box.to_dtype()) * gtz * dom.errors return dom.new(newhead, newbeta, newerr) def creluSwitch(dom): if dom.errors is None: if dom.beta is None: return dom.new(F.relu(dom.head), None, None) er = dom.beta mx = F.relu(dom.head + er) mn = F.relu(dom.head - er) return dom.new((mn + mx) / 2, (mx - mn) / 2 , None) aber = torch.abs(dom.errors) sm = torch.sum(aber, 0) if not dom.beta is None: sm += dom.beta mn = dom.head - sm mx = sm mx += dom.head should_box = mn.lt(0) * mx.gt(0) gtz = dom.head.gt(0) mn.neg_() should_boxer = mn.gt(mx) mn /= 2 newhead = h.ifThenElse(should_box, h.ifThenElse(should_boxer, mx / 2, dom.head + mn ), gtz.to_dtype() * dom.head) zbet = dom.beta if not dom.beta is None else 0 newbeta = h.ifThenElse(should_box, h.ifThenElse(should_boxer, mx / 2, mn + zbet), gtz.to_dtype() * zbet) newerr = h.ifThenElseL(should_box, 1 - should_boxer, gtz).to_dtype() * dom.errors return dom.new(newhead, newbeta , newerr) def creluSmooth(dom): if dom.errors is None: if dom.beta is None: return dom.new(F.relu(dom.head), None, None) er = dom.beta mx = F.relu(dom.head + er) mn = F.relu(dom.head - er) return dom.new((mn + mx) / 2, (mx - mn) / 2 , None) aber = torch.abs(dom.errors) sm = torch.sum(aber, 0) if not dom.beta is None: sm += dom.beta mn = dom.head - sm mx = sm mx += dom.head nmn = F.relu(-1 * mn) zbet = (dom.beta if not dom.beta is None else 0) newheadS = dom.head + nmn / 2 newbetaS = zbet + nmn / 2 newerrS = dom.errors mmx = F.relu(mx) newheadB = mmx / 2 newbetaB = newheadB newerrB = 0 eps = 0.0001 t = nmn / (mmx + nmn + eps) # mn.lt(0).to_dtype() * F.sigmoid(nmn - nmx) shouldnt_zero = mx.gt(0).to_dtype() newhead = shouldnt_zero * ( (1 - t) * newheadS + t * newheadB) newbeta = shouldnt_zero * ( (1 - t) * newbetaS + t * newbetaB) newerr = shouldnt_zero * ( (1 - t) * newerrS + t * newerrB) return dom.new(newhead, newbeta , newerr) def creluNIPS(dom): if dom.errors is None: if dom.beta is None: return dom.new(F.relu(dom.head), None, None) er = dom.beta mx = F.relu(dom.head + er) mn = F.relu(dom.head - er) return dom.new((mn + mx) / 2, (mx - mn) / 2 , None) sm = torch.sum(torch.abs(dom.errors), 0) if not dom.beta is None: sm += dom.beta mn = dom.head - sm mx = dom.head + sm mngz = mn >= 0.0 zs = h.zeros(dom.head.shape) diff = mx - mn lam = torch.where((mx > 0) & (diff > 0.0), mx / diff, zs) mu = lam * mn * (-0.5) betaz = zs if dom.beta is None else dom.beta newhead = torch.where(mngz, dom.head , lam * dom.head + mu) mngz += diff <= 0.0 newbeta = torch.where(mngz, betaz , lam * betaz + mu ) # mu is always positive on this side newerr = torch.where(mngz, dom.errors, lam * dom.errors ) return dom.new(newhead, newbeta, newerr) class MaxTypes: @staticmethod def ub(x): return x.ub() @staticmethod def only_beta(x): return x.beta if x.beta is not None else x.head * 0 @staticmethod def head_beta(x): return MaxTypes.only_beta(x) + x.head class HybridZonotope: def isSafe(self, target): od,_ = torch.min(h.preDomRes(self,target).lb(), 1) return od.gt(0.0).long() def isPoint(self): return False def labels(self): target = torch.max(self.ub(), 1)[1] l = list(h.preDomRes(self,target).lb()[0]) return [target.item()] + [ i for i,v in zip(range(len(l)), l) if v <= 0] def relu(self): return self.customRelu(self) def __init__(self, head, beta, errors, customRelu = creluBoxy, **kargs): self.head = head self.errors = errors self.beta = beta self.customRelu = creluBoxy if customRelu is None else customRelu def new(self, *args, customRelu = None, **kargs): return self.__class__(*args, **kargs, customRelu = self.customRelu if customRelu is None else customRelu).checkSizes() def zono_to_hybrid(self, *args, **kargs): # we are already a hybrid zono. return self.new(self.head, self.beta, self.errors, **kargs) def hybrid_to_zono(self, *args, correlate=True, customRelu = None, **kargs): beta = self.beta errors = self.errors if correlate and beta is not None: batches = beta.shape[0] num_elem = h.product(beta.shape[1:]) ei = h.getEi(batches, num_elem) if len(beta.shape) > 2: ei = ei.contiguous().view(num_elem, *beta.shape) err = ei * beta errors = torch.cat((err, errors), dim=0) if errors is not None else err beta = None return Zonotope(self.head, beta, errors if errors is not None else (self.beta * 0).unsqueeze(0) , customRelu = self.customRelu if customRelu is None else None) def abstractApplyLeaf(self, foo, *args, **kargs): return getattr(self, foo)(*args, **kargs) def decorrelate(self, cc_indx_batch_err): # keep these errors if self.errors is None: return self batch_size = self.head.shape[0] num_error_terms = self.errors.shape[0] beta = h.zeros(self.head.shape).to_dtype() if self.beta is None else self.beta errors = h.zeros([0] + list(self.head.shape)).to_dtype() if self.errors is None else self.errors inds_i = torch.arange(self.head.shape[0], device=h.device).unsqueeze(1).long() errors = errors.to_dtype().permute(1,0, *list(range(len(self.errors.shape)))[2:]) sm = errors.clone() sm[inds_i, cc_indx_batch_err] = 0 beta = beta.to_dtype() + sm.abs().sum(dim=1) errors = errors[inds_i, cc_indx_batch_err] errors = errors.permute(1,0, *list(range(len(self.errors.shape)))[2:]).contiguous() return self.new(self.head, beta, errors) def dummyDecorrelate(self, num_decorrelate): if num_decorrelate == 0 or self.errors is None: return self elif num_decorrelate >= self.errors.shape[0]: beta = self.beta if self.errors is not None: errs = self.errors.abs().sum(dim=0) if beta is None: beta = errs else: beta += errs return self.new(self.head, beta, None) return None def stochasticDecorrelate(self, num_decorrelate, choices = None, num_to_keep=False): dummy = self.dummyDecorrelate(num_decorrelate) if dummy is not None: return dummy num_error_terms = self.errors.shape[0] batch_size = self.head.shape[0] ucc_mask = h.ones([batch_size, self.errors.shape[0]]).long() cc_indx_batch_err = h.cudify(torch.multinomial(ucc_mask.to_dtype(), num_decorrelate if num_to_keep else num_error_terms - num_decorrelate, replacement=False)) if choices is None else choices return self.decorrelate(cc_indx_batch_err) def decorrelateMin(self, num_decorrelate, num_to_keep=False): dummy = self.dummyDecorrelate(num_decorrelate) if dummy is not None: return dummy num_error_terms = self.errors.shape[0] batch_size = self.head.shape[0] error_sum_b_e = self.errors.abs().view(self.errors.shape[0], batch_size, -1).sum(dim=2).permute(1,0) cc_indx_batch_err = error_sum_b_e.topk(num_decorrelate if num_to_keep else num_error_terms - num_decorrelate)[1] return self.decorrelate(cc_indx_batch_err) def correlate(self, cc_indx_batch_beta): # given in terms of the flattened matrix. num_correlate = h.product(cc_indx_batch_beta.shape[1:]) beta = h.zeros(self.head.shape).to_dtype() if self.beta is None else self.beta errors = h.zeros([0] + list(self.head.shape)).to_dtype() if self.errors is None else self.errors batch_size = beta.shape[0] new_errors = h.zeros([num_correlate] + list(self.head.shape)).to_dtype() inds_i = torch.arange(batch_size, device=h.device).unsqueeze(1).long() nc = torch.arange(num_correlate, device=h.device).unsqueeze(1).long() new_errors = new_errors.permute(1,0, *list(range(len(new_errors.shape)))[2:]).contiguous().view(batch_size, num_correlate, -1) new_errors[inds_i, nc.unsqueeze(0).expand([batch_size]+list(nc.shape)).squeeze(2), cc_indx_batch_beta] = beta.view(batch_size,-1)[inds_i, cc_indx_batch_beta] new_errors = new_errors.permute(1,0, *list(range(len(new_errors.shape)))[2:]).contiguous().view(num_correlate, batch_size, *beta.shape[1:]) errors = torch.cat((errors, new_errors), dim=0) beta.view(batch_size, -1)[inds_i, cc_indx_batch_beta] = 0 return self.new(self.head, beta, errors) def stochasticCorrelate(self, num_correlate, choices = None): if num_correlate == 0: return self domshape = self.head.shape batch_size = domshape[0] num_pixs = h.product(domshape[1:]) num_correlate = min(num_correlate, num_pixs) ucc_mask = h.ones([batch_size, num_pixs ]).long() cc_indx_batch_beta = h.cudify(torch.multinomial(ucc_mask.to_dtype(), num_correlate, replacement=False)) if choices is None else choices return self.correlate(cc_indx_batch_beta) def correlateMaxK(self, num_correlate): if num_correlate == 0: return self domshape = self.head.shape batch_size = domshape[0] num_pixs = h.product(domshape[1:]) num_correlate = min(num_correlate, num_pixs) concrete_max_image = self.ub().view(batch_size, -1) cc_indx_batch_beta = concrete_max_image.topk(num_correlate)[1] return self.correlate(cc_indx_batch_beta) def correlateMaxPool(self, *args, max_type = MaxTypes.ub , max_pool = F.max_pool2d, **kargs): domshape = self.head.shape batch_size = domshape[0] num_pixs = h.product(domshape[1:]) concrete_max_image = max_type(self) cc_indx_batch_beta = max_pool(concrete_max_image, *args, return_indices=True, **kargs)[1].view(batch_size, -1) return self.correlate(cc_indx_batch_beta) def checkSizes(self): if not self.errors is None: if not self.errors.size()[1:] == self.head.size(): raise Exception("Such bad sizes on error:", self.errors.shape, " head:", self.head.shape) if torch.isnan(self.errors).any(): raise Exception("Such nan in errors") if not self.beta is None: if not self.beta.size() == self.head.size(): raise Exception("Such bad sizes on beta") if torch.isnan(self.beta).any(): raise Exception("Such nan in errors") if self.beta.lt(0.0).any(): self.beta = self.beta.abs() return self def __mul__(self, flt): return self.new(self.head * flt, None if self.beta is None else self.beta * abs(flt), None if self.errors is None else self.errors * flt) def __truediv__(self, flt): flt = 1. / flt return self.new(self.head * flt, None if self.beta is None else self.beta * abs(flt), None if self.errors is None else self.errors * flt) def __add__(self, other): if isinstance(other, HybridZonotope): return self.new(self.head + other.head, h.msum(self.beta, other.beta, lambda a,b: a + b), h.msum(self.errors, other.errors, catNonNullErrors(lambda a,b: a + b))) else: # other has to be a standard variable or tensor return self.new(self.head + other, self.beta, self.errors) def addPar(self, a, b): return self.new(a.head + b.head, h.msum(a.beta, b.beta, lambda a,b: a + b), h.msum(a.errors, b.errors, catNonNullErrors(lambda a,b: a + b, self.errors))) def __sub__(self, other): if isinstance(other, HybridZonotope): return self.new(self.head - other.head , h.msum(self.beta, other.beta, lambda a,b: a + b) , h.msum(self.errors, None if other.errors is None else -other.errors, catNonNullErrors(lambda a,b: a + b))) else: # other has to be a standard variable or tensor return self.new(self.head - other, self.beta, self.errors) def bmm(self, other): hd = self.head.bmm(other) bet = None if self.beta is None else self.beta.bmm(other.abs()) if self.errors is None: er = None else: er = self.errors.matmul(other) return self.new(hd, bet, er) def getBeta(self): return self.head * 0 if self.beta is None else self.beta def getErrors(self): return (self.head * 0).unsqueeze(0) if self.beta is None else self.errors def merge(self, other, ref = None): # the vast majority of the time ref should be none here. Not for parallel computation with powerset s_beta = self.getBeta() # so that beta is never none sbox_u = self.head + s_beta sbox_l = self.head - s_beta o_u = other.ub() o_l = other.lb() o_in_s = (o_u <= sbox_u) & (o_l >= sbox_l) s_err_mx = self.errors.abs().sum(dim=0) if not isinstance(other, HybridZonotope): new_head = (self.head + other.center()) / 2 new_beta = torch.max(sbox_u + s_err_mx,o_u) - new_head return self.new(torch.where(o_in_s, self.head, new_head), torch.where(o_in_s, self.beta,new_beta), o_in_s.float() * self.errors) # TODO: could be more efficient if one of these doesn't have beta or errors but thats okay for now. s_u = sbox_u + s_err_mx s_l = sbox_l - s_err_mx obox_u = o_u - other.head obox_l = o_l + other.head s_in_o = (s_u <= obox_u) & (s_l >= obox_l) # TODO: could theoretically still do something better when one is contained partially in the other new_head = (self.head + other.center()) / 2 new_beta = torch.max(sbox_u + self.getErrors().abs().sum(dim=0),o_u) - new_head return self.new(torch.where(o_in_s, self.head, torch.where(s_in_o, other.head, new_head)) , torch.where(o_in_s, s_beta,torch.where(s_in_o, other.getBeta(), new_beta)) , h.msum(o_in_s.float() * self.errors, s_in_o.float() * other.errors, catNonNullErrors(lambda a,b: a + b, ref_errs = ref.errors if ref is not None else ref))) # these are both zero otherwise def conv(self, conv, weight, bias = None, **kargs): h = self.errors inter = h if h is None else h.view(-1, *h.size()[2:]) hd = conv(self.head, weight, bias=bias, **kargs) res = h if h is None else conv(inter, weight, bias=None, **kargs) return self.new( hd , None if self.beta is None else conv(self.beta, weight.abs(), bias = None, **kargs) , h if h is None else res.view(h.size()[0], h.size()[1], *res.size()[1:])) def conv1d(self, *args, **kargs): return self.conv(lambda x, *args, **kargs: x.conv1d(*args,**kargs), *args, **kargs) def conv2d(self, *args, **kargs): return self.conv(lambda x, *args, **kargs: x.conv2d(*args,**kargs), *args, **kargs) def conv3d(self, *args, **kargs): return self.conv(lambda x, *args, **kargs: x.conv3d(*args,**kargs), *args, **kargs) def conv_transpose1d(self, *args, **kargs): return self.conv(lambda x, *args, **kargs: x.conv_transpose1d(*args,**kargs), *args, **kargs) def conv_transpose2d(self, *args, **kargs): return self.conv(lambda x, *args, **kargs: x.conv_transpose2d(*args,**kargs), *args, **kargs) def conv_transpose3d(self, *args, **kargs): return self.conv(lambda x, *args, **kargs: x.conv_transpose3d(*args,**kargs), *args, **kargs) def matmul(self, other): return self.new(self.head.matmul(other), None if self.beta is None else self.beta.matmul(other.abs()), None if self.errors is None else self.errors.matmul(other)) def unsqueeze(self, i): return self.new(self.head.unsqueeze(i), None if self.beta is None else self.beta.unsqueeze(i), None if self.errors is None else self.errors.unsqueeze(i + 1)) def squeeze(self, dim): return self.new(self.head.squeeze(dim), None if self.beta is None else self.beta.squeeze(dim), None if self.errors is None else self.errors.squeeze(dim + 1 if dim >= 0 else dim)) def double(self): return self.new(self.head.double(), self.beta.double() if self.beta is not None else None, self.errors.double() if self.errors is not None else None) def float(self): return self.new(self.head.float(), self.beta.float() if self.beta is not None else None, self.errors.float() if self.errors is not None else None) def to_dtype(self): return self.new(self.head.to_dtype(), self.beta.to_dtype() if self.beta is not None else None, self.errors.to_dtype() if self.errors is not None else None) def sum(self, dim=1): return self.new(torch.sum(self.head,dim=dim), None if self.beta is None else torch.sum(self.beta,dim=dim), None if self.errors is None else torch.sum(self.errors, dim= dim + 1 if dim >= 0 else dim)) def view(self,*newshape): return self.new(self.head.view(*newshape), None if self.beta is None else self.beta.view(*newshape), None if self.errors is None else self.errors.view(self.errors.size()[0], *newshape)) def gather(self,dim, index): return self.new(self.head.gather(dim, index), None if self.beta is None else self.beta.gather(dim, index), None if self.errors is None else self.errors.gather(dim + 1, index.expand([self.errors.size()[0]] + list(index.size())))) def concretize(self): if self.errors is None: return self return self.new(self.head, torch.sum(self.concreteErrors().abs(),0), None) # maybe make a box? def cat(self,other, dim=0): return self.new(self.head.cat(other.head, dim = dim), h.msum(other.beta, self.beta, lambda a,b: a.cat(b, dim = dim)), h.msum(self.errors, other.errors, catNonNullErrors(lambda a,b: a.cat(b, dim+1)))) def split(self, split_size, dim = 0): heads = list(self.head.split(split_size, dim)) betas = list(self.beta.split(split_size, dim)) if not self.beta is None else None errorss = list(self.errors.split(split_size, dim + 1)) if not self.errors is None else None def makeFromI(i): return self.new( heads[i], None if betas is None else betas[i], None if errorss is None else errorss[i]) return tuple(makeFromI(i) for i in range(len(heads))) def concreteErrors(self): if self.beta is None and self.errors is None: raise Exception("shouldn't have both beta and errors be none") if self.errors is None: return self.beta.unsqueeze(0) if self.beta is None: return self.errors return torch.cat([self.beta.unsqueeze(0),self.errors], dim=0) def applyMonotone(self, foo, *args, **kargs): if self.beta is None and self.errors is None: return self.new(foo(self.head), None , None) beta = self.concreteErrors().abs().sum(dim=0) tp = foo(self.head + beta, *args, **kargs) bt = foo(self.head - beta, *args, **kargs) new_hybrid = self.new((tp + bt) / 2, (tp - bt) / 2 , None) if self.errors is not None: return new_hybrid.correlateMaxK(self.errors.shape[0]) return new_hybrid def avg_pool2d(self, *args, **kargs): nhead = F.avg_pool2d(self.head, *args, **kargs) return self.new(nhead, None if self.beta is None else F.avg_pool2d(self.beta, *args, **kargs), None if self.errors is None else F.avg_pool2d(self.errors.view(-1, *self.head.shape[1:]), *args, **kargs).view(-1,*nhead.shape)) def adaptive_avg_pool2d(self, *args, **kargs): nhead = F.adaptive_avg_pool2d(self.head, *args, **kargs) return self.new(nhead, None if self.beta is None else F.adaptive_avg_pool2d(self.beta, *args, **kargs), None if self.errors is None else F.adaptive_avg_pool2d(self.errors.view(-1, *self.head.shape[1:]), *args, **kargs).view(-1,*nhead.shape)) def elu(self): return self.applyMonotone(F.elu) def selu(self): return self.applyMonotone(F.selu) def sigm(self): return self.applyMonotone(F.sigmoid) def softplus(self): if self.errors is None: if self.beta is None: return self.new(F.softplus(self.head), None , None) tp = F.softplus(self.head + self.beta) bt = F.softplus(self.head - self.beta) return self.new((tp + bt) / 2, (tp - bt) / 2 , None) errors = self.concreteErrors() o = h.ones(self.head.size()) def sp(hd): return F.softplus(hd) # torch.log(o + torch.exp(hd)) # not very stable def spp(hd): ehd = torch.exp(hd) return ehd.div(ehd + o) def sppp(hd): ehd = torch.exp(hd) md = ehd + o return ehd.div(md.mul(md)) fa = sp(self.head) fpa = spp(self.head) a = self.head k = torch.sum(errors.abs(), 0) def evalG(r): return r.mul(r).mul(sppp(a + r)) m = torch.max(evalG(h.zeros(k.size())), torch.max(evalG(k), evalG(-k))) m = h.ifThenElse( a.abs().lt(k), torch.max(m, torch.max(evalG(a), evalG(-a))), m) m /= 2 return self.new(fa, m if self.beta is None else m + self.beta.mul(fpa), None if self.errors is None else self.errors.mul(fpa)) def center(self): return self.head def vanillaTensorPart(self): return self.head def lb(self): return self.head - self.concreteErrors().abs().sum(dim=0) def ub(self): return self.head + self.concreteErrors().abs().sum(dim=0) def size(self): return self.head.size() def diameter(self): abal = torch.abs(self.concreteErrors()).transpose(0,1) return abal.sum(1).sum(1) # perimeter def loss(self, target, **args): r = -h.preDomRes(self, target).lb() return F.softplus(r.max(1)[0]) def deep_loss(self, act = F.relu, *args, **kargs): batch_size = self.head.shape[0] inds = torch.arange(batch_size, device=h.device).unsqueeze(1).long() def dl(l,u): ls, lsi = torch.sort(l, dim=1) ls_u = u[inds, lsi] def slidingMax(a): # using maxpool k = a.shape[1] ml = a.min(dim=1)[0].unsqueeze(1) inp = torch.cat((h.zeros([batch_size, k]), a - ml), dim=1) mpl = F.max_pool1d(inp.unsqueeze(1) , kernel_size = k, stride=1, padding = 0, return_indices=False).squeeze(1) return mpl[:,:-1] + ml return act(slidingMax(ls_u) - ls).sum(dim=1) l = self.lb().view(batch_size, -1) u = self.ub().view(batch_size, -1) return ( dl(l,u) + dl(-u,-l) ) / (2 * l.shape[1]) # make it easier to regularize against class Zonotope(HybridZonotope): def applySuper(self, ret): batches = ret.head.size()[0] num_elem = h.product(ret.head.size()[1:]) ei = h.getEi(batches, num_elem) if len(ret.head.size()) > 2: ei = ei.contiguous().view(num_elem, *ret.head.size()) ret.errors = torch.cat( (ret.errors, ei * ret.beta) ) if not ret.beta is None else ret.errors ret.beta = None return ret.checkSizes() def zono_to_hybrid(self, *args, customRelu = None, **kargs): # we are already a hybrid zono. return HybridZonotope(self.head, self.beta, self.errors, customRelu = self.customRelu if customRelu is None else customRelu) def hybrid_to_zono(self, *args, **kargs): return self.new(self.head, self.beta, self.errors, **kargs) def applyMonotone(self, *args, **kargs): return self.applySuper(super(Zonotope,self).applyMonotone(*args, **kargs)) def softplus(self): return self.applySuper(super(Zonotope,self).softplus()) def relu(self): return self.applySuper(super(Zonotope,self).relu()) def splitRelu(self, *args, **kargs): return [self.applySuper(a) for a in super(Zonotope, self).splitRelu(*args, **kargs)] def mysign(x): e = x.eq(0).to_dtype() r = x.sign().to_dtype() return r + e def mulIfEq(grad,out,target): pred = out.max(1, keepdim=True)[1] is_eq = pred.eq(target.view_as(pred)).to_dtype() is_eq = is_eq.view([-1] + [1 for _ in grad.size()[1:]]).expand_as(grad) return is_eq def stdLoss(out, target): if torch.__version__[0] == "0": return F.cross_entropy(out, target, reduce = False) else: return F.cross_entropy(out, target, reduction='none') class ListDomain(object): def __init__(self, al, *args, **kargs): self.al = list(al) def new(self, *args, **kargs): return self.__class__(*args, **kargs) def isSafe(self,*args,**kargs): raise "Domain Not Suitable For Testing" def labels(self): raise "Domain Not Suitable For Testing" def isPoint(self): return all(a.isPoint() for a in self.al) def __mul__(self, flt): return self.new(a.__mul__(flt) for a in self.al) def __truediv__(self, flt): return self.new(a.__truediv__(flt) for a in self.al) def __add__(self, other): if isinstance(other, ListDomain): return self.new(a.__add__(o) for a,o in zip(self.al, other.al)) else: return self.new(a.__add__(other) for a in self.al) def merge(self, other, ref = None): if ref is None: return self.new(a.merge(o) for a,o in zip(self.al,other.al) ) return self.new(a.merge(o, ref = r) for a,o,r in zip(self.al,other.al, ref.al)) def addPar(self, a, b): return self.new(s.addPar(av,bv) for s,av,bv in zip(self.al, a.al, b.al)) def __sub__(self, other): if isinstance(other, ListDomain): return self.new(a.__sub__(o) for a,o in zip(self.al, other.al)) else: return self.new(a.__sub__(other) for a in self.al) def abstractApplyLeaf(self, *args, **kargs): return self.new(a.abstractApplyLeaf(*args, **kargs) for a in self.al) def bmm(self, other): return self.new(a.bmm(other) for a in self.al) def matmul(self, other): return self.new(a.matmul(other) for a in self.al) def conv(self, *args, **kargs): return self.new(a.conv(*args, **kargs) for a in self.al) def conv1d(self, *args, **kargs): return self.new(a.conv1d(*args, **kargs) for a in self.al) def conv2d(self, *args, **kargs): return self.new(a.conv2d(*args, **kargs) for a in self.al) def conv3d(self, *args, **kargs): return self.new(a.conv3d(*args, **kargs) for a in self.al) def max_pool2d(self, *args, **kargs): return self.new(a.max_pool2d(*args, **kargs) for a in self.al) def avg_pool2d(self, *args, **kargs): return self.new(a.avg_pool2d(*args, **kargs) for a in self.al) def adaptive_avg_pool2d(self, *args, **kargs): return self.new(a.adaptive_avg_pool2d(*args, **kargs) for a in self.al) def unsqueeze(self, *args, **kargs): return self.new(a.unsqueeze(*args, **kargs) for a in self.al) def squeeze(self, *args, **kargs): return self.new(a.squeeze(*args, **kargs) for a in self.al) def view(self, *args, **kargs): return self.new(a.view(*args, **kargs) for a in self.al) def gather(self, *args, **kargs): return self.new(a.gather(*args, **kargs) for a in self.al) def sum(self, *args, **kargs): return self.new(a.sum(*args,**kargs) for a in self.al) def double(self): return self.new(a.double() for a in self.al) def float(self): return self.new(a.float() for a in self.al) def to_dtype(self): return self.new(a.to_dtype() for a in self.al) def vanillaTensorPart(self): return self.al[0].vanillaTensorPart() def center(self): return self.new(a.center() for a in self.al) def ub(self): return self.new(a.ub() for a in self.al) def lb(self): return self.new(a.lb() for a in self.al) def relu(self): return self.new(a.relu() for a in self.al) def splitRelu(self, *args, **kargs): return self.new(a.splitRelu(*args, **kargs) for a in self.al) def softplus(self): return self.new(a.softplus() for a in self.al) def elu(self): return self.new(a.elu() for a in self.al) def selu(self): return self.new(a.selu() for a in self.al) def sigm(self): return self.new(a.sigm() for a in self.al) def cat(self, other, *args, **kargs): return self.new(a.cat(o, *args, **kargs) for a,o in zip(self.al, other.al)) def split(self, *args, **kargs): return [self.new(*z) for z in zip(a.split(*args, **kargs) for a in self.al)] def size(self): return self.al[0].size() def loss(self, *args, **kargs): return sum(a.loss(*args, **kargs) for a in self.al) def deep_loss(self, *args, **kargs): return sum(a.deep_loss(*args, **kargs) for a in self.al) def checkSizes(self): for a in self.al: a.checkSizes() return self class TaggedDomain(object): def __init__(self, a, tag = None): self.tag = tag self.a = a def isSafe(self,*args,**kargs): return self.a.isSafe(*args, **kargs) def isPoint(self): return self.a.isPoint() def labels(self): raise "Domain Not Suitable For Testing" def __mul__(self, flt): return TaggedDomain(self.a.__mul__(flt), self.tag) def __truediv__(self, flt): return TaggedDomain(self.a.__truediv__(flt), self.tag) def __add__(self, other): if isinstance(other, TaggedDomain): return TaggedDomain(self.a.__add__(other.a), self.tag) else: return TaggedDomain(self.a.__add__(other), self.tag) def addPar(self, a,b): return TaggedDomain(self.a.addPar(a.a, b.a), self.tag) def __sub__(self, other): if isinstance(other, TaggedDomain): return TaggedDomain(self.a.__sub__(other.a), self.tag) else: return TaggedDomain(self.a.__sub__(other), self.tag) def bmm(self, other): return TaggedDomain(self.a.bmm(other), self.tag) def matmul(self, other): return TaggedDomain(self.a.matmul(other), self.tag) def conv(self, *args, **kargs): return TaggedDomain(self.a.conv(*args, **kargs) , self.tag) def conv1d(self, *args, **kargs): return TaggedDomain(self.a.conv1d(*args, **kargs), self.tag) def conv2d(self, *args, **kargs): return TaggedDomain(self.a.conv2d(*args, **kargs), self.tag) def conv3d(self, *args, **kargs): return TaggedDomain(self.a.conv3d(*args, **kargs), self.tag) def max_pool2d(self, *args, **kargs): return TaggedDomain(self.a.max_pool2d(*args, **kargs), self.tag) def avg_pool2d(self, *args, **kargs): return TaggedDomain(self.a.avg_pool2d(*args, **kargs), self.tag) def adaptive_avg_pool2d(self, *args, **kargs): return TaggedDomain(self.a.adaptive_avg_pool2d(*args, **kargs), self.tag) def unsqueeze(self, *args, **kargs): return TaggedDomain(self.a.unsqueeze(*args, **kargs), self.tag) def squeeze(self, *args, **kargs): return TaggedDomain(self.a.squeeze(*args, **kargs), self.tag) def abstractApplyLeaf(self, *args, **kargs): return TaggedDomain(self.a.abstractApplyLeaf(*args, **kargs), self.tag) def view(self, *args, **kargs): return TaggedDomain(self.a.view(*args, **kargs), self.tag) def gather(self, *args, **kargs): return TaggedDomain(self.a.gather(*args, **kargs), self.tag) def sum(self, *args, **kargs): return TaggedDomain(self.a.sum(*args,**kargs), self.tag) def double(self): return TaggedDomain(self.a.double(), self.tag) def float(self): return TaggedDomain(self.a.float(), self.tag) def to_dtype(self): return TaggedDomain(self.a.to_dtype(), self.tag) def vanillaTensorPart(self): return self.a.vanillaTensorPart() def center(self): return TaggedDomain(self.a.center(), self.tag) def ub(self): return TaggedDomain(self.a.ub(), self.tag) def lb(self): return TaggedDomain(self.a.lb(), self.tag) def relu(self): return TaggedDomain(self.a.relu(), self.tag) def splitRelu(self, *args, **kargs): return TaggedDomain(self.a.splitRelu(*args, **kargs), self.tag) def diameter(self): return self.a.diameter() def softplus(self): return TaggedDomain(self.a.softplus(), self.tag) def elu(self): return TaggedDomain(self.a.elu(), self.tag) def selu(self): return TaggedDomain(self.a.selu(), self.tag) def sigm(self): return TaggedDomain(self.a.sigm(), self.tag) def cat(self, other, *args, **kargs): return TaggedDomain(self.a.cat(other.a, *args, **kargs), self.tag) def split(self, *args, **kargs): return [TaggedDomain(z, self.tag) for z in self.a.split(*args, **kargs)] def size(self): return self.a.size() def loss(self, *args, **kargs): return self.tag.loss(self.a, *args, **kargs) def deep_loss(self, *args, **kargs): return self.a.deep_loss(*args, **kargs) def checkSizes(self): self.a.checkSizes() return self def merge(self, other, ref = None): return TaggedDomain(self.a.merge(other.a, ref = None if ref is None else ref.a), self.tag)