khulnasoft commited on
Commit
746c674
·
verified ·
1 Parent(s): cd6bb5d

Upload 16 files

Browse files
Files changed (16) hide show
  1. AllExperimentsSerial.sh +33 -0
  2. LICENSE +21 -0
  3. README.md +249 -0
  4. __init__.py +8 -0
  5. __main__.py +561 -0
  6. ai.py +1064 -0
  7. components.py +951 -0
  8. convert.py +144 -0
  9. goals.py +529 -0
  10. helpers.py +489 -0
  11. losses.py +60 -0
  12. media/overview.png +0 -0
  13. media/resnetTinyFewCombo.png +0 -0
  14. models.py +120 -0
  15. requirements.txt +6 -0
  16. scheduling.py +120 -0
AllExperimentsSerial.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Baseline
2
+ python . -D CIFAR10 -n ResNetTiny -d "LinMix(a=Point(), b=Box(w=Lin(0,0.031373,150,10)), bw=Lin(0,0.5,150,10))" --batch-size 50 --width 0.031373 --lr 0.001 --normalize-layer True --clip-norm False --lr-multistep $1
3
+ # InSamp
4
+ python . -D CIFAR10 -n ResNetTiny -d "LinMix(a=Point(), b=InSamp(Lin(0,1,150,10)), bw=Lin(0,0.5, 150,10))" --batch-size 50 --width 0.031373 --lr 0.001 --normalize-layer True --clip-norm False --lr-multistep $1
5
+ # InSampLPA
6
+ python . -D CIFAR10 -n ResNetTiny -d "LinMix(a=Point(), b=InSamp(Lin(0,1,150,20), w=Lin(0,0.031373, 150, 20)), bw=Lin(0,0.5, 150, 20))" --batch-size 50 --width 0.031373 --lr 0.001 --normalize-layer True --clip-norm False --lr-multistep $1
7
+ # Adv_{1}InSampLPA
8
+ python . -D CIFAR10 -n ResNetTiny -d "LinMix(a=IFGSM(w=Lin(0,0.031373,20,20), k=1), b=InSamp(Lin(0,1,150,10), w=Lin(0,0.031373,150,10)), bw=Lin(0,0.5,150,10))" --batch-size 50 --width 0.031373 --lr 0.001 --normalize-layer True --clip-norm False --lr-multistep $1
9
+ # Adv_{3}InSampLPA
10
+ python . -D CIFAR10 -n ResNetTiny -d "LinMix(a=IFGSM(w=Lin(0,0.031373,20,20), k=3), b=InSamp(Lin(0,1,150,10), w=Lin(0,0.031373,150,10)), bw=Lin(0,0.5,150,10))" --batch-size 50 --width 0.031373 --lr 0.001 --normalize-layer True --clip-norm False --lr-multistep $1
11
+
12
+
13
+ # Baseline
14
+ python . -D CIFAR10 -n ResNetTiny_FewCombo -d "LinMix(a=Point(), b=Box(w=Lin(0,0.031373,150,10)), bw=Lin(0,0.5,150,10))" --batch-size 50 --width 0.031373 --lr 0.001 --normalize-layer True --clip-norm False --lr-multistep $1
15
+ # InSamp
16
+ python . -D CIFAR10 -n ResNetTiny_FewCombo -d "LinMix(a=Point(), b=InSamp(Lin(0,1,150,10)), bw=Lin(0,0.5, 150,10))" --batch-size 50 --width 0.031373 --lr 0.001 --normalize-layer True --clip-norm False --lr-multistep $1
17
+ # InSampLPA
18
+ python . -D CIFAR10 -n ResNetTiny_FewCombo -d "LinMix(a=Point(), b=InSamp(Lin(0,1,150,20), w=Lin(0,0.031373, 150, 20)), bw=Lin(0,0.5, 150, 20))" --batch-size 50 --width 0.031373 --lr 0.001 --normalize-layer True --clip-norm False --lr-multistep $1
19
+ # Adv_{1}InSampLPA
20
+ python . -D CIFAR10 -n ResNetTiny_FewCombo -d "LinMix(a=IFGSM(w=Lin(0,0.031373,20,20), k=1), b=InSamp(Lin(0,1,150,10), w=Lin(0,0.031373,150,10)), bw=Lin(0,0.5,150,10))" --batch-size 50 --width 0.031373 --lr 0.001 --normalize-layer True --clip-norm False --lr-multistep $1
21
+ # Adv_{3}InSampLPA
22
+ python . -D CIFAR10 -n ResNetTiny_FewCombo -d "LinMix(a=IFGSM(w=Lin(0,0.031373,20,20), k=3), b=InSamp(Lin(0,1,150,10), w=Lin(0,0.031373,150,10)), bw=Lin(0,0.5,150,10))" --batch-size 50 --width 0.031373 --lr 0.001 --normalize-layer True --clip-norm False --lr-multistep $1
23
+
24
+ # Adv_{1}InSampLPA
25
+ python . -D CIFAR10 -n ResNetTiny_ManyFixed -d "LinMix(a=IFGSM(w=Lin(0,0.031373,20,20), k=1), b=InSamp(Lin(0,1,150,10), w=Lin(0,0.031373,150,10)), bw=Lin(0,0.5,150,10))" --batch-size 50 --width 0.031373 --lr 0.001 --normalize-layer True --clip-norm False --lr-multistep $1
26
+
27
+ # InSamp_{18}
28
+ python . -D CIFAR10 -n SkipNet18 -d "LinMix(a=Point(), b=InSamp(Lin(0,1,200,40)), bw=Lin(0,0.5,200,40))" -t "MI_FGSM(k=20,r=2)" --batch-size 100 --save-freq 2 --width 0.031373 --lr 0.1 --normalize-layer True --clip-norm False --lr-multistep --sgd --custom-schedule "[10,20,250,300,350]" $1
29
+ # Adv_{5}InSamp_{18}
30
+ python . -D CIFAR10 -n SkipNet18 -d "LinMix(a=IFGSM(w=Lin(0,0.031373,20,20)), b=InSamp(Lin(0,1,200,40)), bw=Lin(0,0.5,200,40))" -t "MI_FGSM(k=20,r=2)" --batch-size 100 --width 0.031373 --lr 0.1 --normalize-layer True --clip-norm False --lr-multistep --sgd --custom-schedule "[10,20,250,300,350]" $1
31
+ # InSamp_{18} Combo
32
+ python . -D CIFAR10 -n SkipNet18_Combo -d "LinMix(b=InSamp(Lin(0,1,200,40)), bw=Lin(0,0.5, 200, 40))" --batch-size 100 --width 0.031373 --lr 0.1 --normalize-layer True --clip-norm False --sgd --lr-multistep --custom-schedule "[10,20,250,300,350]" $1
33
+
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2018 SRI Lab, ETH Zurich
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DiffAI v3 <a href="https://www.sri.inf.ethz.ch/"><img width="100" alt="portfolio_view" align="right" src="http://safeai.ethz.ch/img/sri-logo.svg"></a>
2
+ =============================================================================================================
3
+
4
+ ![High Level](https://raw.githubusercontent.com/eth-sri/diffai/v3/media/overview.png)
5
+
6
+
7
+
8
+ DiffAI is a system for training neural networks to be provably robust and for proving that they are robust.
9
+
10
+ Background
11
+ ----------
12
+
13
+ By now, it is well known that otherwise working networks can be tricked by clever attacks. For example [Goodfellow et al.](https://arxiv.org/abs/1412.6572) demonstrated a network with high classification accuracy which classified one image of a panda correctly, and a seemingly identical attack picture
14
+ incorrectly. Many defenses against this type of attack have been produced, but very few produce networks for which *provably* verifying the safety of a prediction is feasible.
15
+
16
+ Abstract Interpretation is a technique for verifying properties of programs by soundly overapproximating their behavior. When applied to neural networks, an infinite set (a ball) of possible inputs is passed to an approximating "abstract" network
17
+ to produce a superset of the possible outputs from the actual network. Provided an appropreate representation for these sets, demonstrating that the network classifies everything in the ball correctly becomes a simple task. The method used to represent these sets is the abstract domain, and the specific approximations are the abstract transformers.
18
+
19
+ In DiffAI, the entire abstract interpretation process is programmed using PyTorch so that it is differentiable and can be run on the GPU,
20
+ and a loss function is crafted so that low values correspond to inputs which can be proved safe (robust).
21
+
22
+ Whats New In v3?
23
+ ----------------
24
+
25
+ * Abstract Networks: one can now customize the handling of the domains on a per-layer basis.
26
+ * Training DSL: A DSL has been exposed to allow for custom training regimens with complex parameter scheduling.
27
+ * Cross Loss: The box goal now uses the cross entropy style loss by default as suggested by [Gowal et al. 2019](https://arxiv.org/abs/1810.12715)
28
+ * Conversion to Onyx: We can now export to the onyx format, and can export the abstract network itself to onyx (so that one can run abstract analysis or training using tensorflow for example).
29
+
30
+ Requirements
31
+ ------------
32
+
33
+ python 3.6.7, and virtualenv, torch 0.4.1.
34
+
35
+ Recommended Setup
36
+ -----------------
37
+
38
+ ```
39
+ $ git clone https://github.com/eth-sri/DiffAI.git
40
+ $ cd DiffAI
41
+ $ virtualenv pytorch --python python3.6
42
+ $ source pytorch/bin/activate
43
+ (pytorch) $ pip install -r requirements.txt
44
+ ```
45
+
46
+ Note: you need to activate your virtualenv every time you start a new shell.
47
+
48
+ Getting Started
49
+ ---------------
50
+
51
+ DiffAI can be run as a standalone program. To see a list of arguments, type
52
+
53
+ ```
54
+ (pytorch) $ python . --help
55
+ ```
56
+
57
+ At the minimum, DiffAI expects at least one domain to train with and one domain to test with, and a network with which to test. For example, to train with the Box domain, baseline training (Point) and test against the FGSM attack and the ZSwitch domain with a simple feed forward network on the MNIST dataset (default, if none provided), you would type:
58
+
59
+ ```
60
+ (pytorch) $ python . -d "Point()" -d "Box()" -t "PGD()" -t "ZSwitch()" -n ffnn
61
+ ```
62
+
63
+ Unless otherwise specified by "--out", the output is logged to the folder "out/".
64
+ In the folder corresponding to the experiment that has been run, one can find the saved configuration options in
65
+ "config.txt", and a pickled net which is saved every 10 epochs (provided that testing is set to happen every 10th epoch).
66
+
67
+ To load a saved model, use "--test" as per the example:
68
+
69
+ ```
70
+ (pytorch) $ alias test-diffai="python . -d Point --epochs 1 --dont-write --test-freq 1"
71
+ (pytorch) $ test-diffai -t Box --update-test-net-name convBig --test PATHTOSAVED_CONVBIG.pynet --width 0.1 --test-size 500 --test-batch-size 500
72
+ ```
73
+
74
+ Note that "--update-test-net-name" will create a new model based on convBig and try to use the weights in the pickled PATHTOSAVED_CONVBIG.pynet to initialize that models weights. This is not always necessary, but is useful when the code for a model changes (in components) but does not effect the number or usage of weight, or when loading a model pickled by a cuda process into a cpu process.
75
+
76
+ The default specification type is the L_infinity Ball specified explicitly by "--spec boxSpec",
77
+ which uses an epsilon specified by "--width"
78
+
79
+ The default specification type is the L_infinity Ball specified explicitly by "--spec boxSpec",
80
+ which uses an epsilon specified by "--width"
81
+
82
+ Abstract Networks
83
+ -----------------
84
+
85
+ ![Example Abstract Net](https://raw.githubusercontent.com/eth-sri/diffai/master/media/resnetTinyFewCombo.png)
86
+
87
+ A cruical point of DiffAI v3 is that how a network is trained and abstracted should be part of the network description itself. In this release, we provide layers that allow one to alter how the abstraction works,
88
+ in addition to providing a script for converting an abstract network to onyx so that the abstract analysis might be run in tensorflow.
89
+ Below is a list of the abstract layers that we have included.
90
+
91
+ * CorrMaxPool3D
92
+ * CorrMaxPool2D
93
+ * CorrFix
94
+ * CorrMaxK
95
+ * CorrRand
96
+ * DecorrRand
97
+ * DecorrMin
98
+ * DeepLoss
99
+ * ToZono
100
+ * ToHZono
101
+ * Concretize
102
+ * CorrelateAll
103
+
104
+ Training Domain DSL
105
+ -------------------
106
+
107
+ In DiffAI v3, a dsl has been provided to specify arbitrary training domains. In particular, it is now possible to train on combinations of attacks and abstract domains on specifications defined by attacks. Specifying training domains is possible in the command line using ```-d "DOMAIN_INITIALIZATION"```. The possible combinations are the classes listed in domains.py. The same syntax is also supported for testing domains, to allow for testing robustness with different epsilon-sized attacks and specifications.
108
+
109
+ Listed below are a few examples:
110
+
111
+ * ```-t "IFGSM(k=4, w=0.1)" -t "ZNIPS(w=0.3)" ``` Will first test with the PGD attack with an epsilon=w=0.1 and, the number of iterations k=4 and step size set to w/k. It will also test with the zonotope domain using the transformer specified in our [NIPS 2018 paper](https://www.sri.inf.ethz.ch/publications/singh2018effective) with an epsilon=w=0.3.
112
+
113
+ * ```-t "PGD(r=3,k=16,restart=2, w=0.1)"``` tests on points found using PGD with a step size of r*w/k and two restarts, and an attack-generated specification.
114
+
115
+ * ```-d Point()``` is standard non-defensive training.
116
+
117
+ * ```-d "LinMix(a=IFGSM(), b=Box(), aw=1, bw=0.1)"``` trains on points produced by pgd with the default parameters listed in domains.py, and points produced using the box domain. The loss is combined linearly using the weights aw and bw and scaled by 1/(aw + bw). The epsilon used for both is the ambient epsilon specified with "--width".
118
+
119
+ * ```-d "DList((IFGSM(w=0.1),1), (Box(w=0.01),0.1), (Box(w=0.1),0.01))"``` is a generalization of the Mix domain allowing for training with arbitrarily many domains at once weighted by the given values (the resulting loss is scaled by the inverse of the sum of weights).
120
+
121
+ * ```-d "AdvDom(a=IFGSM(), b=Box())"``` trains using the Box domain, but constructs specifications as L∞ balls containing the PGD attack image and the original image "o".
122
+
123
+ * ```-d "BiAdv(a=IFGSM(), b=Box())"``` is similar, but creates specifications between the pgd attack image "a" and "o - (a - o)".
124
+
125
+ One domain we have found particularly useful for training is ```Mix(a=PGD(r=3,k=16,restart=2, w=0.1), b=BiAdv(a=IFGSM(k=5, w=0.05)), bw=0.1)```.
126
+
127
+ While the above domains are all deterministic (up to gpu error and shuffling orders), we have also implemented nondeterministic training domains:
128
+
129
+ * ```-d "Coin(a=IFGSM(), b=Box(), aw=1, bw=0.1)"``` is like Mix, but chooses which domain to train a batch with by the probabilities determined by aw / (aw + bw) and bw / (aw + bw).
130
+
131
+ * ```-d "DProb((IFGSM(w=0.1),1), (Box(w=0.01),0.1), (Box(w=0.1),0.01))"``` is to Coin what DList is to Mix.
132
+
133
+ * ```-d AdvDom(a=IFGSM(), b=DList((PointB(),1), (PointA(), 1), (Box(), 0.2)))``` can be used to share attack images between multiple training types. Here an attack image "m" is found using PGD, then both the original image "o" and the attack image "m" are passed to DList which trains using three different ways: PointA trains with "o", PointB trains with "m", and Box trains on the box produced between them. This can also be used with Mix.
134
+
135
+ * ```-d Normal(w=0.3)``` trains using images sampled from a normal distribution around the provided image using standard deviation w.
136
+
137
+ * ```-d NormalAdv(a=IFGSM(), w=0.3)``` trains using PGD (but this could be an abstract domain) where perturbations are constrained to a box determined by a normal distribution around the original image with standard deviation w.
138
+
139
+ * ```-d InSamp(0.2, w=0.1)``` uses Inclusion sampling as defined in the ArXiv paper.
140
+
141
+ There are more domains implemented than listed here, and of course more interesting combinations are possible. Please look carefully at domains.py for default values and further options.
142
+
143
+
144
+ Parameter Scheduling DSL
145
+ ------------------------
146
+
147
+ In place of many constants, you can use the following scheduling devices.
148
+
149
+ * ```Lin(s,e,t,i)``` Linearly interpolates between s and e over t epochs, using s for the first i epochs.
150
+
151
+ * ```Until(t,a,b)``` Uses a for the first t epochs, then switches to using b (telling b the current epoch starting from 0 at epoch t).
152
+
153
+ Suggested Training
154
+ ------------------
155
+
156
+ ```LinMix(a=IFGSM(k=2), b=InSamp(Lin(0,1,150,10)), bw = Lin(0,0.5,150,10))``` is a training goal that appears to work particularly well for CIFAR10 networks.
157
+
158
+ Contents
159
+ --------
160
+
161
+ * components.py: A high level neural network library for composable layers and operations
162
+ * goals.py: The DSL for specifying training losses and domains, and attacks which can be used as a drop in replacement for pytorch tensors in any model built with components from components.py
163
+ * scheduling.py: The DSL for specifying parameter scheduling.
164
+ * models.py: A repository of models to train with which are used in the paper.
165
+ * convert.py: A utility for converting a model with a training or testing domain (goal) into an onyx network. This is useful for exporting DiffAI abstractions to tensorflow.
166
+ * \_\_main\_\_.py: The entry point to run the experiments.
167
+ * helpers.py: Assorted helper functions. Does some monkeypatching, so you might want to be careful importing our library into your project.
168
+ * AllExperimentsSerial.sh: A script which runs the training experiments from the 2019 ArXiv paper from table 4 and 5 and figure 5.
169
+
170
+ Notes
171
+ -----
172
+
173
+ Not all of the datasets listed in the help message are supported. Supported datasets are:
174
+
175
+ * CIFAR10
176
+ * CIFAR100
177
+ * MNIST
178
+ * SVHN
179
+ * FashionMNIST
180
+
181
+ Unsupported datasets will not necessarily throw errors.
182
+
183
+ Reproducing Results
184
+ -------------------
185
+
186
+ [Download Defended Networks](https://www.dropbox.com/sh/66obogmvih79e3k/AACe-tkKGvIK0Z--2tk2alZaa?dl=0)
187
+
188
+ All training runs from the paper can be reproduced as by the following command, in the same order as Table 6 in the appendix.
189
+
190
+ ```
191
+ ./AllExperimentsSerial.sh "-t MI_FGSM(k=20,r=2) -t HBox --test-size 10000 --test-batch-size 200 --test-freq 400 --save-freq 1 --epochs 420 --out all_experiments --write-first True --test-first False"
192
+ ```
193
+
194
+ The training schemes can be written as follows (the names differ slightly from the presentation in the paper):
195
+
196
+ * Baseline: LinMix(a=Point(), b=Box(w=Lin(0,0.031373,150,10)), bw=Lin(0,0.5,150,10))
197
+ * InSamp: LinMix(a=Point(), b=InSamp(Lin(0,1,150,10)), bw=Lin(0,0.5, 150,10))
198
+ * InSampLPA: LinMix(a=Point(), b=InSamp(Lin(0,1,150,20), w=Lin(0,0.031373, 150, 20)), bw=Lin(0,0.5, 150, 20))
199
+ * Adv_{1}ISLPA: LinMix(a=IFGSM(w=Lin(0,0.031373,20,20), k=1), b=InSamp(Lin(0,1,150,10), w=Lin(0,0.031373,150,10)), bw=Lin(0,0.5,150,10))
200
+ * Adv_{3}ISLPA: LinMix(a=IFGSM(w=Lin(0,0.031373,20,20), k=3), b=InSamp(Lin(0,1,150,10), w=Lin(0,0.031373,150,10)), bw=Lin(0,0.5,150,10))
201
+ * Baseline_{18}: LinMix(a=Point(), b=InSamp(Lin(0,1,200,40)), bw=Lin(0,0.5,200,40))
202
+ * InSamp_{18}: LinMix(a=IFGSM(w=Lin(0,0.031373,20,20)), b=InSamp(Lin(0,1,200,40)), bw=Lin(0,0.5,200,40))
203
+ * Adv_{5}IS_{18}: LinMix(b=InSamp(Lin(0,1,200,40)), bw=Lin(0,0.5, 200, 40))
204
+ * BiAdv_L: LinMix(a=IFGSM(k=2), b=BiAdv(a=IFGSM(k=3, w=Lin(0,0.031373, 150, 30)), b=Box()), bw=Lin(0,0.6, 200, 30))
205
+
206
+ To test a saved network as in the paper, use the following command:
207
+
208
+ ```
209
+ python . -D CIFAR10 -n ResNetLarge_LargeCombo -d Point --width 0.031373 --normalize-layer True --clip-norm False -t 'MI_FGSM(k=20,r=2)' -t HBox --test-size 10000 --test-batch-size 200 --epochs 1 --test NAMEOFSAVEDNET.pynet
210
+ ```
211
+
212
+ About
213
+ -----
214
+
215
+ * DiffAI is now on version 3.0.
216
+ * This repository contains the code used for the experiments in the [2019 ArXiV Paper](https://arxiv.org/abs/1903.12519).
217
+ * To reproduce the experiments from the 2018 ICML paper [Differentiable Abstract Interpretation for Provably Robust Neural Networks](https://files.sri.inf.ethz.ch/website/papers/icml18-diffai.pdf), one must download the source from download the [source code for Version 1.0](https://github.com/eth-sri/diffai/releases/tag/v1.0)
218
+ * Further information and related projects can be found at [the SafeAI Project](http://safeai.ethz.ch/)
219
+ * [High level slides](https://files.sri.inf.ethz.ch/website/slides/mirman2018differentiable.pdf)
220
+
221
+ Citing This Framework
222
+ ---------------------
223
+
224
+ ```
225
+ @inproceedings{
226
+ title={Differentiable Abstract Interpretation for Provably Robust Neural Networks},
227
+ author={Mirman, Matthew and Gehr, Timon and Vechev, Martin},
228
+ booktitle={International Conference on Machine Learning (ICML)},
229
+ year={2018},
230
+ url={https://www.icml.cc/Conferences/2018/Schedule?showEvent=2477},
231
+ }
232
+ ```
233
+
234
+ Contributors
235
+ ------------
236
+
237
+ * [Matthew Mirman](https://www.mirman.com) - [email protected]
238
+ * [Gagandeep Singh](https://www.sri.inf.ethz.ch/people/gagandeep) - [email protected]
239
+ * [Timon Gehr](https://www.sri.inf.ethz.ch/tg.php) - [email protected]
240
+ * Marc Fischer - [email protected]
241
+ * [Martin Vechev](https://www.sri.inf.ethz.ch/vechev.php) - [email protected]
242
+
243
+
244
+
245
+ License and Copyright
246
+ ---------------------
247
+
248
+ * Copyright (c) 2018 [Secure, Reliable, and Intelligent Systems Lab (SRI), ETH Zurich](https://www.sri.inf.ethz.ch/)
249
+ * Licensed under the [MIT License](https://opensource.org/licenses/MIT)
__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
5
+ print(SCRIPT_DIR)
6
+ sys.path.append(SCRIPT_DIR)
7
+
8
+
__main__.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import future
2
+ import builtins
3
+ import past
4
+ import six
5
+ import copy
6
+
7
+ from timeit import default_timer as timer
8
+ from datetime import datetime
9
+ import argparse
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.optim as optim
14
+ from torchvision import datasets
15
+ from torch.utils.data import Dataset
16
+ import decimal
17
+ import torch.onnx
18
+
19
+
20
+ import inspect
21
+ from inspect import getargspec
22
+ import os
23
+ import helpers as h
24
+ from helpers import Timer
25
+ import copy
26
+ import random
27
+
28
+ from components import *
29
+ import models
30
+
31
+ import goals
32
+ import scheduling
33
+
34
+ from goals import *
35
+ from scheduling import *
36
+
37
+ import math
38
+
39
+ import warnings
40
+ from torch.serialization import SourceChangeWarning
41
+
42
+ POINT_DOMAINS = [m for m in h.getMethods(goals) if issubclass(m, goals.Point)]
43
+ SYMETRIC_DOMAINS = [goals.Box] + POINT_DOMAINS
44
+
45
+
46
+ datasets.Imagenet12 = None
47
+
48
+ class Top(nn.Module):
49
+ def __init__(self, args, net, ty = Point):
50
+ super(Top, self).__init__()
51
+ self.net = net
52
+ self.ty = ty
53
+ self.w = args.width
54
+ self.global_num = 0
55
+ self.getSpec = getattr(self, args.spec)
56
+ self.sub_batch_size = args.sub_batch_size
57
+ self.curve_width = args.curve_width
58
+ self.regularize = args.regularize
59
+
60
+
61
+ self.speedCount = 0
62
+ self.speed = 0.0
63
+
64
+ def addSpeed(self, s):
65
+ self.speed = (s + self.speed * self.speedCount) / (self.speedCount + 1)
66
+ self.speedCount += 1
67
+
68
+ def forward(self, x):
69
+ return self.net(x)
70
+
71
+ def clip_norm(self):
72
+ self.net.clip_norm()
73
+
74
+ def boxSpec(self, x, target, **kargs):
75
+ return [(self.ty.box(x, w = self.w, model=self, target=target, untargeted=True, **kargs).to_dtype(), target)]
76
+
77
+ def curveSpec(self, x, target, **kargs):
78
+ if self.ty.__class__ in SYMETRIC_DOMAINS:
79
+ return self.boxSpec(x,target, **kargs)
80
+
81
+
82
+ batch_size = x.size()[0]
83
+
84
+ newTargs = [ None for i in range(batch_size) ]
85
+ newSpecs = [ None for i in range(batch_size) ]
86
+ bestSpecs = [ None for i in range(batch_size) ]
87
+
88
+ for i in range(batch_size):
89
+ newTarg = target[i]
90
+ newTargs[i] = newTarg
91
+ newSpec = x[i]
92
+
93
+ best_x = newSpec
94
+ best_dist = float("inf")
95
+ for j in range(batch_size):
96
+ potTarg = target[j]
97
+ potSpec = x[j]
98
+ if (not newTarg.data.equal(potTarg.data)) or i == j:
99
+ continue
100
+ curr_dist = (newSpec - potSpec).norm(1).item() # must experiment with the type of norm here
101
+ if curr_dist <= best_dist:
102
+ best_x = potSpec
103
+
104
+ newSpecs[i] = newSpec
105
+ bestSpecs[i] = best_x
106
+
107
+ new_batch_size = self.sub_batch_size
108
+ batchedTargs = h.chunks(newTargs, new_batch_size)
109
+ batchedSpecs = h.chunks(newSpecs, new_batch_size)
110
+ batchedBest = h.chunks(bestSpecs, new_batch_size)
111
+
112
+ def batch(t,s,b):
113
+ t = h.lten(t)
114
+ s = torch.stack(s)
115
+ b = torch.stack(b)
116
+
117
+ if h.use_cuda:
118
+ t.cuda()
119
+ s.cuda()
120
+ b.cuda()
121
+
122
+ m = self.ty.line(s, b, w = self.curve_width, **kargs)
123
+ return (m , t)
124
+
125
+ return [batch(t,s,b) for t,s,b in zip(batchedTargs, batchedSpecs, batchedBest)]
126
+
127
+
128
+ def regLoss(self):
129
+ if self.regularize is None or self.regularize <= 0.0:
130
+ return 0
131
+ reg_loss = 0
132
+ r = self.net.regularize(2)
133
+ return self.regularize * r
134
+
135
+ def aiLoss(self, dom, target, **args):
136
+ r = self(dom)
137
+ return self.regLoss() + r.loss(target = target, **args)
138
+
139
+ def printNet(self, f):
140
+ self.net.printNet(f)
141
+
142
+
143
+ # Training settings
144
+ parser = argparse.ArgumentParser(description='PyTorch DiffAI Example', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
145
+ parser.add_argument('--batch-size', type=int, default=10, metavar='N', help='input batch size for training')
146
+ parser.add_argument('--test-first', type=h.str2bool, nargs='?', const=True, default=True, help='test first')
147
+ parser.add_argument('--test-freq', type=int, default=1, metavar='N', help='number of epochs to skip before testing')
148
+ parser.add_argument('--test-batch-size', type=int, default=10, metavar='N', help='input batch size for testing')
149
+ parser.add_argument('--sub-batch-size', type=int, default=3, metavar='N', help='input batch size for curve specs')
150
+
151
+ parser.add_argument('--custom-schedule', type=str, default="", metavar='net', help='Learning rate scheduling for lr-multistep. Defaults to [200,250,300] for CIFAR10 and [15,25] for everything else.')
152
+
153
+ parser.add_argument('--test', type=str, default=None, metavar='net', help='Saved net to use, in addition to any other nets you specify with -n')
154
+ parser.add_argument('--update-test-net', type=h.str2bool, nargs='?', const=True, default=False, help="should update test net")
155
+
156
+ parser.add_argument('--sgd',type=h.str2bool, nargs='?', const=True, default=False, help="use sgd instead of adam")
157
+ parser.add_argument('--onyx', type=h.str2bool, nargs='?', const=True, default=False, help="should output onyx")
158
+ parser.add_argument('--save-dot-net', type=h.str2bool, nargs='?', const=True, default=False, help="should output in .net")
159
+ parser.add_argument('--update-test-net-name', type=str, choices = h.getMethodNames(models), default=None, help="update test net name")
160
+
161
+ parser.add_argument('--normalize-layer', type=h.str2bool, nargs='?', const=True, default=True, help="should include a training set specific normalization layer")
162
+ parser.add_argument('--clip-norm', type=h.str2bool, nargs='?', const=True, default=False, help="should clip the normal and use normal decomposition for weights")
163
+
164
+ parser.add_argument('--epochs', type=int, default=1000, metavar='N', help='number of epochs to train')
165
+ parser.add_argument('--log-freq', type=int, default=10, metavar='N', help='The frequency with which log statistics are printed')
166
+ parser.add_argument('--save-freq', type=int, default=1, metavar='N', help='The frequency with which nets and images are saved, in terms of number of test passes')
167
+ parser.add_argument('--number-save-images', type=int, default=0, metavar='N', help='The number of images to save. Should be smaller than test-size.')
168
+
169
+ parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate')
170
+ parser.add_argument('--lr-multistep', type=h.str2bool, nargs='?', const=True, default=False, help='learning rate multistep scheduling')
171
+
172
+ parser.add_argument('--threshold', type=float, default=-0.01, metavar='TH', help='threshold for lr schedule')
173
+ parser.add_argument('--patience', type=int, default=0, metavar='PT', help='patience for lr schedule')
174
+ parser.add_argument('--factor', type=float, default=0.5, metavar='R', help='reduction multiplier for lr schedule')
175
+ parser.add_argument('--max-norm', type=float, default=10000, metavar='MN', help='the maximum norm allowed in weight distribution')
176
+
177
+
178
+ parser.add_argument('--curve-width', type=float, default=None, metavar='CW', help='the width of the curve spec')
179
+
180
+ parser.add_argument('--width', type=float, default=0.01, metavar='CW', help='the width of either the line or box')
181
+ parser.add_argument('--spec', choices = [ x for x in dir(Top) if x[-4:] == "Spec" and len(getargspec(getattr(Top, x)).args) == 3]
182
+ , default="boxSpec", help='picks which spec builder function to use for training')
183
+
184
+
185
+ parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed')
186
+ parser.add_argument("--use-schedule", type=h.str2bool, nargs='?',
187
+ const=True, default=False,
188
+ help="activate learning rate schedule")
189
+
190
+ parser.add_argument('-d', '--domain', sub_choices = None, action = h.SubAct
191
+ , default=[], help='picks which abstract goals to use for training', required=True)
192
+
193
+ parser.add_argument('-t', '--test-domain', sub_choices = None, action = h.SubAct
194
+ , default=[], help='picks which abstract goals to use for testing. Examples include ' + str(goals), required=True)
195
+
196
+ parser.add_argument('-n', '--net', choices = h.getMethodNames(models), action = 'append'
197
+ , default=[], help='picks which net to use for training') # one net for now
198
+
199
+ parser.add_argument('-D', '--dataset', choices = [n for (n,k) in inspect.getmembers(datasets, inspect.isclass) if issubclass(k, Dataset)]
200
+ , default="MNIST", help='picks which dataset to use.')
201
+
202
+ parser.add_argument('-o', '--out', default="out", help='picks which net to use for training')
203
+ parser.add_argument('--dont-write', type=h.str2bool, nargs='?', const=True, default=False, help='dont write anywhere if this flag is on')
204
+ parser.add_argument('--write-first', type=h.str2bool, nargs='?', const=True, default=False, help='write the initial net. Useful for comparing algorithms, a pain for testing.')
205
+ parser.add_argument('--test-size', type=int, default=2000, help='number of examples to test with')
206
+
207
+ parser.add_argument('-r', '--regularize', type=float, default=None, help='use regularization')
208
+
209
+
210
+ args = parser.parse_args()
211
+
212
+ largest_domain = max([len(h.catStrs(d)) for d in (args.domain)] )
213
+ largest_test_domain = max([len(h.catStrs(d)) for d in (args.test_domain)] )
214
+
215
+ args.log_interval = int(50000 / (args.batch_size * args.log_freq))
216
+
217
+ h.max_c_for_norm = args.max_norm
218
+
219
+ if h.use_cuda:
220
+ torch.cuda.manual_seed(1 + args.seed)
221
+ else:
222
+ torch.manual_seed(args.seed)
223
+
224
+ train_loader = h.loadDataset(args.dataset, args.batch_size, True, False)
225
+ test_loader = h.loadDataset(args.dataset, args.test_batch_size, False, False)
226
+
227
+ input_dims = train_loader.dataset[0][0].size()
228
+ num_classes = int(max(getattr(train_loader.dataset, 'train_labels' if args.dataset != "SVHN" else 'labels'))) + 1
229
+
230
+ print("input_dims: ", input_dims)
231
+ print("Num classes: ", num_classes)
232
+
233
+ vargs = vars(args)
234
+
235
+ total_batches_seen = 0
236
+
237
+ def train(epoch, models):
238
+ global total_batches_seen
239
+
240
+ for model in models:
241
+ model.train()
242
+
243
+ for batch_idx, (data, target) in enumerate(train_loader):
244
+ total_batches_seen += 1
245
+ time = float(total_batches_seen) / len(train_loader)
246
+ if h.use_cuda:
247
+ data, target = data.cuda(), target.cuda()
248
+
249
+ for model in models:
250
+ model.global_num += data.size()[0]
251
+
252
+ timer = Timer("train a sample from " + model.name + " with " + model.ty.name, data.size()[0], False)
253
+ lossy = 0
254
+ with timer:
255
+ for s in model.getSpec(data.to_dtype(),target, time = time):
256
+ model.optimizer.zero_grad()
257
+ loss = model.aiLoss(*s, time = time, **vargs).mean(dim=0)
258
+ lossy += loss.detach().item()
259
+ loss.backward()
260
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
261
+ for p in model.parameters():
262
+ if p is not None and torch.isnan(p).any():
263
+ print("Such nan in vals")
264
+ if p is not None and p.grad is not None and torch.isnan(p.grad).any():
265
+ print("Such nan in postmagic")
266
+ stdv = 1 / math.sqrt(h.product(p.data.shape))
267
+ p.grad = torch.where(torch.isnan(p.grad), torch.normal(mean=h.zeros(p.grad.shape), std=stdv), p.grad)
268
+
269
+ model.optimizer.step()
270
+
271
+ for p in model.parameters():
272
+ if p is not None and torch.isnan(p).any():
273
+ print("Such nan in vals after grad")
274
+ stdv = 1 / math.sqrt(h.product(p.data.shape))
275
+ p.data = torch.where(torch.isnan(p.data), torch.normal(mean=h.zeros(p.data.shape), std=stdv), p.data)
276
+
277
+ if args.clip_norm:
278
+ model.clip_norm()
279
+ for p in model.parameters():
280
+ if p is not None and torch.isnan(p).any():
281
+ raise Exception("Such nan in vals after clip")
282
+
283
+ model.addSpeed(timer.getUnitTime())
284
+
285
+ if batch_idx % args.log_interval == 0:
286
+ print(('Train Epoch {:12} {:'+ str(largest_domain) +'}: {:3} [{:7}/{} ({:.0f}%)] \tAvg sec/ex {:1.8f}\tLoss: {:.6f}').format(
287
+ model.name, model.ty.name,
288
+ epoch,
289
+ batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader),
290
+ model.speed,
291
+ lossy))
292
+
293
+
294
+ num_tests = 0
295
+ def test(models, epoch, f = None):
296
+ global num_tests
297
+ num_tests += 1
298
+ class MStat:
299
+ def __init__(self, model):
300
+ model.eval()
301
+ self.model = model
302
+ self.correct = 0
303
+ class Stat:
304
+ def __init__(self, d, dnm):
305
+ self.domain = d
306
+ self.name = dnm
307
+ self.width = 0
308
+ self.max_eps = None
309
+ self.safe = 0
310
+ self.proved = 0
311
+ self.time = 0
312
+ self.domains = [ Stat(h.parseValues(d, goals), h.catStrs(d)) for d in args.test_domain ]
313
+ model_stats = [ MStat(m) for m in models ]
314
+
315
+ num_its = 0
316
+ saved_data_target = []
317
+ for data, target in test_loader:
318
+ if num_its >= args.test_size:
319
+ break
320
+
321
+ if num_tests == 1:
322
+ saved_data_target += list(zip(list(data), list(target)))
323
+
324
+ num_its += data.size()[0]
325
+ if h.use_cuda:
326
+ data, target = data.cuda().to_dtype(), target.cuda()
327
+
328
+ for m in model_stats:
329
+
330
+ with torch.no_grad():
331
+ pred = m.model(data).vanillaTensorPart().max(1, keepdim=True)[1] # get the index of the max log-probability
332
+ m.correct += pred.eq(target.data.view_as(pred)).sum()
333
+
334
+ for stat in m.domains:
335
+ timer = Timer(shouldPrint = False)
336
+ with timer:
337
+ def calcData(data, target):
338
+ box = stat.domain.box(data, w = m.model.w, model=m.model, untargeted = True, target=target).to_dtype()
339
+ with torch.no_grad():
340
+ bs = m.model(box)
341
+ org = m.model(data).vanillaTensorPart().max(1,keepdim=True)[1]
342
+ stat.width += bs.diameter().sum().item() # sum up batch loss
343
+ stat.proved += bs.isSafe(org).sum().item()
344
+ stat.safe += bs.isSafe(target).sum().item()
345
+ # stat.max_eps += 0 # TODO: calculate max_eps
346
+
347
+ if m.model.net.neuronCount() < 5000 or stat.domain in SYMETRIC_DOMAINS:
348
+ calcData(data, target)
349
+ else:
350
+ for d,t in zip(data, target):
351
+ calcData(d.unsqueeze(0),t.unsqueeze(0))
352
+ stat.time += timer.getUnitTime()
353
+
354
+ l = num_its # len(test_loader.dataset)
355
+ for m in model_stats:
356
+ if args.lr_multistep:
357
+ m.model.lrschedule.step()
358
+
359
+ pr_corr = float(m.correct) / float(l)
360
+ if args.use_schedule:
361
+ m.model.lrschedule.step(1 - pr_corr)
362
+
363
+ h.printBoth(('Test: {:12} trained with {:'+ str(largest_domain) +'} - Avg sec/ex {:1.12f}, Accuracy: {}/{} ({:3.1f}%)').format(
364
+ m.model.name, m.model.ty.name,
365
+ m.model.speed,
366
+ m.correct, l, 100. * pr_corr), f = f)
367
+
368
+ model_stat_rec = ""
369
+ for stat in m.domains:
370
+ pr_safe = stat.safe / l
371
+ pr_proved = stat.proved / l
372
+ pr_corr_given_proved = pr_safe / pr_proved if pr_proved > 0 else 0.0
373
+ h.printBoth(("\t{:" + str(largest_test_domain)+"} - Width: {:<36.16f} Pr[Proved]={:<1.3f} Pr[Corr and Proved]={:<1.3f} Pr[Corr|Proved]={:<1.3f} {}Time = {:<7.5f}" ).format(
374
+ stat.name,
375
+ stat.width / l,
376
+ pr_proved,
377
+ pr_safe, pr_corr_given_proved,
378
+ "AvgMaxEps: {:1.10f} ".format(stat.max_eps / l) if stat.max_eps is not None else "",
379
+ stat.time), f = f)
380
+ model_stat_rec += "{}_{:1.3f}_{:1.3f}_{:1.3f}__".format(stat.name, pr_proved, pr_safe, pr_corr_given_proved)
381
+ prepedname = m.model.ty.name.replace(" ", "_").replace(",", "").replace("(", "_").replace(")", "_").replace("=", "_")
382
+ net_file = os.path.join(out_dir, m.model.name +"__" +prepedname + "_checkpoint_"+str(epoch)+"_with_{:1.3f}".format(pr_corr))
383
+
384
+ h.printBoth("\tSaving netfile: {}\n".format(net_file + ".pynet"), f = f)
385
+
386
+ if (num_tests % args.save_freq == 1 or args.save_freq == 1) and not args.dont_write and (num_tests > 1 or args.write_first):
387
+ print("Actually Saving")
388
+ torch.save(m.model.net, net_file + ".pynet")
389
+ if args.save_dot_net:
390
+ with h.mopen(args.dont_write, net_file + ".net", "w") as f2:
391
+ m.model.net.printNet(f2)
392
+ f2.close()
393
+ if args.onyx:
394
+ nn = copy.deepcopy(m.model.net)
395
+ nn.remove_norm()
396
+ torch.onnx.export(nn, h.zeros([1] + list(input_dims)), net_file + ".onyx",
397
+ verbose=False, input_names=["actual_input"] + ["param"+str(i) for i in range(len(list(nn.parameters())))], output_names=["output"])
398
+
399
+
400
+ if num_tests == 1 and not args.dont_write:
401
+ img_dir = os.path.join(out_dir, "images")
402
+ if not os.path.exists(img_dir):
403
+ os.makedirs(img_dir)
404
+ for img_num,(img,target) in zip(range(args.number_save_images), saved_data_target[:args.number_save_images]):
405
+ sz = ""
406
+ for s in img.size():
407
+ sz += str(s) + "x"
408
+ sz = sz[:-1]
409
+
410
+ img_file = os.path.join(img_dir, args.dataset + "_" + sz + "_"+ str(img_num))
411
+ if img_num == 0:
412
+ print("Saving image to: ", img_file + ".img")
413
+ with open(img_file + ".img", "w") as imgfile:
414
+ flatimg = img.view(h.product(img.size()))
415
+ for t in flatimg.cpu():
416
+ print(decimal.Decimal(float(t)).__format__("f"), file=imgfile)
417
+ with open(img_file + ".class" , "w") as imgfile:
418
+ print(int(target.item()), file=imgfile)
419
+
420
+ def createModel(net, domain, domain_name):
421
+ net_weights, net_create = net
422
+ domain.name = domain_name
423
+
424
+ net = net_create()
425
+ m = {}
426
+ for (k,v) in net_weights.state_dict().items():
427
+ m[k] = v.to_dtype()
428
+ net.load_state_dict(m)
429
+
430
+ model = Top(args, net, domain)
431
+ if args.clip_norm:
432
+ model.clip_norm()
433
+ if h.use_cuda:
434
+ model.cuda()
435
+ if args.sgd:
436
+ model.optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
437
+ else:
438
+ model.optimizer = optim.Adam(model.parameters(), lr=args.lr)
439
+
440
+ if args.lr_multistep:
441
+ model.lrschedule = optim.lr_scheduler.MultiStepLR(
442
+ model.optimizer,
443
+ gamma = 0.1,
444
+ milestones = eval(args.custom_schedule) if args.custom_schedule != "" else ([200, 250, 300] if args.dataset == "CIFAR10" else [15, 25]))
445
+ else:
446
+ model.lrschedule = optim.lr_scheduler.ReduceLROnPlateau(
447
+ model.optimizer,
448
+ 'min',
449
+ patience=args.patience,
450
+ threshold= args.threshold,
451
+ min_lr=0.000001,
452
+ factor=args.factor,
453
+ verbose=True)
454
+
455
+ net.name = net_create.__name__
456
+ model.name = net_create.__name__
457
+
458
+ return model
459
+
460
+ out_dir = os.path.join(args.out, args.dataset, str(args.net)[1:-1].replace(", ","_").replace("'",""),
461
+ args.spec, "width_"+str(args.width), h.file_timestamp() )
462
+
463
+ print("Saving to:", out_dir)
464
+
465
+ if not os.path.exists(out_dir) and not args.dont_write:
466
+ os.makedirs(out_dir)
467
+
468
+ print("Starting Training with:")
469
+ with h.mopen(args.dont_write, os.path.join(out_dir, "config.txt"), "w") as f:
470
+ for k in sorted(vars(args)):
471
+ h.printBoth("\t"+k+": "+str(getattr(args,k)), f = f)
472
+ print("")
473
+
474
+ def buildNet(n):
475
+ n = n(num_classes)
476
+ if args.normalize_layer:
477
+ if args.dataset in ["MNIST"]:
478
+ n = Seq(Normalize([0.1307], [0.3081] ), n)
479
+ elif args.dataset in ["CIFAR10", "CIFAR100"]:
480
+ n = Seq(Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), n)
481
+ elif args.dataset in ["SVHN"]:
482
+ n = Seq(Normalize([0.5,0.5,0.5], [0.2, 0.2, 0.2]), n)
483
+ elif args.dataset in ["Imagenet12"]:
484
+ n = Seq(Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]), n)
485
+ n = n.infer(input_dims)
486
+ if args.clip_norm:
487
+ n.clip_norm()
488
+ return n
489
+
490
+ if not args.test is None:
491
+
492
+ test_name = None
493
+
494
+ def loadedNet():
495
+ if test_name is not None:
496
+ n = getattr(models,test_name)
497
+ n = buildNet(n)
498
+ if args.clip_norm:
499
+ n.clip_norm()
500
+ return n
501
+ else:
502
+ with warnings.catch_warnings():
503
+ warnings.simplefilter("ignore", SourceChangeWarning)
504
+ return torch.load(args.test)
505
+
506
+ net = loadedNet().double() if h.dtype == torch.float64 else loadedNet().float()
507
+
508
+
509
+ if args.update_test_net_name is not None:
510
+ test_name = args.update_test_net_name
511
+ elif args.update_test_net and '__name__' in dir(net):
512
+ test_name = net.__name__
513
+
514
+ if test_name is not None:
515
+ loadedNet.__name__ = test_name
516
+
517
+ nets = [ (net, loadedNet) ]
518
+
519
+ elif args.net == []:
520
+ raise Exception("Need to specify at least one net with either -n or --test")
521
+ else:
522
+ nets = []
523
+
524
+ for n in args.net:
525
+ m = getattr(models,n)
526
+ net_create = (lambda m: lambda: buildNet(m))(m) # why doesn't python do scoping right? This is a thunk. It is bad.
527
+ net_create.__name__ = n
528
+ net = buildNet(m)
529
+ net.__name__ = n
530
+ nets += [ (net, net_create) ]
531
+
532
+ print("Name: ", net_create.__name__)
533
+ print("Number of Neurons (relus): ", net.neuronCount())
534
+ print("Number of Parameters: ", sum([h.product(s.size()) for s in net.parameters()]))
535
+ print("Depth (relu layers): ", net.depth())
536
+ print()
537
+ net.showNet()
538
+ print()
539
+
540
+
541
+ if args.domain == []:
542
+ models = [ createModel(net, goals.Box(args.width), "Box") for net in nets]
543
+ else:
544
+ models = h.flat([[createModel(net, h.parseValues(d, goals, scheduling), h.catStrs(d)) for net in nets] for d in args.domain])
545
+
546
+
547
+ with h.mopen(args.dont_write, os.path.join(out_dir, "log.txt"), "w") as f:
548
+ startTime = timer()
549
+ for epoch in range(1, args.epochs + 1):
550
+ if f is not None:
551
+ f.flush()
552
+ if (epoch - 1) % args.test_freq == 0 and (epoch > 1 or args.test_first):
553
+ with Timer("test all models before epoch "+str(epoch), 1):
554
+ test(models, epoch, f)
555
+ if f is not None:
556
+ f.flush()
557
+ h.printBoth("Elapsed-Time: {:.2f}s\n".format(timer() - startTime), f = f)
558
+ if args.epochs <= args.test_freq:
559
+ break
560
+ with Timer("train all models in epoch", 1, f = f):
561
+ train(epoch, models)
ai.py ADDED
@@ -0,0 +1,1064 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import future
2
+ import builtins
3
+ import past
4
+ import six
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.optim as optim
10
+ import torch.autograd
11
+
12
+ from functools import reduce
13
+
14
+ try:
15
+ from . import helpers as h
16
+ except:
17
+ import helpers as h
18
+
19
+
20
+
21
+ def catNonNullErrors(op, ref_errs=None): # the way of things is ugly
22
+ def doop(er1, er2):
23
+ erS, erL = (er1, er2)
24
+ sS, sL = (erS.size()[0], erL.size()[0])
25
+
26
+ if sS == sL: # TODO: here we know we used transformers on either side which didnt introduce new error terms (this is a hack for hybrid zonotopes and doesn't work with adaptive error term adding).
27
+ return op(erS,erL)
28
+
29
+ if ref_errs is not None:
30
+ sz = ref_errs.size()[0]
31
+ else:
32
+ sz = min(sS, sL)
33
+
34
+ p1 = op(erS[:sz], erL[:sz])
35
+ erSrem = erS[sz:]
36
+ erLrem = erS[sz:]
37
+ p2 = op(erSrem, h.zeros(erSrem.shape))
38
+ p3 = op(h.zeros(erLrem.shape), erLrem)
39
+ return torch.cat((p1,p2,p3), dim=0)
40
+ return doop
41
+
42
+ def creluBoxy(dom):
43
+ if dom.errors is None:
44
+ if dom.beta is None:
45
+ return dom.new(F.relu(dom.head), None, None)
46
+ er = dom.beta
47
+ mx = F.relu(dom.head + er)
48
+ mn = F.relu(dom.head - er)
49
+ return dom.new((mn + mx) / 2, (mx - mn) / 2 , None)
50
+
51
+ aber = torch.abs(dom.errors)
52
+
53
+ sm = torch.sum(aber, 0)
54
+
55
+ if not dom.beta is None:
56
+ sm += dom.beta
57
+
58
+ mx = dom.head + sm
59
+ mn = dom.head - sm
60
+
61
+ should_box = mn.lt(0) * mx.gt(0)
62
+ gtz = dom.head.gt(0).to_dtype()
63
+ mx /= 2
64
+ newhead = h.ifThenElse(should_box, mx, gtz * dom.head)
65
+ newbeta = h.ifThenElse(should_box, mx, gtz * (dom.beta if not dom.beta is None else 0))
66
+ newerr = (1 - should_box.to_dtype()) * gtz * dom.errors
67
+
68
+ return dom.new(newhead, newbeta , newerr)
69
+
70
+
71
+ def creluBoxySound(dom):
72
+ if dom.errors is None:
73
+ if dom.beta is None:
74
+ return dom.new(F.relu(dom.head), None, None)
75
+ er = dom.beta
76
+ mx = F.relu(dom.head + er)
77
+ mn = F.relu(dom.head - er)
78
+ return dom.new((mn + mx) / 2, (mx - mn) / 2 + 2e-6 , None)
79
+
80
+ aber = torch.abs(dom.errors)
81
+
82
+ sm = torch.sum(aber, 0)
83
+
84
+ if not dom.beta is None:
85
+ sm += dom.beta
86
+
87
+ mx = dom.head + sm
88
+ mn = dom.head - sm
89
+
90
+ should_box = mn.lt(0) * mx.gt(0)
91
+ gtz = dom.head.gt(0).to_dtype()
92
+ mx /= 2
93
+ newhead = h.ifThenElse(should_box, mx, gtz * dom.head)
94
+ newbeta = h.ifThenElse(should_box, mx + 2e-6, gtz * (dom.beta if not dom.beta is None else 0))
95
+ newerr = (1 - should_box.to_dtype()) * gtz * dom.errors
96
+
97
+ return dom.new(newhead, newbeta, newerr)
98
+
99
+
100
+ def creluSwitch(dom):
101
+ if dom.errors is None:
102
+ if dom.beta is None:
103
+ return dom.new(F.relu(dom.head), None, None)
104
+ er = dom.beta
105
+ mx = F.relu(dom.head + er)
106
+ mn = F.relu(dom.head - er)
107
+ return dom.new((mn + mx) / 2, (mx - mn) / 2 , None)
108
+
109
+ aber = torch.abs(dom.errors)
110
+
111
+ sm = torch.sum(aber, 0)
112
+
113
+ if not dom.beta is None:
114
+ sm += dom.beta
115
+
116
+ mn = dom.head - sm
117
+ mx = sm
118
+ mx += dom.head
119
+
120
+ should_box = mn.lt(0) * mx.gt(0)
121
+ gtz = dom.head.gt(0)
122
+
123
+ mn.neg_()
124
+ should_boxer = mn.gt(mx)
125
+
126
+ mn /= 2
127
+ newhead = h.ifThenElse(should_box, h.ifThenElse(should_boxer, mx / 2, dom.head + mn ), gtz.to_dtype() * dom.head)
128
+ zbet = dom.beta if not dom.beta is None else 0
129
+ newbeta = h.ifThenElse(should_box, h.ifThenElse(should_boxer, mx / 2, mn + zbet), gtz.to_dtype() * zbet)
130
+ newerr = h.ifThenElseL(should_box, 1 - should_boxer, gtz).to_dtype() * dom.errors
131
+
132
+ return dom.new(newhead, newbeta , newerr)
133
+
134
+ def creluSmooth(dom):
135
+ if dom.errors is None:
136
+ if dom.beta is None:
137
+ return dom.new(F.relu(dom.head), None, None)
138
+ er = dom.beta
139
+ mx = F.relu(dom.head + er)
140
+ mn = F.relu(dom.head - er)
141
+ return dom.new((mn + mx) / 2, (mx - mn) / 2 , None)
142
+
143
+ aber = torch.abs(dom.errors)
144
+
145
+ sm = torch.sum(aber, 0)
146
+
147
+ if not dom.beta is None:
148
+ sm += dom.beta
149
+
150
+ mn = dom.head - sm
151
+ mx = sm
152
+ mx += dom.head
153
+
154
+
155
+ nmn = F.relu(-1 * mn)
156
+
157
+ zbet = (dom.beta if not dom.beta is None else 0)
158
+ newheadS = dom.head + nmn / 2
159
+ newbetaS = zbet + nmn / 2
160
+ newerrS = dom.errors
161
+
162
+ mmx = F.relu(mx)
163
+
164
+ newheadB = mmx / 2
165
+ newbetaB = newheadB
166
+ newerrB = 0
167
+
168
+ eps = 0.0001
169
+ t = nmn / (mmx + nmn + eps) # mn.lt(0).to_dtype() * F.sigmoid(nmn - nmx)
170
+
171
+ shouldnt_zero = mx.gt(0).to_dtype()
172
+
173
+ newhead = shouldnt_zero * ( (1 - t) * newheadS + t * newheadB)
174
+ newbeta = shouldnt_zero * ( (1 - t) * newbetaS + t * newbetaB)
175
+ newerr = shouldnt_zero * ( (1 - t) * newerrS + t * newerrB)
176
+
177
+ return dom.new(newhead, newbeta , newerr)
178
+
179
+
180
+ def creluNIPS(dom):
181
+ if dom.errors is None:
182
+ if dom.beta is None:
183
+ return dom.new(F.relu(dom.head), None, None)
184
+ er = dom.beta
185
+ mx = F.relu(dom.head + er)
186
+ mn = F.relu(dom.head - er)
187
+ return dom.new((mn + mx) / 2, (mx - mn) / 2 , None)
188
+
189
+ sm = torch.sum(torch.abs(dom.errors), 0)
190
+
191
+ if not dom.beta is None:
192
+ sm += dom.beta
193
+
194
+ mn = dom.head - sm
195
+ mx = dom.head + sm
196
+
197
+ mngz = mn >= 0.0
198
+
199
+ zs = h.zeros(dom.head.shape)
200
+
201
+ diff = mx - mn
202
+
203
+ lam = torch.where((mx > 0) & (diff > 0.0), mx / diff, zs)
204
+ mu = lam * mn * (-0.5)
205
+
206
+ betaz = zs if dom.beta is None else dom.beta
207
+
208
+ newhead = torch.where(mngz, dom.head , lam * dom.head + mu)
209
+ mngz += diff <= 0.0
210
+ newbeta = torch.where(mngz, betaz , lam * betaz + mu ) # mu is always positive on this side
211
+ newerr = torch.where(mngz, dom.errors, lam * dom.errors )
212
+ return dom.new(newhead, newbeta, newerr)
213
+
214
+
215
+
216
+
217
+ class MaxTypes:
218
+
219
+ @staticmethod
220
+ def ub(x):
221
+ return x.ub()
222
+
223
+ @staticmethod
224
+ def only_beta(x):
225
+ return x.beta if x.beta is not None else x.head * 0
226
+
227
+ @staticmethod
228
+ def head_beta(x):
229
+ return MaxTypes.only_beta(x) + x.head
230
+
231
+ class HybridZonotope:
232
+
233
+ def isSafe(self, target):
234
+ od,_ = torch.min(h.preDomRes(self,target).lb(), 1)
235
+ return od.gt(0.0).long()
236
+
237
+ def isPoint(self):
238
+ return False
239
+
240
+ def labels(self):
241
+ target = torch.max(self.ub(), 1)[1]
242
+ l = list(h.preDomRes(self,target).lb()[0])
243
+ return [target.item()] + [ i for i,v in zip(range(len(l)), l) if v <= 0]
244
+
245
+ def relu(self):
246
+ return self.customRelu(self)
247
+
248
+ def __init__(self, head, beta, errors, customRelu = creluBoxy, **kargs):
249
+ self.head = head
250
+ self.errors = errors
251
+ self.beta = beta
252
+ self.customRelu = creluBoxy if customRelu is None else customRelu
253
+
254
+ def new(self, *args, customRelu = None, **kargs):
255
+ return self.__class__(*args, **kargs, customRelu = self.customRelu if customRelu is None else customRelu).checkSizes()
256
+
257
+ def zono_to_hybrid(self, *args, **kargs): # we are already a hybrid zono.
258
+ return self.new(self.head, self.beta, self.errors, **kargs)
259
+
260
+ def hybrid_to_zono(self, *args, correlate=True, customRelu = None, **kargs):
261
+ beta = self.beta
262
+ errors = self.errors
263
+ if correlate and beta is not None:
264
+ batches = beta.shape[0]
265
+ num_elem = h.product(beta.shape[1:])
266
+ ei = h.getEi(batches, num_elem)
267
+
268
+ if len(beta.shape) > 2:
269
+ ei = ei.contiguous().view(num_elem, *beta.shape)
270
+ err = ei * beta
271
+ errors = torch.cat((err, errors), dim=0) if errors is not None else err
272
+ beta = None
273
+
274
+ return Zonotope(self.head, beta, errors if errors is not None else (self.beta * 0).unsqueeze(0) , customRelu = self.customRelu if customRelu is None else None)
275
+
276
+
277
+
278
+ def abstractApplyLeaf(self, foo, *args, **kargs):
279
+ return getattr(self, foo)(*args, **kargs)
280
+
281
+ def decorrelate(self, cc_indx_batch_err): # keep these errors
282
+ if self.errors is None:
283
+ return self
284
+
285
+ batch_size = self.head.shape[0]
286
+ num_error_terms = self.errors.shape[0]
287
+
288
+
289
+
290
+ beta = h.zeros(self.head.shape).to_dtype() if self.beta is None else self.beta
291
+ errors = h.zeros([0] + list(self.head.shape)).to_dtype() if self.errors is None else self.errors
292
+
293
+ inds_i = torch.arange(self.head.shape[0], device=h.device).unsqueeze(1).long()
294
+ errors = errors.to_dtype().permute(1,0, *list(range(len(self.errors.shape)))[2:])
295
+
296
+ sm = errors.clone()
297
+ sm[inds_i, cc_indx_batch_err] = 0
298
+
299
+ beta = beta.to_dtype() + sm.abs().sum(dim=1)
300
+
301
+ errors = errors[inds_i, cc_indx_batch_err]
302
+ errors = errors.permute(1,0, *list(range(len(self.errors.shape)))[2:]).contiguous()
303
+ return self.new(self.head, beta, errors)
304
+
305
+ def dummyDecorrelate(self, num_decorrelate):
306
+ if num_decorrelate == 0 or self.errors is None:
307
+ return self
308
+ elif num_decorrelate >= self.errors.shape[0]:
309
+ beta = self.beta
310
+ if self.errors is not None:
311
+ errs = self.errors.abs().sum(dim=0)
312
+ if beta is None:
313
+ beta = errs
314
+ else:
315
+ beta += errs
316
+ return self.new(self.head, beta, None)
317
+ return None
318
+
319
+ def stochasticDecorrelate(self, num_decorrelate, choices = None, num_to_keep=False):
320
+ dummy = self.dummyDecorrelate(num_decorrelate)
321
+ if dummy is not None:
322
+ return dummy
323
+ num_error_terms = self.errors.shape[0]
324
+ batch_size = self.head.shape[0]
325
+
326
+ ucc_mask = h.ones([batch_size, self.errors.shape[0]]).long()
327
+ cc_indx_batch_err = h.cudify(torch.multinomial(ucc_mask.to_dtype(), num_decorrelate if num_to_keep else num_error_terms - num_decorrelate, replacement=False)) if choices is None else choices
328
+ return self.decorrelate(cc_indx_batch_err)
329
+
330
+ def decorrelateMin(self, num_decorrelate, num_to_keep=False):
331
+ dummy = self.dummyDecorrelate(num_decorrelate)
332
+ if dummy is not None:
333
+ return dummy
334
+
335
+ num_error_terms = self.errors.shape[0]
336
+ batch_size = self.head.shape[0]
337
+
338
+ error_sum_b_e = self.errors.abs().view(self.errors.shape[0], batch_size, -1).sum(dim=2).permute(1,0)
339
+ cc_indx_batch_err = error_sum_b_e.topk(num_decorrelate if num_to_keep else num_error_terms - num_decorrelate)[1]
340
+ return self.decorrelate(cc_indx_batch_err)
341
+
342
+ def correlate(self, cc_indx_batch_beta): # given in terms of the flattened matrix.
343
+ num_correlate = h.product(cc_indx_batch_beta.shape[1:])
344
+
345
+ beta = h.zeros(self.head.shape).to_dtype() if self.beta is None else self.beta
346
+ errors = h.zeros([0] + list(self.head.shape)).to_dtype() if self.errors is None else self.errors
347
+
348
+ batch_size = beta.shape[0]
349
+ new_errors = h.zeros([num_correlate] + list(self.head.shape)).to_dtype()
350
+
351
+ inds_i = torch.arange(batch_size, device=h.device).unsqueeze(1).long()
352
+
353
+ nc = torch.arange(num_correlate, device=h.device).unsqueeze(1).long()
354
+
355
+ new_errors = new_errors.permute(1,0, *list(range(len(new_errors.shape)))[2:]).contiguous().view(batch_size, num_correlate, -1)
356
+ new_errors[inds_i, nc.unsqueeze(0).expand([batch_size]+list(nc.shape)).squeeze(2), cc_indx_batch_beta] = beta.view(batch_size,-1)[inds_i, cc_indx_batch_beta]
357
+
358
+ new_errors = new_errors.permute(1,0, *list(range(len(new_errors.shape)))[2:]).contiguous().view(num_correlate, batch_size, *beta.shape[1:])
359
+ errors = torch.cat((errors, new_errors), dim=0)
360
+
361
+ beta.view(batch_size, -1)[inds_i, cc_indx_batch_beta] = 0
362
+
363
+ return self.new(self.head, beta, errors)
364
+
365
+ def stochasticCorrelate(self, num_correlate, choices = None):
366
+ if num_correlate == 0:
367
+ return self
368
+
369
+ domshape = self.head.shape
370
+ batch_size = domshape[0]
371
+ num_pixs = h.product(domshape[1:])
372
+ num_correlate = min(num_correlate, num_pixs)
373
+ ucc_mask = h.ones([batch_size, num_pixs ]).long()
374
+
375
+ cc_indx_batch_beta = h.cudify(torch.multinomial(ucc_mask.to_dtype(), num_correlate, replacement=False)) if choices is None else choices
376
+ return self.correlate(cc_indx_batch_beta)
377
+
378
+
379
+ def correlateMaxK(self, num_correlate):
380
+ if num_correlate == 0:
381
+ return self
382
+
383
+ domshape = self.head.shape
384
+ batch_size = domshape[0]
385
+ num_pixs = h.product(domshape[1:])
386
+ num_correlate = min(num_correlate, num_pixs)
387
+
388
+ concrete_max_image = self.ub().view(batch_size, -1)
389
+
390
+ cc_indx_batch_beta = concrete_max_image.topk(num_correlate)[1]
391
+ return self.correlate(cc_indx_batch_beta)
392
+
393
+ def correlateMaxPool(self, *args, max_type = MaxTypes.ub , max_pool = F.max_pool2d, **kargs):
394
+ domshape = self.head.shape
395
+ batch_size = domshape[0]
396
+ num_pixs = h.product(domshape[1:])
397
+
398
+ concrete_max_image = max_type(self)
399
+
400
+ cc_indx_batch_beta = max_pool(concrete_max_image, *args, return_indices=True, **kargs)[1].view(batch_size, -1)
401
+
402
+ return self.correlate(cc_indx_batch_beta)
403
+
404
+ def checkSizes(self):
405
+ if not self.errors is None:
406
+ if not self.errors.size()[1:] == self.head.size():
407
+ raise Exception("Such bad sizes on error:", self.errors.shape, " head:", self.head.shape)
408
+ if torch.isnan(self.errors).any():
409
+ raise Exception("Such nan in errors")
410
+ if not self.beta is None:
411
+ if not self.beta.size() == self.head.size():
412
+ raise Exception("Such bad sizes on beta")
413
+
414
+ if torch.isnan(self.beta).any():
415
+ raise Exception("Such nan in errors")
416
+ if self.beta.lt(0.0).any():
417
+ self.beta = self.beta.abs()
418
+
419
+ return self
420
+
421
+ def __mul__(self, flt):
422
+ return self.new(self.head * flt, None if self.beta is None else self.beta * abs(flt), None if self.errors is None else self.errors * flt)
423
+
424
+ def __truediv__(self, flt):
425
+ flt = 1. / flt
426
+ return self.new(self.head * flt, None if self.beta is None else self.beta * abs(flt), None if self.errors is None else self.errors * flt)
427
+
428
+ def __add__(self, other):
429
+ if isinstance(other, HybridZonotope):
430
+ return self.new(self.head + other.head, h.msum(self.beta, other.beta, lambda a,b: a + b), h.msum(self.errors, other.errors, catNonNullErrors(lambda a,b: a + b)))
431
+ else:
432
+ # other has to be a standard variable or tensor
433
+ return self.new(self.head + other, self.beta, self.errors)
434
+
435
+ def addPar(self, a, b):
436
+ return self.new(a.head + b.head, h.msum(a.beta, b.beta, lambda a,b: a + b), h.msum(a.errors, b.errors, catNonNullErrors(lambda a,b: a + b, self.errors)))
437
+
438
+ def __sub__(self, other):
439
+ if isinstance(other, HybridZonotope):
440
+ return self.new(self.head - other.head
441
+ , h.msum(self.beta, other.beta, lambda a,b: a + b)
442
+ , h.msum(self.errors, None if other.errors is None else -other.errors, catNonNullErrors(lambda a,b: a + b)))
443
+ else:
444
+ # other has to be a standard variable or tensor
445
+ return self.new(self.head - other, self.beta, self.errors)
446
+
447
+ def bmm(self, other):
448
+ hd = self.head.bmm(other)
449
+ bet = None if self.beta is None else self.beta.bmm(other.abs())
450
+
451
+ if self.errors is None:
452
+ er = None
453
+ else:
454
+ er = self.errors.matmul(other)
455
+ return self.new(hd, bet, er)
456
+
457
+
458
+ def getBeta(self):
459
+ return self.head * 0 if self.beta is None else self.beta
460
+
461
+ def getErrors(self):
462
+ return (self.head * 0).unsqueeze(0) if self.beta is None else self.errors
463
+
464
+ def merge(self, other, ref = None): # the vast majority of the time ref should be none here. Not for parallel computation with powerset
465
+ s_beta = self.getBeta() # so that beta is never none
466
+
467
+ sbox_u = self.head + s_beta
468
+ sbox_l = self.head - s_beta
469
+ o_u = other.ub()
470
+ o_l = other.lb()
471
+ o_in_s = (o_u <= sbox_u) & (o_l >= sbox_l)
472
+
473
+ s_err_mx = self.errors.abs().sum(dim=0)
474
+
475
+ if not isinstance(other, HybridZonotope):
476
+ new_head = (self.head + other.center()) / 2
477
+ new_beta = torch.max(sbox_u + s_err_mx,o_u) - new_head
478
+ return self.new(torch.where(o_in_s, self.head, new_head), torch.where(o_in_s, self.beta,new_beta), o_in_s.float() * self.errors)
479
+
480
+ # TODO: could be more efficient if one of these doesn't have beta or errors but thats okay for now.
481
+ s_u = sbox_u + s_err_mx
482
+ s_l = sbox_l - s_err_mx
483
+
484
+ obox_u = o_u - other.head
485
+ obox_l = o_l + other.head
486
+
487
+ s_in_o = (s_u <= obox_u) & (s_l >= obox_l)
488
+
489
+ # TODO: could theoretically still do something better when one is contained partially in the other
490
+ new_head = (self.head + other.center()) / 2
491
+ new_beta = torch.max(sbox_u + self.getErrors().abs().sum(dim=0),o_u) - new_head
492
+
493
+ return self.new(torch.where(o_in_s, self.head, torch.where(s_in_o, other.head, new_head))
494
+ , torch.where(o_in_s, s_beta,torch.where(s_in_o, other.getBeta(), new_beta))
495
+ , h.msum(o_in_s.float() * self.errors, s_in_o.float() * other.errors, catNonNullErrors(lambda a,b: a + b, ref_errs = ref.errors if ref is not None else ref))) # these are both zero otherwise
496
+
497
+
498
+ def conv(self, conv, weight, bias = None, **kargs):
499
+ h = self.errors
500
+ inter = h if h is None else h.view(-1, *h.size()[2:])
501
+ hd = conv(self.head, weight, bias=bias, **kargs)
502
+ res = h if h is None else conv(inter, weight, bias=None, **kargs)
503
+
504
+ return self.new( hd
505
+ , None if self.beta is None else conv(self.beta, weight.abs(), bias = None, **kargs)
506
+ , h if h is None else res.view(h.size()[0], h.size()[1], *res.size()[1:]))
507
+
508
+
509
+ def conv1d(self, *args, **kargs):
510
+ return self.conv(lambda x, *args, **kargs: x.conv1d(*args,**kargs), *args, **kargs)
511
+
512
+ def conv2d(self, *args, **kargs):
513
+ return self.conv(lambda x, *args, **kargs: x.conv2d(*args,**kargs), *args, **kargs)
514
+
515
+ def conv3d(self, *args, **kargs):
516
+ return self.conv(lambda x, *args, **kargs: x.conv3d(*args,**kargs), *args, **kargs)
517
+
518
+ def conv_transpose1d(self, *args, **kargs):
519
+ return self.conv(lambda x, *args, **kargs: x.conv_transpose1d(*args,**kargs), *args, **kargs)
520
+
521
+ def conv_transpose2d(self, *args, **kargs):
522
+ return self.conv(lambda x, *args, **kargs: x.conv_transpose2d(*args,**kargs), *args, **kargs)
523
+
524
+ def conv_transpose3d(self, *args, **kargs):
525
+ return self.conv(lambda x, *args, **kargs: x.conv_transpose3d(*args,**kargs), *args, **kargs)
526
+
527
+ def matmul(self, other):
528
+ return self.new(self.head.matmul(other), None if self.beta is None else self.beta.matmul(other.abs()), None if self.errors is None else self.errors.matmul(other))
529
+
530
+ def unsqueeze(self, i):
531
+ return self.new(self.head.unsqueeze(i), None if self.beta is None else self.beta.unsqueeze(i), None if self.errors is None else self.errors.unsqueeze(i + 1))
532
+
533
+ def squeeze(self, dim):
534
+ return self.new(self.head.squeeze(dim),
535
+ None if self.beta is None else self.beta.squeeze(dim),
536
+ None if self.errors is None else self.errors.squeeze(dim + 1 if dim >= 0 else dim))
537
+
538
+ def double(self):
539
+ return self.new(self.head.double(), self.beta.double() if self.beta is not None else None, self.errors.double() if self.errors is not None else None)
540
+
541
+ def float(self):
542
+ return self.new(self.head.float(), self.beta.float() if self.beta is not None else None, self.errors.float() if self.errors is not None else None)
543
+
544
+ def to_dtype(self):
545
+ return self.new(self.head.to_dtype(), self.beta.to_dtype() if self.beta is not None else None, self.errors.to_dtype() if self.errors is not None else None)
546
+
547
+ def sum(self, dim=1):
548
+ return self.new(torch.sum(self.head,dim=dim), None if self.beta is None else torch.sum(self.beta,dim=dim), None if self.errors is None else torch.sum(self.errors, dim= dim + 1 if dim >= 0 else dim))
549
+
550
+ def view(self,*newshape):
551
+ return self.new(self.head.view(*newshape),
552
+ None if self.beta is None else self.beta.view(*newshape),
553
+ None if self.errors is None else self.errors.view(self.errors.size()[0], *newshape))
554
+
555
+ def gather(self,dim, index):
556
+ return self.new(self.head.gather(dim, index),
557
+ None if self.beta is None else self.beta.gather(dim, index),
558
+ None if self.errors is None else self.errors.gather(dim + 1, index.expand([self.errors.size()[0]] + list(index.size()))))
559
+
560
+ def concretize(self):
561
+ if self.errors is None:
562
+ return self
563
+
564
+ return self.new(self.head, torch.sum(self.concreteErrors().abs(),0), None) # maybe make a box?
565
+
566
+ def cat(self,other, dim=0):
567
+ return self.new(self.head.cat(other.head, dim = dim),
568
+ h.msum(other.beta, self.beta, lambda a,b: a.cat(b, dim = dim)),
569
+ h.msum(self.errors, other.errors, catNonNullErrors(lambda a,b: a.cat(b, dim+1))))
570
+
571
+
572
+ def split(self, split_size, dim = 0):
573
+ heads = list(self.head.split(split_size, dim))
574
+ betas = list(self.beta.split(split_size, dim)) if not self.beta is None else None
575
+ errorss = list(self.errors.split(split_size, dim + 1)) if not self.errors is None else None
576
+
577
+ def makeFromI(i):
578
+ return self.new( heads[i],
579
+ None if betas is None else betas[i],
580
+ None if errorss is None else errorss[i])
581
+ return tuple(makeFromI(i) for i in range(len(heads)))
582
+
583
+
584
+
585
+ def concreteErrors(self):
586
+ if self.beta is None and self.errors is None:
587
+ raise Exception("shouldn't have both beta and errors be none")
588
+ if self.errors is None:
589
+ return self.beta.unsqueeze(0)
590
+ if self.beta is None:
591
+ return self.errors
592
+ return torch.cat([self.beta.unsqueeze(0),self.errors], dim=0)
593
+
594
+
595
+ def applyMonotone(self, foo, *args, **kargs):
596
+ if self.beta is None and self.errors is None:
597
+ return self.new(foo(self.head), None , None)
598
+
599
+ beta = self.concreteErrors().abs().sum(dim=0)
600
+
601
+ tp = foo(self.head + beta, *args, **kargs)
602
+ bt = foo(self.head - beta, *args, **kargs)
603
+
604
+ new_hybrid = self.new((tp + bt) / 2, (tp - bt) / 2 , None)
605
+
606
+
607
+ if self.errors is not None:
608
+ return new_hybrid.correlateMaxK(self.errors.shape[0])
609
+ return new_hybrid
610
+
611
+ def avg_pool2d(self, *args, **kargs):
612
+ nhead = F.avg_pool2d(self.head, *args, **kargs)
613
+ return self.new(nhead,
614
+ None if self.beta is None else F.avg_pool2d(self.beta, *args, **kargs),
615
+ None if self.errors is None else F.avg_pool2d(self.errors.view(-1, *self.head.shape[1:]), *args, **kargs).view(-1,*nhead.shape))
616
+
617
+ def adaptive_avg_pool2d(self, *args, **kargs):
618
+ nhead = F.adaptive_avg_pool2d(self.head, *args, **kargs)
619
+ return self.new(nhead,
620
+ None if self.beta is None else F.adaptive_avg_pool2d(self.beta, *args, **kargs),
621
+ None if self.errors is None else F.adaptive_avg_pool2d(self.errors.view(-1, *self.head.shape[1:]), *args, **kargs).view(-1,*nhead.shape))
622
+
623
+ def elu(self):
624
+ return self.applyMonotone(F.elu)
625
+
626
+ def selu(self):
627
+ return self.applyMonotone(F.selu)
628
+
629
+ def sigm(self):
630
+ return self.applyMonotone(F.sigmoid)
631
+
632
+ def softplus(self):
633
+ if self.errors is None:
634
+ if self.beta is None:
635
+ return self.new(F.softplus(self.head), None , None)
636
+ tp = F.softplus(self.head + self.beta)
637
+ bt = F.softplus(self.head - self.beta)
638
+ return self.new((tp + bt) / 2, (tp - bt) / 2 , None)
639
+
640
+ errors = self.concreteErrors()
641
+ o = h.ones(self.head.size())
642
+
643
+ def sp(hd):
644
+ return F.softplus(hd) # torch.log(o + torch.exp(hd)) # not very stable
645
+ def spp(hd):
646
+ ehd = torch.exp(hd)
647
+ return ehd.div(ehd + o)
648
+ def sppp(hd):
649
+ ehd = torch.exp(hd)
650
+ md = ehd + o
651
+ return ehd.div(md.mul(md))
652
+
653
+ fa = sp(self.head)
654
+ fpa = spp(self.head)
655
+
656
+ a = self.head
657
+
658
+ k = torch.sum(errors.abs(), 0)
659
+
660
+ def evalG(r):
661
+ return r.mul(r).mul(sppp(a + r))
662
+
663
+ m = torch.max(evalG(h.zeros(k.size())), torch.max(evalG(k), evalG(-k)))
664
+ m = h.ifThenElse( a.abs().lt(k), torch.max(m, torch.max(evalG(a), evalG(-a))), m)
665
+ m /= 2
666
+
667
+ return self.new(fa, m if self.beta is None else m + self.beta.mul(fpa), None if self.errors is None else self.errors.mul(fpa))
668
+
669
+ def center(self):
670
+ return self.head
671
+
672
+ def vanillaTensorPart(self):
673
+ return self.head
674
+
675
+ def lb(self):
676
+ return self.head - self.concreteErrors().abs().sum(dim=0)
677
+
678
+ def ub(self):
679
+ return self.head + self.concreteErrors().abs().sum(dim=0)
680
+
681
+ def size(self):
682
+ return self.head.size()
683
+
684
+ def diameter(self):
685
+ abal = torch.abs(self.concreteErrors()).transpose(0,1)
686
+ return abal.sum(1).sum(1) # perimeter
687
+
688
+ def loss(self, target, **args):
689
+ r = -h.preDomRes(self, target).lb()
690
+ return F.softplus(r.max(1)[0])
691
+
692
+ def deep_loss(self, act = F.relu, *args, **kargs):
693
+ batch_size = self.head.shape[0]
694
+ inds = torch.arange(batch_size, device=h.device).unsqueeze(1).long()
695
+
696
+ def dl(l,u):
697
+ ls, lsi = torch.sort(l, dim=1)
698
+ ls_u = u[inds, lsi]
699
+
700
+ def slidingMax(a): # using maxpool
701
+ k = a.shape[1]
702
+ ml = a.min(dim=1)[0].unsqueeze(1)
703
+
704
+ inp = torch.cat((h.zeros([batch_size, k]), a - ml), dim=1)
705
+ mpl = F.max_pool1d(inp.unsqueeze(1) , kernel_size = k, stride=1, padding = 0, return_indices=False).squeeze(1)
706
+ return mpl[:,:-1] + ml
707
+
708
+ return act(slidingMax(ls_u) - ls).sum(dim=1)
709
+
710
+ l = self.lb().view(batch_size, -1)
711
+ u = self.ub().view(batch_size, -1)
712
+ return ( dl(l,u) + dl(-u,-l) ) / (2 * l.shape[1]) # make it easier to regularize against
713
+
714
+
715
+
716
+ class Zonotope(HybridZonotope):
717
+ def applySuper(self, ret):
718
+ batches = ret.head.size()[0]
719
+ num_elem = h.product(ret.head.size()[1:])
720
+ ei = h.getEi(batches, num_elem)
721
+
722
+ if len(ret.head.size()) > 2:
723
+ ei = ei.contiguous().view(num_elem, *ret.head.size())
724
+
725
+ ret.errors = torch.cat( (ret.errors, ei * ret.beta) ) if not ret.beta is None else ret.errors
726
+ ret.beta = None
727
+ return ret.checkSizes()
728
+
729
+ def zono_to_hybrid(self, *args, customRelu = None, **kargs): # we are already a hybrid zono.
730
+ return HybridZonotope(self.head, self.beta, self.errors, customRelu = self.customRelu if customRelu is None else customRelu)
731
+
732
+ def hybrid_to_zono(self, *args, **kargs):
733
+ return self.new(self.head, self.beta, self.errors, **kargs)
734
+
735
+ def applyMonotone(self, *args, **kargs):
736
+ return self.applySuper(super(Zonotope,self).applyMonotone(*args, **kargs))
737
+
738
+ def softplus(self):
739
+ return self.applySuper(super(Zonotope,self).softplus())
740
+
741
+ def relu(self):
742
+ return self.applySuper(super(Zonotope,self).relu())
743
+
744
+ def splitRelu(self, *args, **kargs):
745
+ return [self.applySuper(a) for a in super(Zonotope, self).splitRelu(*args, **kargs)]
746
+
747
+
748
+ def mysign(x):
749
+ e = x.eq(0).to_dtype()
750
+ r = x.sign().to_dtype()
751
+ return r + e
752
+
753
+ def mulIfEq(grad,out,target):
754
+ pred = out.max(1, keepdim=True)[1]
755
+ is_eq = pred.eq(target.view_as(pred)).to_dtype()
756
+ is_eq = is_eq.view([-1] + [1 for _ in grad.size()[1:]]).expand_as(grad)
757
+ return is_eq
758
+
759
+
760
+ def stdLoss(out, target):
761
+ if torch.__version__[0] == "0":
762
+ return F.cross_entropy(out, target, reduce = False)
763
+ else:
764
+ return F.cross_entropy(out, target, reduction='none')
765
+
766
+
767
+
768
+ class ListDomain(object):
769
+
770
+ def __init__(self, al, *args, **kargs):
771
+ self.al = list(al)
772
+
773
+ def new(self, *args, **kargs):
774
+ return self.__class__(*args, **kargs)
775
+
776
+ def isSafe(self,*args,**kargs):
777
+ raise "Domain Not Suitable For Testing"
778
+
779
+ def labels(self):
780
+ raise "Domain Not Suitable For Testing"
781
+
782
+ def isPoint(self):
783
+ return all(a.isPoint() for a in self.al)
784
+
785
+ def __mul__(self, flt):
786
+ return self.new(a.__mul__(flt) for a in self.al)
787
+
788
+ def __truediv__(self, flt):
789
+ return self.new(a.__truediv__(flt) for a in self.al)
790
+
791
+ def __add__(self, other):
792
+ if isinstance(other, ListDomain):
793
+ return self.new(a.__add__(o) for a,o in zip(self.al, other.al))
794
+ else:
795
+ return self.new(a.__add__(other) for a in self.al)
796
+
797
+ def merge(self, other, ref = None):
798
+ if ref is None:
799
+ return self.new(a.merge(o) for a,o in zip(self.al,other.al) )
800
+ return self.new(a.merge(o, ref = r) for a,o,r in zip(self.al,other.al, ref.al))
801
+
802
+ def addPar(self, a, b):
803
+ return self.new(s.addPar(av,bv) for s,av,bv in zip(self.al, a.al, b.al))
804
+
805
+ def __sub__(self, other):
806
+ if isinstance(other, ListDomain):
807
+ return self.new(a.__sub__(o) for a,o in zip(self.al, other.al))
808
+ else:
809
+ return self.new(a.__sub__(other) for a in self.al)
810
+
811
+ def abstractApplyLeaf(self, *args, **kargs):
812
+ return self.new(a.abstractApplyLeaf(*args, **kargs) for a in self.al)
813
+
814
+ def bmm(self, other):
815
+ return self.new(a.bmm(other) for a in self.al)
816
+
817
+ def matmul(self, other):
818
+ return self.new(a.matmul(other) for a in self.al)
819
+
820
+ def conv(self, *args, **kargs):
821
+ return self.new(a.conv(*args, **kargs) for a in self.al)
822
+
823
+ def conv1d(self, *args, **kargs):
824
+ return self.new(a.conv1d(*args, **kargs) for a in self.al)
825
+
826
+ def conv2d(self, *args, **kargs):
827
+ return self.new(a.conv2d(*args, **kargs) for a in self.al)
828
+
829
+ def conv3d(self, *args, **kargs):
830
+ return self.new(a.conv3d(*args, **kargs) for a in self.al)
831
+
832
+ def max_pool2d(self, *args, **kargs):
833
+ return self.new(a.max_pool2d(*args, **kargs) for a in self.al)
834
+
835
+ def avg_pool2d(self, *args, **kargs):
836
+ return self.new(a.avg_pool2d(*args, **kargs) for a in self.al)
837
+
838
+ def adaptive_avg_pool2d(self, *args, **kargs):
839
+ return self.new(a.adaptive_avg_pool2d(*args, **kargs) for a in self.al)
840
+
841
+ def unsqueeze(self, *args, **kargs):
842
+ return self.new(a.unsqueeze(*args, **kargs) for a in self.al)
843
+
844
+ def squeeze(self, *args, **kargs):
845
+ return self.new(a.squeeze(*args, **kargs) for a in self.al)
846
+
847
+ def view(self, *args, **kargs):
848
+ return self.new(a.view(*args, **kargs) for a in self.al)
849
+
850
+ def gather(self, *args, **kargs):
851
+ return self.new(a.gather(*args, **kargs) for a in self.al)
852
+
853
+ def sum(self, *args, **kargs):
854
+ return self.new(a.sum(*args,**kargs) for a in self.al)
855
+
856
+ def double(self):
857
+ return self.new(a.double() for a in self.al)
858
+
859
+ def float(self):
860
+ return self.new(a.float() for a in self.al)
861
+
862
+ def to_dtype(self):
863
+ return self.new(a.to_dtype() for a in self.al)
864
+
865
+ def vanillaTensorPart(self):
866
+ return self.al[0].vanillaTensorPart()
867
+
868
+ def center(self):
869
+ return self.new(a.center() for a in self.al)
870
+
871
+ def ub(self):
872
+ return self.new(a.ub() for a in self.al)
873
+
874
+ def lb(self):
875
+ return self.new(a.lb() for a in self.al)
876
+
877
+ def relu(self):
878
+ return self.new(a.relu() for a in self.al)
879
+
880
+ def splitRelu(self, *args, **kargs):
881
+ return self.new(a.splitRelu(*args, **kargs) for a in self.al)
882
+
883
+ def softplus(self):
884
+ return self.new(a.softplus() for a in self.al)
885
+
886
+ def elu(self):
887
+ return self.new(a.elu() for a in self.al)
888
+
889
+ def selu(self):
890
+ return self.new(a.selu() for a in self.al)
891
+
892
+ def sigm(self):
893
+ return self.new(a.sigm() for a in self.al)
894
+
895
+ def cat(self, other, *args, **kargs):
896
+ return self.new(a.cat(o, *args, **kargs) for a,o in zip(self.al, other.al))
897
+
898
+
899
+ def split(self, *args, **kargs):
900
+ return [self.new(*z) for z in zip(a.split(*args, **kargs) for a in self.al)]
901
+
902
+ def size(self):
903
+ return self.al[0].size()
904
+
905
+ def loss(self, *args, **kargs):
906
+ return sum(a.loss(*args, **kargs) for a in self.al)
907
+
908
+ def deep_loss(self, *args, **kargs):
909
+ return sum(a.deep_loss(*args, **kargs) for a in self.al)
910
+
911
+ def checkSizes(self):
912
+ for a in self.al:
913
+ a.checkSizes()
914
+ return self
915
+
916
+
917
+ class TaggedDomain(object):
918
+
919
+
920
+ def __init__(self, a, tag = None):
921
+ self.tag = tag
922
+ self.a = a
923
+
924
+ def isSafe(self,*args,**kargs):
925
+ return self.a.isSafe(*args, **kargs)
926
+
927
+ def isPoint(self):
928
+ return self.a.isPoint()
929
+
930
+ def labels(self):
931
+ raise "Domain Not Suitable For Testing"
932
+
933
+ def __mul__(self, flt):
934
+ return TaggedDomain(self.a.__mul__(flt), self.tag)
935
+
936
+ def __truediv__(self, flt):
937
+ return TaggedDomain(self.a.__truediv__(flt), self.tag)
938
+
939
+ def __add__(self, other):
940
+ if isinstance(other, TaggedDomain):
941
+ return TaggedDomain(self.a.__add__(other.a), self.tag)
942
+ else:
943
+ return TaggedDomain(self.a.__add__(other), self.tag)
944
+
945
+ def addPar(self, a,b):
946
+ return TaggedDomain(self.a.addPar(a.a, b.a), self.tag)
947
+
948
+ def __sub__(self, other):
949
+ if isinstance(other, TaggedDomain):
950
+ return TaggedDomain(self.a.__sub__(other.a), self.tag)
951
+ else:
952
+ return TaggedDomain(self.a.__sub__(other), self.tag)
953
+
954
+ def bmm(self, other):
955
+ return TaggedDomain(self.a.bmm(other), self.tag)
956
+
957
+ def matmul(self, other):
958
+ return TaggedDomain(self.a.matmul(other), self.tag)
959
+
960
+ def conv(self, *args, **kargs):
961
+ return TaggedDomain(self.a.conv(*args, **kargs) , self.tag)
962
+
963
+ def conv1d(self, *args, **kargs):
964
+ return TaggedDomain(self.a.conv1d(*args, **kargs), self.tag)
965
+
966
+ def conv2d(self, *args, **kargs):
967
+ return TaggedDomain(self.a.conv2d(*args, **kargs), self.tag)
968
+
969
+ def conv3d(self, *args, **kargs):
970
+ return TaggedDomain(self.a.conv3d(*args, **kargs), self.tag)
971
+
972
+ def max_pool2d(self, *args, **kargs):
973
+ return TaggedDomain(self.a.max_pool2d(*args, **kargs), self.tag)
974
+
975
+ def avg_pool2d(self, *args, **kargs):
976
+ return TaggedDomain(self.a.avg_pool2d(*args, **kargs), self.tag)
977
+
978
+ def adaptive_avg_pool2d(self, *args, **kargs):
979
+ return TaggedDomain(self.a.adaptive_avg_pool2d(*args, **kargs), self.tag)
980
+
981
+
982
+ def unsqueeze(self, *args, **kargs):
983
+ return TaggedDomain(self.a.unsqueeze(*args, **kargs), self.tag)
984
+
985
+ def squeeze(self, *args, **kargs):
986
+ return TaggedDomain(self.a.squeeze(*args, **kargs), self.tag)
987
+
988
+ def abstractApplyLeaf(self, *args, **kargs):
989
+ return TaggedDomain(self.a.abstractApplyLeaf(*args, **kargs), self.tag)
990
+
991
+ def view(self, *args, **kargs):
992
+ return TaggedDomain(self.a.view(*args, **kargs), self.tag)
993
+
994
+ def gather(self, *args, **kargs):
995
+ return TaggedDomain(self.a.gather(*args, **kargs), self.tag)
996
+
997
+ def sum(self, *args, **kargs):
998
+ return TaggedDomain(self.a.sum(*args,**kargs), self.tag)
999
+
1000
+ def double(self):
1001
+ return TaggedDomain(self.a.double(), self.tag)
1002
+
1003
+ def float(self):
1004
+ return TaggedDomain(self.a.float(), self.tag)
1005
+
1006
+ def to_dtype(self):
1007
+ return TaggedDomain(self.a.to_dtype(), self.tag)
1008
+
1009
+ def vanillaTensorPart(self):
1010
+ return self.a.vanillaTensorPart()
1011
+
1012
+ def center(self):
1013
+ return TaggedDomain(self.a.center(), self.tag)
1014
+
1015
+ def ub(self):
1016
+ return TaggedDomain(self.a.ub(), self.tag)
1017
+
1018
+ def lb(self):
1019
+ return TaggedDomain(self.a.lb(), self.tag)
1020
+
1021
+ def relu(self):
1022
+ return TaggedDomain(self.a.relu(), self.tag)
1023
+
1024
+ def splitRelu(self, *args, **kargs):
1025
+ return TaggedDomain(self.a.splitRelu(*args, **kargs), self.tag)
1026
+
1027
+ def diameter(self):
1028
+ return self.a.diameter()
1029
+
1030
+ def softplus(self):
1031
+ return TaggedDomain(self.a.softplus(), self.tag)
1032
+
1033
+ def elu(self):
1034
+ return TaggedDomain(self.a.elu(), self.tag)
1035
+
1036
+ def selu(self):
1037
+ return TaggedDomain(self.a.selu(), self.tag)
1038
+
1039
+ def sigm(self):
1040
+ return TaggedDomain(self.a.sigm(), self.tag)
1041
+
1042
+
1043
+ def cat(self, other, *args, **kargs):
1044
+ return TaggedDomain(self.a.cat(other.a, *args, **kargs), self.tag)
1045
+
1046
+ def split(self, *args, **kargs):
1047
+ return [TaggedDomain(z, self.tag) for z in self.a.split(*args, **kargs)]
1048
+
1049
+ def size(self):
1050
+
1051
+ return self.a.size()
1052
+
1053
+ def loss(self, *args, **kargs):
1054
+ return self.tag.loss(self.a, *args, **kargs)
1055
+
1056
+ def deep_loss(self, *args, **kargs):
1057
+ return self.a.deep_loss(*args, **kargs)
1058
+
1059
+ def checkSizes(self):
1060
+ self.a.checkSizes()
1061
+ return self
1062
+
1063
+ def merge(self, other, ref = None):
1064
+ return TaggedDomain(self.a.merge(other.a, ref = None if ref is None else ref.a), self.tag)
components.py ADDED
@@ -0,0 +1,951 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.distributions import multinomial, categorical
5
+ import torch.optim as optim
6
+
7
+ import math
8
+
9
+ try:
10
+ from . import helpers as h
11
+ from . import ai
12
+ from . import scheduling as S
13
+ except:
14
+ import helpers as h
15
+ import ai
16
+ import scheduling as S
17
+
18
+ import math
19
+ import abc
20
+
21
+ from torch.nn.modules.conv import _ConvNd
22
+ from enum import Enum
23
+
24
+
25
+ class InferModule(nn.Module):
26
+ def __init__(self, *args, normal = False, ibp_init = False, **kwargs):
27
+ self.args = args
28
+ self.kwargs = kwargs
29
+ self.infered = False
30
+ self.normal = normal
31
+ self.ibp_init = ibp_init
32
+
33
+ def infer(self, in_shape, global_args = None):
34
+ """ this is really actually stateful. """
35
+
36
+ if self.infered:
37
+ return self
38
+ self.infered = True
39
+
40
+ super(InferModule, self).__init__()
41
+ self.inShape = list(in_shape)
42
+ self.outShape = list(self.init(list(in_shape), *self.args, global_args = global_args, **self.kwargs))
43
+ if self.outShape is None:
44
+ raise "init should set the out_shape"
45
+
46
+ self.reset_parameters()
47
+ return self
48
+
49
+ def reset_parameters(self):
50
+ if not hasattr(self,'weight') or self.weight is None:
51
+ return
52
+ n = h.product(self.weight.size()) / self.outShape[0]
53
+ stdv = 1 / math.sqrt(n)
54
+
55
+ if self.ibp_init:
56
+ torch.nn.init.orthogonal_(self.weight.data)
57
+ elif self.normal:
58
+ self.weight.data.normal_(0, stdv)
59
+ self.weight.data.clamp_(-1, 1)
60
+ else:
61
+ self.weight.data.uniform_(-stdv, stdv)
62
+
63
+ if self.bias is not None:
64
+ if self.ibp_init:
65
+ self.bias.data.zero_()
66
+ elif self.normal:
67
+ self.bias.data.normal_(0, stdv)
68
+ self.bias.data.clamp_(-1, 1)
69
+ else:
70
+ self.bias.data.uniform_(-stdv, stdv)
71
+
72
+ def clip_norm(self):
73
+ if not hasattr(self, "weight"):
74
+ return
75
+ if not hasattr(self,"weight_g"):
76
+ if torch.__version__[0] == "0":
77
+ nn.utils.weight_norm(self, dim=None)
78
+ else:
79
+ nn.utils.weight_norm(self)
80
+
81
+ self.weight_g.data.clamp_(-h.max_c_for_norm, h.max_c_for_norm)
82
+
83
+ if torch.__version__[0] != "0":
84
+ self.weight_v.data.clamp_(-h.max_c_for_norm * 10000,h.max_c_for_norm * 10000)
85
+ if hasattr(self, "bias"):
86
+ self.bias.data.clamp_(-h.max_c_for_norm * 10000, h.max_c_for_norm * 10000)
87
+
88
+ def regularize(self, p):
89
+ reg = 0
90
+ if torch.__version__[0] == "0":
91
+ for param in self.parameters():
92
+ reg += param.norm(p)
93
+ else:
94
+ if hasattr(self, "weight_g"):
95
+ reg += self.weight_g.norm().sum()
96
+ reg += self.weight_v.norm().sum()
97
+ elif hasattr(self, "weight"):
98
+ reg += self.weight.norm().sum()
99
+
100
+ if hasattr(self, "bias"):
101
+ reg += self.bias.view(-1).norm(p=p).sum()
102
+
103
+ return reg
104
+
105
+ def remove_norm(self):
106
+ if hasattr(self,"weight_g"):
107
+ torch.nn.utils.remove_weight_norm(self)
108
+
109
+ def showNet(self, t = ""):
110
+ print(t + self.__class__.__name__)
111
+
112
+ def printNet(self, f):
113
+ print(self.__class__.__name__, file=f)
114
+
115
+ @abc.abstractmethod
116
+ def forward(self, *args, **kargs):
117
+ pass
118
+
119
+ def __call__(self, *args, onyx=False, **kargs):
120
+ if onyx:
121
+ return self.forward(*args, onyx=onyx, **kargs)
122
+ else:
123
+ return super(InferModule, self).__call__(*args, **kargs)
124
+
125
+ @abc.abstractmethod
126
+ def neuronCount(self):
127
+ pass
128
+
129
+ def depth(self):
130
+ return 0
131
+
132
+ def getShapeConv(in_shape, conv_shape, stride = 1, padding = 0):
133
+ inChan, inH, inW = in_shape
134
+ outChan, kH, kW = conv_shape[:3]
135
+
136
+ outH = 1 + int((2 * padding + inH - kH) / stride)
137
+ outW = 1 + int((2 * padding + inW - kW) / stride)
138
+ return (outChan, outH, outW)
139
+
140
+ def getShapeConvTranspose(in_shape, conv_shape, stride = 1, padding = 0, out_padding=0):
141
+ inChan, inH, inW = in_shape
142
+ outChan, kH, kW = conv_shape[:3]
143
+
144
+ outH = (inH - 1 ) * stride - 2 * padding + kH + out_padding
145
+ outW = (inW - 1 ) * stride - 2 * padding + kW + out_padding
146
+ return (outChan, outH, outW)
147
+
148
+
149
+
150
+ class Linear(InferModule):
151
+ def init(self, in_shape, out_shape, **kargs):
152
+ self.in_neurons = h.product(in_shape)
153
+ if isinstance(out_shape, int):
154
+ out_shape = [out_shape]
155
+ self.out_neurons = h.product(out_shape)
156
+
157
+ self.weight = torch.nn.Parameter(torch.Tensor(self.in_neurons, self.out_neurons))
158
+ self.bias = torch.nn.Parameter(torch.Tensor(self.out_neurons))
159
+
160
+ return out_shape
161
+
162
+ def forward(self, x, **kargs):
163
+ s = x.size()
164
+ x = x.view(s[0], h.product(s[1:]))
165
+ return (x.matmul(self.weight) + self.bias).view(s[0], *self.outShape)
166
+
167
+ def neuronCount(self):
168
+ return 0
169
+
170
+ def showNet(self, t = ""):
171
+ print(t + "Linear out=" + str(self.out_neurons))
172
+
173
+ def printNet(self, f):
174
+ print("Linear(" + str(self.out_neurons) + ")" )
175
+
176
+ print(h.printListsNumpy(list(self.weight.transpose(1,0).data)), file= f)
177
+ print(h.printNumpy(self.bias), file= f)
178
+
179
+ class Activation(InferModule):
180
+ def init(self, in_shape, global_args = None, activation = "ReLU", **kargs):
181
+ self.activation = [ "ReLU","Sigmoid", "Tanh", "Softplus", "ELU", "SELU"].index(activation)
182
+ self.activation_name = activation
183
+ return in_shape
184
+
185
+ def regularize(self, p):
186
+ return 0
187
+
188
+ def forward(self, x, **kargs):
189
+ return [lambda x:x.relu(), lambda x:x.sigmoid(), lambda x:x.tanh(), lambda x:x.softplus(), lambda x:x.elu(), lambda x:x.selu()][self.activation](x)
190
+
191
+ def neuronCount(self):
192
+ return h.product(self.outShape)
193
+
194
+ def depth(self):
195
+ return 1
196
+
197
+ def showNet(self, t = ""):
198
+ print(t + self.activation_name)
199
+
200
+ def printNet(self, f):
201
+ pass
202
+
203
+ class ReLU(Activation):
204
+ pass
205
+
206
+ def activation(*args, batch_norm = False, **kargs):
207
+ a = Activation(*args, **kargs)
208
+ return Seq(BatchNorm(), a) if batch_norm else a
209
+
210
+ class Identity(InferModule): # for feigning model equivelence when removing an op
211
+ def init(self, in_shape, global_args = None, **kargs):
212
+ return in_shape
213
+
214
+ def forward(self, x, **kargs):
215
+ return x
216
+
217
+ def neuronCount(self):
218
+ return 0
219
+
220
+ def printNet(self, f):
221
+ pass
222
+
223
+ def regularize(self, p):
224
+ return 0
225
+
226
+ def showNet(self, *args, **kargs):
227
+ pass
228
+
229
+ class Dropout(InferModule):
230
+ def init(self, in_shape, p=0.5, use_2d = False, alpha_dropout = False, **kargs):
231
+ self.p = S.Const.initConst(p)
232
+ self.use_2d = use_2d
233
+ self.alpha_dropout = alpha_dropout
234
+ return in_shape
235
+
236
+ def forward(self, x, time = 0, **kargs):
237
+ if self.training:
238
+ with torch.no_grad():
239
+ p = self.p.getVal(time = time)
240
+ mask = (F.dropout2d if self.use_2d else F.dropout)(h.ones(x.size()),p=p, training=True)
241
+ if self.alpha_dropout:
242
+ with torch.no_grad():
243
+ keep_prob = 1 - p
244
+ alpha = -1.7580993408473766
245
+ a = math.pow(keep_prob + alpha * alpha * keep_prob * (1 - keep_prob), -0.5)
246
+ b = -a * alpha * (1 - keep_prob)
247
+ mask = mask * a
248
+ return x * mask + b
249
+ else:
250
+ return x * mask
251
+ else:
252
+ return x
253
+
254
+ def neuronCount(self):
255
+ return 0
256
+
257
+ def showNet(self, t = ""):
258
+ print(t + "Dropout p=" + str(self.p))
259
+
260
+ def printNet(self, f):
261
+ print("Dropout(" + str(self.p) + ")" )
262
+
263
+ class PrintActivation(Identity):
264
+ def init(self, in_shape, global_args = None, activation = "ReLU", **kargs):
265
+ self.activation = activation
266
+ return in_shape
267
+
268
+ def printNet(self, f):
269
+ print(self.activation, file = f)
270
+
271
+ class PrintReLU(PrintActivation):
272
+ pass
273
+
274
+ class Conv2D(InferModule):
275
+
276
+ def init(self, in_shape, out_channels, kernel_size, stride = 1, global_args = None, bias=True, padding = 0, activation = "ReLU", **kargs):
277
+ self.prev = in_shape
278
+ self.in_channels = in_shape[0]
279
+ self.out_channels = out_channels
280
+ self.kernel_size = kernel_size
281
+ self.stride = stride
282
+ self.padding = padding
283
+ self.activation = activation
284
+ self.use_softplus = h.default(global_args, 'use_softplus', False)
285
+
286
+ weights_shape = (self.out_channels, self.in_channels, kernel_size, kernel_size)
287
+ self.weight = torch.nn.Parameter(torch.Tensor(*weights_shape))
288
+ if bias:
289
+ self.bias = torch.nn.Parameter(torch.Tensor(weights_shape[0]))
290
+ else:
291
+ self.bias = None # h.zeros(weights_shape[0])
292
+
293
+ outshape = getShapeConv(in_shape, (out_channels, kernel_size, kernel_size), stride, padding)
294
+ return outshape
295
+
296
+ def forward(self, input, **kargs):
297
+ return input.conv2d(self.weight, bias=self.bias, stride=self.stride, padding = self.padding )
298
+
299
+ def printNet(self, f): # only complete if we've forwardt stride=1
300
+ print("Conv2D", file = f)
301
+ sz = list(self.prev)
302
+ print(self.activation + ", filters={}, kernel_size={}, input_shape={}, stride={}, padding={}".format(self.out_channels, [self.kernel_size, self.kernel_size], list(reversed(sz)), [self.stride, self.stride], self.padding ), file = f)
303
+ print(h.printListsNumpy([[list(p) for p in l ] for l in self.weight.permute(2,3,1,0).data]) , file= f)
304
+ print(h.printNumpy(self.bias if self.bias is not None else h.dten(self.out_channels)), file= f)
305
+
306
+ def showNet(self, t = ""):
307
+ sz = list(self.prev)
308
+ print(t + "Conv2D, filters={}, kernel_size={}, input_shape={}, stride={}, padding={}".format(self.out_channels, [self.kernel_size, self.kernel_size], list(reversed(sz)), [self.stride, self.stride], self.padding ))
309
+
310
+ def neuronCount(self):
311
+ return 0
312
+
313
+
314
+ class ConvTranspose2D(InferModule):
315
+
316
+ def init(self, in_shape, out_channels, kernel_size, stride = 1, global_args = None, bias=True, padding = 0, out_padding=0, activation = "ReLU", **kargs):
317
+ self.prev = in_shape
318
+ self.in_channels = in_shape[0]
319
+ self.out_channels = out_channels
320
+ self.kernel_size = kernel_size
321
+ self.stride = stride
322
+ self.padding = padding
323
+ self.out_padding = out_padding
324
+ self.activation = activation
325
+ self.use_softplus = h.default(global_args, 'use_softplus', False)
326
+
327
+ weights_shape = (self.in_channels, self.out_channels, kernel_size, kernel_size)
328
+ self.weight = torch.nn.Parameter(torch.Tensor(*weights_shape))
329
+ if bias:
330
+ self.bias = torch.nn.Parameter(torch.Tensor(weights_shape[0]))
331
+ else:
332
+ self.bias = None # h.zeros(weights_shape[0])
333
+
334
+ outshape = getShapeConvTranspose(in_shape, (out_channels, kernel_size, kernel_size), stride, padding, out_padding)
335
+ return outshape
336
+
337
+ def forward(self, input, **kargs):
338
+ return input.conv_transpose2d(self.weight, bias=self.bias, stride=self.stride, padding = self.padding, output_padding=self.out_padding)
339
+
340
+ def printNet(self, f): # only complete if we've forwardt stride=1
341
+ print("ConvTranspose2D", file = f)
342
+ print(self.activation + ", filters={}, kernel_size={}, input_shape={}".format(self.out_channels, list(self.kernel_size), list(self.prev) ), file = f)
343
+ print(h.printListsNumpy([[list(p) for p in l ] for l in self.weight.permute(2,3,1,0).data]) , file= f)
344
+ print(h.printNumpy(self.bias), file= f)
345
+
346
+ def neuronCount(self):
347
+ return 0
348
+
349
+
350
+
351
+ class MaxPool2D(InferModule):
352
+ def init(self, in_shape, kernel_size, stride = None, **kargs):
353
+ self.prev = in_shape
354
+ self.kernel_size = kernel_size
355
+ self.stride = kernel_size if stride is None else stride
356
+ return getShapeConv(in_shape, (in_shape[0], kernel_size, kernel_size), stride)
357
+
358
+ def forward(self, x, **kargs):
359
+ return x.max_pool2d(self.kernel_size, self.stride)
360
+
361
+ def printNet(self, f):
362
+ print("MaxPool2D stride={}, kernel_size={}, input_shape={}".format(list(self.stride), list(self.shape[2:]), list(self.prev[1:]+self.prev[:1]) ), file = f)
363
+
364
+ def neuronCount(self):
365
+ return h.product(self.outShape)
366
+
367
+ class AvgPool2D(InferModule):
368
+ def init(self, in_shape, kernel_size, stride = None, **kargs):
369
+ self.prev = in_shape
370
+ self.kernel_size = kernel_size
371
+ self.stride = kernel_size if stride is None else stride
372
+ out_size = getShapeConv(in_shape, (in_shape[0], kernel_size, kernel_size), self.stride, padding = 1)
373
+ return out_size
374
+
375
+ def forward(self, x, **kargs):
376
+ if h.product(x.size()[2:]) == 1:
377
+ return x
378
+ return x.avg_pool2d(kernel_size = self.kernel_size, stride = self.stride, padding = 1)
379
+
380
+ def printNet(self, f):
381
+ print("AvgPool2D stride={}, kernel_size={}, input_shape={}".format(list(self.stride), list(self.shape[2:]), list(self.prev[1:]+self.prev[:1]) ), file = f)
382
+
383
+ def neuronCount(self):
384
+ return h.product(self.outShape)
385
+
386
+ class AdaptiveAvgPool2D(InferModule):
387
+ def init(self, in_shape, out_shape, **kargs):
388
+ self.prev = in_shape
389
+ self.out_shape = list(out_shape)
390
+ return [in_shape[0]] + self.out_shape
391
+
392
+ def forward(self, x, **kargs):
393
+ return x.adaptive_avg_pool2d(self.out_shape)
394
+
395
+ def printNet(self, f):
396
+ print("AdaptiveAvgPool2D out_Shape={} input_shape={}".format(list(self.out_shape), list(self.prev[1:]+self.prev[:1]) ), file = f)
397
+
398
+ def neuronCount(self):
399
+ return h.product(self.outShape)
400
+
401
+ class Normalize(InferModule):
402
+ def init(self, in_shape, mean, std, **kargs):
403
+ self.mean_v = mean
404
+ self.std_v = std
405
+ self.mean = h.dten(mean)
406
+ self.std = 1 / h.dten(std)
407
+ return in_shape
408
+
409
+ def forward(self, x, **kargs):
410
+ mean_ex = self.mean.view(self.mean.shape[0],1,1).expand(*x.size()[1:])
411
+ std_ex = self.std.view(self.std.shape[0],1,1).expand(*x.size()[1:])
412
+ return (x - mean_ex) * std_ex
413
+
414
+ def neuronCount(self):
415
+ return 0
416
+
417
+ def printNet(self, f):
418
+ print("Normalize mean={} std={}".format(self.mean_v, self.std_v), file = f)
419
+
420
+ def showNet(self, t = ""):
421
+ print(t + "Normalize mean={} std={}".format(self.mean_v, self.std_v))
422
+
423
+ class Flatten(InferModule):
424
+ def init(self, in_shape, **kargs):
425
+ return h.product(in_shape)
426
+
427
+ def forward(self, x, **kargs):
428
+ s = x.size()
429
+ return x.view(s[0], h.product(s[1:]))
430
+
431
+ def neuronCount(self):
432
+ return 0
433
+
434
+ class BatchNorm(InferModule):
435
+ def init(self, in_shape, track_running_stats = True, momentum = 0.1, eps=1e-5, **kargs):
436
+ self.gamma = torch.nn.Parameter(torch.Tensor(*in_shape))
437
+ self.beta = torch.nn.Parameter(torch.Tensor(*in_shape))
438
+ self.eps = eps
439
+ self.track_running_stats = track_running_stats
440
+ self.momentum = momentum
441
+
442
+ self.running_mean = None
443
+ self.running_var = None
444
+
445
+ self.num_batches_tracked = 0
446
+ return in_shape
447
+
448
+ def reset_parameters(self):
449
+ self.gamma.data.fill_(1)
450
+ self.beta.data.zero_()
451
+
452
+ def forward(self, x, **kargs):
453
+ exponential_average_factor = 0.0
454
+ if self.training and self.track_running_stats:
455
+ # TODO: if statement only here to tell the jit to skip emitting this when it is None
456
+ if self.num_batches_tracked is not None:
457
+ self.num_batches_tracked += 1
458
+ if self.momentum is None: # use cumulative moving average
459
+ exponential_average_factor = 1.0 / float(self.num_batches_tracked)
460
+ else: # use exponential moving average
461
+ exponential_average_factor = self.momentum
462
+
463
+ new_mean = x.vanillaTensorPart().detach().mean(dim=0)
464
+ new_var = x.vanillaTensorPart().detach().var(dim=0, unbiased=False)
465
+ if torch.isnan(new_var * 0).any():
466
+ return x
467
+ if self.training:
468
+ self.running_mean = (1 - exponential_average_factor) * self.running_mean + exponential_average_factor * new_mean if self.running_mean is not None else new_mean
469
+ if self.running_var is None:
470
+ self.running_var = new_var
471
+ else:
472
+ q = (1 - exponential_average_factor) * self.running_var
473
+ r = exponential_average_factor * new_var
474
+ self.running_var = q + r
475
+
476
+ if self.track_running_stats and self.running_mean is not None and self.running_var is not None:
477
+ new_mean = self.running_mean
478
+ new_var = self.running_var
479
+
480
+ diver = 1 / (new_var + self.eps).sqrt()
481
+
482
+ if torch.isnan(diver).any():
483
+ print("Really shouldn't happen ever")
484
+ return x
485
+ else:
486
+ out = (x - new_mean) * diver * self.gamma + self.beta
487
+ return out
488
+
489
+ def neuronCount(self):
490
+ return 0
491
+
492
+ class Unflatten2d(InferModule):
493
+ def init(self, in_shape, w, **kargs):
494
+ self.w = w
495
+ self.outChan = int(h.product(in_shape) / (w * w))
496
+
497
+ return (self.outChan, self.w, self.w)
498
+
499
+ def forward(self, x, **kargs):
500
+ s = x.size()
501
+ return x.view(s[0], self.outChan, self.w, self.w)
502
+
503
+ def neuronCount(self):
504
+ return 0
505
+
506
+
507
+ class View(InferModule):
508
+ def init(self, in_shape, out_shape, **kargs):
509
+ assert(h.product(in_shape) == h.product(out_shape))
510
+ return out_shape
511
+
512
+ def forward(self, x, **kargs):
513
+ s = x.size()
514
+ return x.view(s[0], *self.outShape)
515
+
516
+ def neuronCount(self):
517
+ return 0
518
+
519
+ class Seq(InferModule):
520
+ def init(self, in_shape, *layers, **kargs):
521
+ self.layers = layers
522
+ self.net = nn.Sequential(*layers)
523
+ self.prev = in_shape
524
+ for s in layers:
525
+ in_shape = s.infer(in_shape, **kargs).outShape
526
+ return in_shape
527
+
528
+ def forward(self, x, **kargs):
529
+
530
+ for l in self.layers:
531
+ x = l(x, **kargs)
532
+ return x
533
+
534
+ def clip_norm(self):
535
+ for l in self.layers:
536
+ l.clip_norm()
537
+
538
+ def regularize(self, p):
539
+ return sum(n.regularize(p) for n in self.layers)
540
+
541
+ def remove_norm(self):
542
+ for l in self.layers:
543
+ l.remove_norm()
544
+
545
+ def printNet(self, f):
546
+ for l in self.layers:
547
+ l.printNet(f)
548
+
549
+ def showNet(self, *args, **kargs):
550
+ for l in self.layers:
551
+ l.showNet(*args, **kargs)
552
+
553
+ def neuronCount(self):
554
+ return sum([l.neuronCount() for l in self.layers ])
555
+
556
+ def depth(self):
557
+ return sum([l.depth() for l in self.layers ])
558
+
559
+ def FFNN(layers, last_lin = False, last_zono = False, **kargs):
560
+ starts = layers
561
+ ends = []
562
+ if last_lin:
563
+ ends = ([CorrelateAll(only_train=False)] if last_zono else []) + [PrintActivation(activation = "Affine"), Linear(layers[-1],**kargs)]
564
+ starts = layers[:-1]
565
+
566
+ return Seq(*([ Seq(PrintActivation(**kargs), Linear(s, **kargs), activation(**kargs)) for s in starts] + ends))
567
+
568
+ def Conv(*args, **kargs):
569
+ return Seq(Conv2D(*args, **kargs), activation(**kargs))
570
+
571
+ def ConvTranspose(*args, **kargs):
572
+ return Seq(ConvTranspose2D(*args, **kargs), activation(**kargs))
573
+
574
+ MP = MaxPool2D
575
+
576
+ def LeNet(conv_layers, ly = [], bias = True, normal=False, **kargs):
577
+ def transfer(tp):
578
+ if isinstance(tp, InferModule):
579
+ return tp
580
+ if isinstance(tp[0], str):
581
+ return MaxPool2D(*tp[1:])
582
+ return Conv(out_channels = tp[0], kernel_size = tp[1], stride = tp[-1] if len(tp) == 4 else 1, bias=bias, normal=normal, **kargs)
583
+ conv = [transfer(s) for s in conv_layers]
584
+ return Seq(*conv, FFNN(ly, **kargs, bias=bias)) if len(ly) > 0 else Seq(*conv)
585
+
586
+ def InvLeNet(ly, w, conv_layers, bias = True, normal=False, **kargs):
587
+ def transfer(tp):
588
+ return ConvTranspose(out_channels = tp[0], kernel_size = tp[1], stride = tp[2], padding = tp[3], out_padding = tp[4], bias=False, normal=normal)
589
+
590
+ return Seq(FFNN(ly, bias=bias), Unflatten2d(w), *[transfer(s) for s in conv_layers])
591
+
592
+ class FromByteImg(InferModule):
593
+ def init(self, in_shape, **kargs):
594
+ return in_shape
595
+
596
+ def forward(self, x, **kargs):
597
+ return x.to_dtype()/ 256.
598
+
599
+ def neuronCount(self):
600
+ return 0
601
+
602
+ class Skip(InferModule):
603
+ def init(self, in_shape, net1, net2, **kargs):
604
+ self.net1 = net1.infer(in_shape, **kargs)
605
+ self.net2 = net2.infer(in_shape, **kargs)
606
+ assert(net1.outShape[1:] == net2.outShape[1:])
607
+ return [ net1.outShape[0] + net2.outShape[0] ] + net1.outShape[1:]
608
+
609
+ def forward(self, x, **kargs):
610
+ r1 = self.net1(x, **kargs)
611
+ r2 = self.net2(x, **kargs)
612
+ return r1.cat(r2, dim=1)
613
+
614
+ def regularize(self, p):
615
+ return self.net1.regularize(p) + self.net2.regularize(p)
616
+
617
+ def clip_norm(self):
618
+ self.net1.clip_norm()
619
+ self.net2.clip_norm()
620
+
621
+ def remove_norm(self):
622
+ self.net1.remove_norm()
623
+ self.net2.remove_norm()
624
+
625
+ def neuronCount(self):
626
+ return self.net1.neuronCount() + self.net2.neuronCount()
627
+
628
+ def printNet(self, f):
629
+ print("SkipNet1", file=f)
630
+ self.net1.printNet(f)
631
+ print("SkipNet2", file=f)
632
+ self.net2.printNet(f)
633
+ print("SkipCat dim=1", file=f)
634
+
635
+ def showNet(self, t = ""):
636
+ print(t+"SkipNet1")
637
+ self.net1.showNet(" "+t)
638
+ print(t+"SkipNet2")
639
+ self.net2.showNet(" "+t)
640
+ print(t+"SkipCat dim=1")
641
+
642
+ class ParSum(InferModule):
643
+ def init(self, in_shape, net1, net2, **kargs):
644
+ self.net1 = net1.infer(in_shape, **kargs)
645
+ self.net2 = net2.infer(in_shape, **kargs)
646
+ assert(net1.outShape == net2.outShape)
647
+ return net1.outShape
648
+
649
+
650
+
651
+ def forward(self, x, **kargs):
652
+
653
+ r1 = self.net1(x, **kargs)
654
+ r2 = self.net2(x, **kargs)
655
+ return x.addPar(r1,r2)
656
+
657
+ def clip_norm(self):
658
+ self.net1.clip_norm()
659
+ self.net2.clip_norm()
660
+
661
+ def remove_norm(self):
662
+ self.net1.remove_norm()
663
+ self.net2.remove_norm()
664
+
665
+ def neuronCount(self):
666
+ return self.net1.neuronCount() + self.net2.neuronCount()
667
+
668
+ def depth(self):
669
+ return max(self.net1.depth(), self.net2.depth())
670
+
671
+ def printNet(self, f):
672
+ print("ParNet1", file=f)
673
+ self.net1.printNet(f)
674
+ print("ParNet2", file=f)
675
+ self.net2.printNet(f)
676
+ print("ParCat dim=1", file=f)
677
+
678
+ def showNet(self, t = ""):
679
+ print(t + "ParNet1")
680
+ self.net1.showNet(" "+t)
681
+ print(t + "ParNet2")
682
+ self.net2.showNet(" "+t)
683
+ print(t + "ParSum")
684
+
685
+ class ToZono(Identity):
686
+ def init(self, in_shape, customRelu = None, only_train = False, **kargs):
687
+ self.customRelu = customRelu
688
+ self.only_train = only_train
689
+ return in_shape
690
+
691
+ def forward(self, x, **kargs):
692
+ return self.abstract_forward(x, **kargs) if self.training or not self.only_train else x
693
+
694
+ def abstract_forward(self, x, **kargs):
695
+ return x.abstractApplyLeaf('hybrid_to_zono', customRelu = self.customRelu)
696
+
697
+ def showNet(self, t = ""):
698
+ print(t + self.__class__.__name__ + " only_train=" + str(self.only_train))
699
+
700
+ class CorrelateAll(ToZono):
701
+ def abstract_forward(self, x, **kargs):
702
+ return x.abstractApplyLeaf('hybrid_to_zono',correlate=True, customRelu = self.customRelu)
703
+
704
+ class ToHZono(ToZono):
705
+ def abstract_forward(self, x, **kargs):
706
+ return x.abstractApplyLeaf('zono_to_hybrid',customRelu = self.customRelu)
707
+
708
+ class Concretize(ToZono):
709
+ def init(self, in_shape, only_train = True, **kargs):
710
+ self.only_train = only_train
711
+ return in_shape
712
+
713
+ def abstract_forward(self, x, **kargs):
714
+ return x.abstractApplyLeaf('concretize')
715
+
716
+ # stochastic correlation
717
+ class CorrRand(Concretize):
718
+ def init(self, in_shape, num_correlate, only_train = True, **kargs):
719
+ self.only_train = only_train
720
+ self.num_correlate = num_correlate
721
+ return in_shape
722
+
723
+ def abstract_forward(self, x):
724
+ return x.abstractApplyLeaf("stochasticCorrelate", self.num_correlate)
725
+
726
+ def showNet(self, t = ""):
727
+ print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " num_correlate="+ str(self.num_correlate))
728
+
729
+ class CorrMaxK(CorrRand):
730
+ def abstract_forward(self, x):
731
+ return x.abstractApplyLeaf("correlateMaxK", self.num_correlate)
732
+
733
+
734
+ class CorrMaxPool2D(Concretize):
735
+ def init(self,in_shape, kernel_size, only_train = True, max_type = ai.MaxTypes.head_beta, **kargs):
736
+ self.only_train = only_train
737
+ self.kernel_size = kernel_size
738
+ self.max_type = max_type
739
+ return in_shape
740
+
741
+ def abstract_forward(self, x):
742
+ return x.abstractApplyLeaf("correlateMaxPool", kernel_size = self.kernel_size, stride = self.kernel_size, max_type = self.max_type)
743
+
744
+ def showNet(self, t = ""):
745
+ print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " kernel_size="+ str(self.kernel_size) + " max_type=" +str(self.max_type))
746
+
747
+ class CorrMaxPool3D(Concretize):
748
+ def init(self,in_shape, kernel_size, only_train = True, max_type = ai.MaxTypes.only_beta, **kargs):
749
+ self.only_train = only_train
750
+ self.kernel_size = kernel_size
751
+ self.max_type = max_type
752
+ return in_shape
753
+
754
+ def abstract_forward(self, x):
755
+ return x.abstractApplyLeaf("correlateMaxPool", kernel_size = self.kernel_size, stride = self.kernel_size, max_type = self.max_type, max_pool = F.max_pool3d)
756
+
757
+ def showNet(self, t = ""):
758
+ print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " kernel_size="+ str(self.kernel_size) + " max_type=" +self.max_type)
759
+
760
+ class CorrFix(Concretize):
761
+ def init(self,in_shape, k, only_train = True, **kargs):
762
+ self.k = k
763
+ self.only_train = only_train
764
+ return in_shape
765
+
766
+ def abstract_forward(self, x):
767
+ sz = x.size()
768
+ """
769
+ # for more control in the future
770
+ indxs_1 = torch.arange(start = 0, end = sz[1], step = math.ceil(sz[1] / self.dims[1]) )
771
+ indxs_2 = torch.arange(start = 0, end = sz[2], step = math.ceil(sz[2] / self.dims[2]) )
772
+ indxs_3 = torch.arange(start = 0, end = sz[3], step = math.ceil(sz[3] / self.dims[3]) )
773
+
774
+ indxs = torch.stack(torch.meshgrid((indxs_1,indxs_2,indxs_3)), dim=3).view(-1,3)
775
+ """
776
+ szm = h.product(sz[1:])
777
+ indxs = torch.arange(start = 0, end = szm, step = math.ceil(szm / self.k))
778
+ indxs = indxs.unsqueeze(0).expand(sz[0], indxs.size()[0])
779
+
780
+
781
+ return x.abstractApplyLeaf("correlate", indxs)
782
+
783
+ def showNet(self, t = ""):
784
+ print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " k="+ str(self.k))
785
+
786
+
787
+ class DecorrRand(Concretize):
788
+ def init(self, in_shape, num_decorrelate, only_train = True, **kargs):
789
+ self.only_train = only_train
790
+ self.num_decorrelate = num_decorrelate
791
+ return in_shape
792
+
793
+ def abstract_forward(self, x):
794
+ return x.abstractApplyLeaf("stochasticDecorrelate", self.num_decorrelate)
795
+
796
+ class DecorrMin(Concretize):
797
+ def init(self, in_shape, num_decorrelate, only_train = True, num_to_keep = False, **kargs):
798
+ self.only_train = only_train
799
+ self.num_decorrelate = num_decorrelate
800
+ self.num_to_keep = num_to_keep
801
+ return in_shape
802
+
803
+ def abstract_forward(self, x):
804
+ return x.abstractApplyLeaf("decorrelateMin", self.num_decorrelate, num_to_keep = self.num_to_keep)
805
+
806
+
807
+ def showNet(self, t = ""):
808
+ print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " k="+ str(self.num_decorrelate) + " num_to_keep=" + str(self.num_to_keep) )
809
+
810
+ class DeepLoss(ToZono):
811
+ def init(self, in_shape, bw = 0.01, act = F.relu, **kargs): # weight must be between 0 and 1
812
+ self.only_train = True
813
+ self.bw = S.Const.initConst(bw)
814
+ self.act = act
815
+ return in_shape
816
+
817
+ def abstract_forward(self, x, **kargs):
818
+ if x.isPoint():
819
+ return x
820
+ return ai.TaggedDomain(x, self.MLoss(self, x))
821
+
822
+ class MLoss():
823
+ def __init__(self, obj, x):
824
+ self.obj = obj
825
+ self.x = x
826
+
827
+ def loss(self, a, *args, lr = 1, time = 0, **kargs):
828
+ bw = self.obj.bw.getVal(time = time)
829
+ pre_loss = a.loss(*args, time = time, **kargs, lr = lr * (1 - bw))
830
+ if bw <= 0.0:
831
+ return pre_loss
832
+ return (1 - bw) * pre_loss + bw * self.x.deep_loss(act = self.obj.act)
833
+
834
+ def showNet(self, t = ""):
835
+ print(t + self.__class__.__name__ + " only_train=" + str(self.only_train) + " bw="+ str(self.bw) + " act=" + str(self.act) )
836
+
837
+ class IdentLoss(DeepLoss):
838
+ def abstract_forward(self, x, **kargs):
839
+ return x
840
+
841
+ def SkipNet(net1, net2, ffnn, **kargs):
842
+ return Seq(Skip(net1,net2), FFNN(ffnn, **kargs))
843
+
844
+ def WideBlock(out_filters, downsample=False, k=3, bias=False, **kargs):
845
+ if not downsample:
846
+ k_first = 3
847
+ skip_stride = 1
848
+ k_skip = 1
849
+ else:
850
+ k_first = 4
851
+ skip_stride = 2
852
+ k_skip = 2
853
+
854
+ # conv2d280(input)
855
+ blockA = Conv2D(out_filters, kernel_size=k_skip, stride=skip_stride, padding=0, bias=bias, normal=True, **kargs)
856
+
857
+ # conv2d282(relu(conv2d278(input)))
858
+ blockB = Seq( Conv(out_filters, kernel_size = k_first, stride = skip_stride, padding = 1, bias=bias, normal=True, **kargs)
859
+ , Conv2D(out_filters, kernel_size = k, stride = 1, padding = 1, bias=bias, normal=True, **kargs))
860
+ return Seq(ParSum(blockA, blockB), activation(**kargs))
861
+
862
+
863
+
864
+ def BasicBlock(in_planes, planes, stride=1, bias = False, skip_net = False, **kargs):
865
+ block = Seq( Conv(planes, kernel_size = 3, stride = stride, padding = 1, bias=bias, normal=True, **kargs)
866
+ , Conv2D(planes, kernel_size = 3, stride = 1, padding = 1, bias=bias, normal=True, **kargs))
867
+
868
+ if stride != 1 or in_planes != planes:
869
+ block = ParSum(block, Conv2D(planes, kernel_size=1, stride=stride, bias=bias, normal=True, **kargs))
870
+ elif not skip_net:
871
+ block = ParSum(block, Identity())
872
+ return Seq(block, activation(**kargs))
873
+
874
+ # https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
875
+ def ResNet(blocksList, extra = [], bias = False, **kargs):
876
+
877
+ layers = []
878
+ in_planes = 64
879
+ planes = 64
880
+ stride = 0
881
+ for num_blocks in blocksList:
882
+ if stride < 2:
883
+ stride += 1
884
+
885
+ strides = [stride] + [1]*(num_blocks-1)
886
+ for stride in strides:
887
+ layers.append(BasicBlock(in_planes, planes, stride, bias = bias, **kargs))
888
+ in_planes = planes
889
+ planes *= 2
890
+
891
+ print("RESlayers: ", len(layers))
892
+ for e,l in extra:
893
+ layers[l] = Seq(layers[l], e)
894
+
895
+ return Seq(Conv(64, kernel_size=3, stride=1, padding = 1, bias=bias, normal=True, printShape=True),
896
+ *layers)
897
+
898
+
899
+
900
+ def DenseNet(growthRate, depth, reduction, num_classes, bottleneck = True):
901
+
902
+ def Bottleneck(growthRate):
903
+ interChannels = 4*growthRate
904
+
905
+ n = Seq( ReLU(),
906
+ Conv2D(interChannels, kernel_size=1, bias=True, ibp_init = True),
907
+ ReLU(),
908
+ Conv2D(growthRate, kernel_size=3, padding=1, bias=True, ibp_init = True)
909
+ )
910
+
911
+ return Skip(Identity(), n)
912
+
913
+ def SingleLayer(growthRate):
914
+ n = Seq( ReLU(),
915
+ Conv2D(growthRate, kernel_size=3, padding=1, bias=True, ibp_init = True))
916
+ return Skip(Identity(), n)
917
+
918
+ def Transition(nOutChannels):
919
+ return Seq( ReLU(),
920
+ Conv2D(nOutChannels, kernel_size = 1, bias = True, ibp_init = True),
921
+ AvgPool2D(kernel_size=2))
922
+
923
+ def make_dense(growthRate, nDenseBlocks, bottleneck):
924
+ return Seq(*[Bottleneck(growthRate) if bottleneck else SingleLayer(growthRate) for i in range(nDenseBlocks)])
925
+
926
+ nDenseBlocks = (depth-4) // 3
927
+ if bottleneck:
928
+ nDenseBlocks //= 2
929
+
930
+ nChannels = 2*growthRate
931
+ conv1 = Conv2D(nChannels, kernel_size=3, padding=1, bias=True, ibp_init = True)
932
+ dense1 = make_dense(growthRate, nDenseBlocks, bottleneck)
933
+ nChannels += nDenseBlocks * growthRate
934
+ nOutChannels = int(math.floor(nChannels*reduction))
935
+ trans1 = Transition(nOutChannels)
936
+
937
+ nChannels = nOutChannels
938
+ dense2 = make_dense(growthRate, nDenseBlocks, bottleneck)
939
+ nChannels += nDenseBlocks*growthRate
940
+ nOutChannels = int(math.floor(nChannels*reduction))
941
+ trans2 = Transition(nOutChannels)
942
+
943
+ nChannels = nOutChannels
944
+ dense3 = make_dense(growthRate, nDenseBlocks, bottleneck)
945
+
946
+ return Seq(conv1, dense1, trans1, dense2, trans2, dense3,
947
+ ReLU(),
948
+ AvgPool2D(kernel_size=8),
949
+ CorrelateAll(only_train=False, ignore_point = True),
950
+ Linear(num_classes, ibp_init = True))
951
+
convert.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import future
2
+ import builtins
3
+ import past
4
+ import six
5
+
6
+ from timeit import default_timer as timer
7
+ from datetime import datetime
8
+ import argparse
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.optim as optim
13
+ from torchvision import datasets, transforms, utils
14
+ from torch.utils.data import Dataset
15
+
16
+ import inspect
17
+ from inspect import getargspec
18
+ import os
19
+ import helpers as h
20
+ from helpers import Timer
21
+ import copy
22
+ import random
23
+ from itertools import count
24
+
25
+ from components import *
26
+ import models
27
+
28
+ import goals
29
+ from goals import *
30
+ import math
31
+
32
+ from torch.serialization import SourceChangeWarning
33
+ import warnings
34
+
35
+
36
+ parser = argparse.ArgumentParser(description='Convert a pickled PyTorch DiffAI net to an abstract onyx net which returns the interval concretization around the final logits. The first dimension of the output is the natural center, the second dimension is the lb, the third is the ub', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
37
+ parser.add_argument('-n', '--net', type=str, default=None, metavar='N', help='Saved and pickled net to use, in pynet format', required=True)
38
+ parser.add_argument('-d', '--domain', type=str, default="Point()", help='picks which abstract goals to use for testing. Uses box. Doesn\'t use time, so don\'t use Lin. Unless point, should specify a width w.')
39
+ parser.add_argument('-b', '--batch-size', type=int, default=1, help='The batch size to export. Not sure this matters.')
40
+
41
+ parser.add_argument('-o', '--out', type=str, default="convert_out/", metavar='F', help='Where to save the net.')
42
+
43
+ parser.add_argument('--update-net', type=h.str2bool, nargs='?', const=True, default=False, help="should update test net")
44
+ parser.add_argument('--net-name', type=str, choices = h.getMethodNames(models), default=None, help="update test net name")
45
+
46
+ parser.add_argument('--save-name', type=str, default=None, help="name to save the net with. Defaults to <domain>___<netfile-.pynet>.onyx")
47
+
48
+ parser.add_argument('-D', '--dataset', choices = [n for (n,k) in inspect.getmembers(datasets, inspect.isclass) if issubclass(k, Dataset)]
49
+ , default="MNIST", help='picks which dataset to use.')
50
+
51
+ parser.add_argument('--map-to-cpu', type=h.str2bool, nargs='?', const=True, default=False, help="map cuda operations in save back to cpu; enables to run on a computer without a GPU")
52
+
53
+ parser.add_argument('--tf-input', type=h.str2bool, nargs='?', const=True, default=False, help="change the shape of the input data from batch-channels-height-width (standard in pytroch) to batch-height-width-channels (standard in tf)")
54
+
55
+ args = parser.parse_args()
56
+
57
+ out_dir = args.out
58
+
59
+ if not os.path.exists(out_dir):
60
+ os.makedirs(out_dir)
61
+
62
+ with warnings.catch_warnings(record=True) as w:
63
+ warnings.simplefilter("always", SourceChangeWarning)
64
+ if args.map_to_cpu:
65
+ net = torch.load(args.net, map_location='cpu')
66
+ else:
67
+ net = torch.load(args.net)
68
+
69
+ net_name = None
70
+
71
+ if args.net_name is not None:
72
+ net_name = args.net_name
73
+ elif args.update_net and 'name' in dir(net):
74
+ net_name = net.name
75
+
76
+
77
+ def buildNet(n, input_dims, num_classes):
78
+ n = n(num_classes)
79
+ if args.dataset in ["MNIST"]:
80
+ n = Seq(Normalize([0.1307], [0.3081] ), n)
81
+ elif args.dataset in ["CIFAR10", "CIFAR100"]:
82
+ n = Seq(Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), n)
83
+ elif dataset in ["SVHN"]:
84
+ n = Seq(Normalize([0.5,0.5,0.5], [0.2, 0.2, 0.2]), n)
85
+ elif dataset in ["Imagenet12"]:
86
+ n = Seq(Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]), n)
87
+
88
+ n = n.infer(input_dims)
89
+ n.clip_norm()
90
+ return n
91
+
92
+
93
+ if net_name is not None:
94
+ n = getattr(models,net_name)
95
+ n = buildNet(n, net.inShape, net.outShape)
96
+ n.load_state_dict(net.state_dict())
97
+ net = n
98
+
99
+ net = net.to(h.device)
100
+ net.remove_norm()
101
+
102
+ domain = eval(args.domain)
103
+
104
+ if args.save_name is None:
105
+ save_name = h.prepareDomainNameForFile(args.domain) + "___" + os.path.basename(args.net)[:-6] + ".onyx"
106
+ else:
107
+ save_name = args.save_name
108
+
109
+ def abstractNet(inpt):
110
+ if args.tf_input:
111
+ inpt = inpt.permute(0, 3, 1, 2)
112
+ dom = domain.box(inpt, w = None)
113
+ o = net(dom, onyx=True).unsqueeze(1)
114
+
115
+ out = torch.cat([o.vanillaTensorPart(), o.lb().vanillaTensorPart(), o.ub().vanillaTensorPart()], dim=1)
116
+ return out
117
+
118
+ input_shape = [args.batch_size] + list(net.inShape)
119
+ if args.tf_input:
120
+ input_shape = [args.batch_size] + list(net.inShape)[1:] + [net.inShape[0]]
121
+ dummy = h.zeros(input_shape)
122
+
123
+ abstractNet(dummy)
124
+
125
+ class AbstractNet(nn.Module):
126
+ def __init__(self, domain, net, abstractNet):
127
+ super(AbstractNet, self).__init__()
128
+ self.net = net
129
+ self.abstractNet = abstractNet
130
+ if hasattr(domain, "net") and domain.net is not None:
131
+ self.netDom = domain.net
132
+
133
+ def forward(self, inpt):
134
+ return self.abstractNet(inpt)
135
+
136
+ absNet = AbstractNet(domain, net, abstractNet)
137
+
138
+ out_path = os.path.join(out_dir, save_name)
139
+ print("Saving:", out_path)
140
+
141
+ param_list = ["param"+str(i) for i in range(len(list(absNet.parameters())))]
142
+
143
+ torch.onnx.export(absNet, dummy, out_path, verbose=False, input_names=["actual_input"] + param_list, output_names=["output"])
144
+
goals.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import future
2
+ import builtins
3
+ import past
4
+ import six
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.optim as optim
10
+ import torch.autograd
11
+ import components as comp
12
+ from torch.distributions import multinomial, categorical
13
+
14
+ import math
15
+ import numpy as np
16
+
17
+ try:
18
+ from . import helpers as h
19
+ from . import ai
20
+ from . import scheduling as S
21
+ except:
22
+ import helpers as h
23
+ import ai
24
+ import scheduling as S
25
+
26
+
27
+
28
+ class WrapDom(object):
29
+ def __init__(self, a):
30
+ self.a = eval(a) if type(a) is str else a
31
+
32
+ def box(self, *args, **kargs):
33
+ return self.Domain(self.a.box(*args, **kargs))
34
+
35
+ def boxBetween(self, *args, **kargs):
36
+ return self.Domain(self.a.boxBetween(*args, **kargs))
37
+
38
+ def line(self, *args, **kargs):
39
+ return self.Domain(self.a.line(*args, **kargs))
40
+
41
+ class DList(object):
42
+ Domain = ai.ListDomain
43
+ class MLoss():
44
+ def __init__(self, aw):
45
+ self.aw = aw
46
+ def loss(self, dom, *args, lr = 1, **kargs):
47
+ if self.aw <= 0.0:
48
+ return 0
49
+ return self.aw * dom.loss(*args, lr = lr * self.aw, **kargs)
50
+
51
+ def __init__(self, *al):
52
+ if len(al) == 0:
53
+ al = [("Point()", 1.0), ("Box()", 0.1)]
54
+
55
+ self.al = [(eval(a) if type(a) is str else a, S.Const.initConst(aw)) for a,aw in al]
56
+
57
+ def getDiv(self, **kargs):
58
+ return 1.0 / sum(aw.getVal(**kargs) for _,aw in self.al)
59
+
60
+ def box(self, *args, **kargs):
61
+ m = self.getDiv(**kargs)
62
+ return self.Domain(ai.TaggedDomain(a.box(*args, **kargs), DList.MLoss(aw.getVal(**kargs) * m)) for a,aw in self.al)
63
+
64
+ def boxBetween(self, *args, **kargs):
65
+
66
+ m = self.getDiv(**kargs)
67
+ return self.Domain(ai.TaggedDomain(a.boxBetween(*args, **kargs), DList.MLoss(aw.getVal(**kargs) * m)) for a,aw in self.al)
68
+
69
+ def line(self, *args, **kargs):
70
+ m = self.getDiv(**kargs)
71
+ return self.Domain(ai.TaggedDomain(a.line(*args, **kargs), DList.MLoss(aw.getVal(**kargs) * m)) for a,aw in self.al)
72
+
73
+ def __str__(self):
74
+ return "DList(%s)" % h.sumStr("("+str(a)+","+str(w)+")" for a,w in self.al)
75
+
76
+ class Mix(DList):
77
+ def __init__(self, a="Point()", b="Box()", aw = 1.0, bw = 0.1):
78
+ super(Mix, self).__init__((a,aw), (b,bw))
79
+
80
+ class LinMix(DList):
81
+ def __init__(self, a="Point()", b="Box()", bw = 0.1):
82
+ super(LinMix, self).__init__((a,S.Complement(bw)), (b,bw))
83
+
84
+ class DProb(object):
85
+ def __init__(self, *doms):
86
+ if len(doms) == 0:
87
+ doms = [("Point()", 0.8), ("Box()", 0.2)]
88
+ div = 1.0 / sum(float(aw) for _,aw in doms)
89
+ self.domains = [eval(a) if type(a) is str else a for a,_ in doms]
90
+ self.probs = [ div * float(aw) for _,aw in doms]
91
+
92
+ def chooseDom(self):
93
+ return self.domains[np.random.choice(len(self.domains), p = self.probs)] if len(self.domains) > 1 else self.domains[0]
94
+
95
+ def box(self, *args, **kargs):
96
+ domain = self.chooseDom()
97
+ return domain.box(*args, **kargs)
98
+
99
+ def line(self, *args, **kargs):
100
+ domain = self.chooseDom()
101
+ return domain.line(*args, **kargs)
102
+
103
+ def __str__(self):
104
+ return "DProb(%s)" % h.sumStr("("+str(a)+","+str(w)+")" for a,w in zip(self.domains, self.probs))
105
+
106
+ class Coin(DProb):
107
+ def __init__(self, a="Point()", b="Box()", ap = 0.8, bp = 0.2):
108
+ super(Coin, self).__init__((a,ap), (b,bp))
109
+
110
+ class Point(object):
111
+ Domain = h.dten
112
+ def __init__(self, **kargs):
113
+ pass
114
+
115
+ def box(self, original, *args, **kargs):
116
+ return original
117
+
118
+ def line(self, original, other, *args, **kargs):
119
+ return (original + other) / 2
120
+
121
+ def boxBetween(self, o1, o2, *args, **kargs):
122
+ return (o1 + o2) / 2
123
+
124
+ def __str__(self):
125
+ return "Point()"
126
+
127
+ class PointA(Point):
128
+ def boxBetween(self, o1, o2, *args, **kargs):
129
+ return o1
130
+
131
+ def __str__(self):
132
+ return "PointA()"
133
+
134
+ class PointB(Point):
135
+ def boxBetween(self, o1, o2, *args, **kargs):
136
+ return o2
137
+
138
+ def __str__(self):
139
+ return "PointB()"
140
+
141
+
142
+ class NormalPoint(Point):
143
+ def __init__(self, w = None, **kargs):
144
+ self.epsilon = w
145
+
146
+ def box(self, original, w, *args, **kargs):
147
+ """ original = mu = mean, epsilon = variance"""
148
+ if not self.epsilon is None:
149
+ w = self.epsilon
150
+
151
+ inter = torch.randn_like(original, device = h.device) * w
152
+ return original + inter
153
+
154
+ def __str__(self):
155
+ return "NormalPoint(%s)" % ("" if self.epsilon is None else str(self.epsilon))
156
+
157
+
158
+
159
+ class MI_FGSM(Point):
160
+
161
+ def __init__(self, w = None, r = 20.0, k = 100, mu = 0.8, should_end = True, restart = None, searchable=False,**kargs):
162
+ self.epsilon = S.Const.initConst(w)
163
+ self.k = k
164
+ self.mu = mu
165
+ self.r = float(r)
166
+ self.should_end = should_end
167
+ self.restart = restart
168
+ self.searchable = searchable
169
+
170
+ def box(self, original, model, target = None, untargeted = False, **kargs):
171
+ if target is None:
172
+ untargeted = True
173
+ with torch.no_grad():
174
+ target = model(original).max(1)[1]
175
+ return self.attack(model, original, untargeted, target, **kargs)
176
+
177
+ def boxBetween(self, o1, o2, model, target = None, *args, **kargs):
178
+ return self.attack(model, (o1 - o2).abs() / 2, (o1 + o2) / 2, target, **kargs)
179
+
180
+
181
+ def attack(self, model, xo, untargeted, target, w, loss_function=ai.stdLoss, **kargs):
182
+ w = self.epsilon.getVal(c = w, **kargs)
183
+
184
+ x = nn.Parameter(xo.clone(), requires_grad=True)
185
+ gradorg = h.zeros(x.shape)
186
+ is_eq = 1
187
+
188
+ w = h.ones(x.shape) * w
189
+ for i in range(self.k):
190
+ if self.restart is not None and i % int(self.k / self.restart) == 0:
191
+ x = is_eq * (torch.rand_like(xo) * w + xo) + (1 - is_eq) * x
192
+ x = nn.Parameter(x, requires_grad = True)
193
+
194
+ model.optimizer.zero_grad()
195
+
196
+ out = model(x).vanillaTensorPart()
197
+ loss = loss_function(out, target)
198
+
199
+ loss.sum().backward(retain_graph=True)
200
+ with torch.no_grad():
201
+ oth = x.grad / torch.norm(x.grad, p=1)
202
+ gradorg *= self.mu
203
+ gradorg += oth
204
+ grad = (self.r * w / self.k) * ai.mysign(gradorg)
205
+ if self.should_end:
206
+ is_eq = ai.mulIfEq(grad, out, target)
207
+ x = (x + grad * is_eq) if untargeted else (x - grad * is_eq)
208
+
209
+ x = xo + torch.min(torch.max(x - xo, -w),w)
210
+ x.requires_grad_()
211
+
212
+ model.optimizer.zero_grad()
213
+
214
+ return x
215
+
216
+ def boxBetween(self, o1, o2, model, target, *args, **kargs):
217
+ raise "Not boxBetween is not yet supported by MI_FGSM"
218
+
219
+ def __str__(self):
220
+ return "MI_FGSM(%s)" % (("" if self.epsilon is None else "w="+str(self.epsilon)+",")
221
+ + ("" if self.k == 5 else "k="+str(self.k)+",")
222
+ + ("" if self.r == 5.0 else "r="+str(self.r)+",")
223
+ + ("" if self.mu == 0.8 else "r="+str(self.mu)+",")
224
+ + ("" if self.should_end else "should_end=False"))
225
+
226
+
227
+ class PGD(MI_FGSM):
228
+ def __init__(self, r = 5.0, k = 5, **kargs):
229
+ super(PGD,self).__init__(r=r, k = k, mu = 0, **kargs)
230
+
231
+ def __str__(self):
232
+ return "PGD(%s)" % (("" if self.epsilon is None else "w="+str(self.epsilon)+",")
233
+ + ("" if self.k == 5 else "k="+str(self.k)+",")
234
+ + ("" if self.r == 5.0 else "r="+str(self.r)+",")
235
+ + ("" if self.should_end else "should_end=False"))
236
+
237
+ class IFGSM(PGD):
238
+
239
+ def __init__(self, k = 5, **kargs):
240
+ super(IFGSM, self).__init__(r = 1, k=k, **kargs)
241
+
242
+ def __str__(self):
243
+ return "IFGSM(%s)" % (("" if self.epsilon is None else "w="+str(self.epsilon)+",")
244
+ + ("" if self.k == 5 else "k="+str(self.k)+",")
245
+ + ("" if self.should_end else "should_end=False"))
246
+
247
+ class NormalAdv(Point):
248
+ def __init__(self, a="IFGSM()", w = None):
249
+ self.a = (eval(a) if type(a) is str else a)
250
+ self.epsilon = S.Const.initConst(w)
251
+
252
+ def box(self, original, w, *args, **kargs):
253
+ epsilon = self.epsilon.getVal(c = w, shape = original.shape[:1], **kargs)
254
+ assert (0 <= h.dten(epsilon)).all()
255
+ epsilon = torch.randn(original.size()[0:1], device = h.device)[0] * epsilon
256
+ return self.a.box(original, w = epsilon, *args, **kargs)
257
+
258
+ def __str__(self):
259
+ return "NormalAdv(%s)" % ( str(self.a) + ("" if self.epsilon is None else ",w="+str(self.epsilon)))
260
+
261
+
262
+ class InclusionSample(Point):
263
+ def __init__(self, sub, a="Box()", normal = False, w = None, **kargs):
264
+ self.sub = S.Const.initConst(sub) # sub is the fraction of w to use.
265
+ self.w = S.Const.initConst(w)
266
+ self.normal = normal
267
+ self.a = (eval(a) if type(a) is str else a)
268
+
269
+ def box(self, original, w, *args, **kargs):
270
+ w = self.w.getVal(c = w, shape = original.shape[:1], **kargs)
271
+ sub = self.sub.getVal(c = 1, shape = original.shape[:1], **kargs)
272
+
273
+ assert (0 <= h.dten(w)).all()
274
+ assert (h.dten(sub) <= 1).all()
275
+ assert (0 <= h.dten(sub)).all()
276
+ if self.normal:
277
+ inter = torch.randn_like(original, device = h.device)
278
+ else:
279
+ inter = (torch.rand_like(original, device = h.device) * 2 - 1)
280
+
281
+ inter = inter * w * (1 - sub)
282
+
283
+ return self.a.box(original + inter, w = w * sub, *args, **kargs)
284
+
285
+ def boxBetween(self, o1, o2, *args, **kargs):
286
+ w = (o2 - o1).abs()
287
+ return self.box( (o2 + o1)/2 , w = w, *args, **kargs)
288
+
289
+ def __str__(self):
290
+ return "InclusionSample(%s, %s)" % (str(self.sub), str(self.a) + ("" if self.epsilon is None else ",w="+str(self.epsilon)))
291
+
292
+ InSamp = InclusionSample
293
+
294
+
295
+ class AdvInclusion(InclusionSample):
296
+ def __init__(self, sub, a="IFGSM()", b="Box()", w = None, **kargs):
297
+ self.sub = S.Const.initConst(sub) # sub is the fraction of w to use.
298
+ self.w = S.Const.initConst(w)
299
+ self.a = (eval(a) if type(a) is str else a)
300
+ self.b = (eval(b) if type(b) is str else b)
301
+
302
+ def box(self, original, w, *args, **kargs):
303
+ w = self.w.getVal(c = w, shape = original.shape, **kargs)
304
+ sub = self.sub.getVal(c = 1, shape = original.shape, **kargs)
305
+
306
+ assert (0 <= h.dten(w)).all()
307
+ assert (h.dten(sub) <= 1).all()
308
+ assert (0 <= h.dten(sub)).all()
309
+
310
+ if h.dten(w).sum().item() <= 0.0:
311
+ inter = original
312
+ else:
313
+ inter = self.a.box(original, w = w * (1 - sub), *args, **kargs)
314
+
315
+ return self.b.box(inter, w = w * sub, *args, **kargs)
316
+
317
+ def __str__(self):
318
+ return "AdvInclusion(%s, %s, %s)" % (str(self.sub), str(self.a), str(self.b) + ("" if self.epsilon is None else ",w="+str(self.epsilon)))
319
+
320
+
321
+ class AdvDom(Point):
322
+ def __init__(self, a="IFGSM()", b="Box()"):
323
+ self.a = (eval(a) if type(a) is str else a)
324
+ self.b = (eval(b) if type(b) is str else b)
325
+
326
+ def box(self, original,*args, **kargs):
327
+ adv = self.a.box(original, *args, **kargs)
328
+ return self.b.boxBetween(original, adv.ub(), *args, **kargs)
329
+
330
+ def boxBetween(self, o1, o2, *args, **kargs):
331
+ original = (o1 + o2) / 2
332
+ adv = self.a.boxBetween(o1, o2, *args, **kargs)
333
+ return self.b.boxBetween(original, adv.ub(), *args, **kargs)
334
+
335
+ def __str__(self):
336
+ return "AdvDom(%s)" % (("" if self.width is None else "width="+str(self.width)+",")
337
+ + str(self.a) + "," + str(self.b))
338
+
339
+
340
+
341
+ class BiAdv(AdvDom):
342
+ def box(self, original, **kargs):
343
+ adv = self.a.box(original, **kargs)
344
+ extreme = (adv.ub() - original).abs()
345
+ return self.b.boxBetween(original - extreme, original + extreme, **kargs)
346
+
347
+ def boxBetween(self, o1, o2, *args, **kargs):
348
+ original = (o1 + o2) / 2
349
+ adv = self.a.boxBetween(o1, o2, *args, **kargs)
350
+ extreme = (adv.ub() - original).abs()
351
+ return self.b.boxBetween(original - extreme, original + extreme, *args, **kargs)
352
+
353
+ def __str__(self):
354
+ return "BiAdv" + AdvDom.__str__(self)[6:]
355
+
356
+
357
+ class HBox(object):
358
+ Domain = ai.HybridZonotope
359
+
360
+ def domain(self, *args, **kargs):
361
+ return ai.TaggedDomain(self.Domain(*args, **kargs), self)
362
+
363
+ def __init__(self, w = None, tot_weight = 1, width_weight = 0, pow_loss = None, log_loss = False, searchable = True, cross_loss = True, **kargs):
364
+ self.w = S.Const.initConst(w)
365
+ self.tot_weight = S.Const.initConst(tot_weight)
366
+ self.width_weight = S.Const.initConst(width_weight)
367
+ self.pow_loss = pow_loss
368
+ self.searchable = searchable
369
+ self.log_loss = log_loss
370
+ self.cross_loss = cross_loss
371
+
372
+ def __str__(self):
373
+ return "HBox(%s)" % ("" if self.w is None else "w="+str(self.w))
374
+
375
+ def boxBetween(self, o1, o2, *args, **kargs):
376
+ batches = o1.size()[0]
377
+ num_elem = h.product(o1.size()[1:])
378
+ ei = h.getEi(batches, num_elem)
379
+
380
+ if len(o1.size()) > 2:
381
+ ei = ei.contiguous().view(num_elem, *o1.size())
382
+
383
+ return self.domain((o1 + o2) / 2, None, ei * (o2 - o1).abs() / 2).checkSizes()
384
+
385
+ def box(self, original, w, **kargs):
386
+ """
387
+ This version of it is slow, but keeps correlation down the line.
388
+ """
389
+ radius = self.w.getVal(c = w, **kargs)
390
+
391
+ batches = original.size()[0]
392
+ num_elem = h.product(original.size()[1:])
393
+ ei = h.getEi(batches,num_elem)
394
+
395
+ if len(original.size()) > 2:
396
+ ei = ei.contiguous().view(num_elem, *original.size())
397
+
398
+ return self.domain(original, None, ei * radius).checkSizes()
399
+
400
+ def line(self, o1, o2, **kargs):
401
+ w = self.w.getVal(c = 0, **kargs)
402
+
403
+ ln = ((o2 - o1) / 2).unsqueeze(0)
404
+ if not w is None and w > 0.0:
405
+ batches = o1.size()[0]
406
+ num_elem = h.product(o1.size()[1:])
407
+ ei = h.getEi(batches,num_elem)
408
+ if len(o1.size()) > 2:
409
+ ei = ei.contiguous().view(num_elem, *o1.size())
410
+ ln = torch.cat([ln, ei * w])
411
+ return self.domain((o1 + o2) / 2, None, ln ).checkSizes()
412
+
413
+ def loss(self, dom, target, *args, **kargs):
414
+ width_weight = self.width_weight.getVal(**kargs)
415
+ tot_weight = self.tot_weight.getVal(**kargs)
416
+
417
+ if self.cross_loss:
418
+ r = dom.ub()
419
+ inds = torch.arange(r.shape[0], device=h.device, dtype=h.ltype)
420
+ r[inds,target] = dom.lb()[inds,target]
421
+ tot = r.loss(target, *args, **kargs)
422
+ else:
423
+ tot = dom.loss(target, *args, **kargs)
424
+
425
+ if self.log_loss:
426
+ tot = (tot + 1).log()
427
+ if self.pow_loss is not None and self.pow_loss > 0 and self.pow_loss != 1:
428
+ tot = tot.pow(self.pow_loss)
429
+
430
+ ls = tot * tot_weight
431
+ if width_weight > 0:
432
+ ls += dom.diameter() * width_weight
433
+
434
+ return ls / (width_weight + tot_weight)
435
+
436
+ class Box(HBox):
437
+ def __str__(self):
438
+ return "Box(%s)" % ("" if self.w is None else "w="+str(self.w))
439
+
440
+ def box(self, original, w, **kargs):
441
+ """
442
+ This version of it takes advantage of betas being uncorrelated.
443
+ Unfortunately they stay uncorrelated forever.
444
+ Counterintuitively, tests show more accuracy - this is because the other box
445
+ creates lots of 0 errors which get accounted for by the calcultion of the newhead in relu
446
+ which is apparently worse than not accounting for errors.
447
+ """
448
+ radius = self.w.getVal(c = w, **kargs)
449
+ return self.domain(original, h.ones(original.size()) * radius, None).checkSizes()
450
+
451
+ def line(self, o1, o2, **kargs):
452
+ w = self.w.getVal(c = 0, **kargs)
453
+ return self.domain((o1 + o2) / 2, ((o2 - o1) / 2).abs() + h.ones(o2.size()) * w, None).checkSizes()
454
+
455
+ def boxBetween(self, o1, o2, *args, **kargs):
456
+ return self.line(o1, o2, **kargs)
457
+
458
+ class ZBox(HBox):
459
+
460
+ def __str__(self):
461
+ return "ZBox(%s)" % ("" if self.w is None else "w="+str(self.w))
462
+
463
+ def Domain(self, *args, **kargs):
464
+ return ai.Zonotope(*args, **kargs)
465
+
466
+ class HSwitch(HBox):
467
+ def __str__(self):
468
+ return "HSwitch(%s)" % ("" if self.w is None else "w="+str(self.w))
469
+
470
+ def Domain(self, *args, **kargs):
471
+ return ai.HybridZonotope(*args, customRelu = ai.creluSwitch, **kargs)
472
+
473
+ class ZSwitch(ZBox):
474
+
475
+ def __str__(self):
476
+ return "ZSwitch(%s)" % ("" if self.w is None else "w="+str(self.w))
477
+ def Domain(self, *args, **kargs):
478
+ return ai.Zonotope(*args, customRelu = ai.creluSwitch, **kargs)
479
+
480
+
481
+ class ZNIPS(ZBox):
482
+
483
+ def __str__(self):
484
+ return "ZSwitch(%s)" % ("" if self.w is None else "w="+str(self.w))
485
+
486
+ def Domain(self, *args, **kargs):
487
+ return ai.Zonotope(*args, customRelu = ai.creluNIPS, **kargs)
488
+
489
+ class HSmooth(HBox):
490
+ def __str__(self):
491
+ return "HSmooth(%s)" % ("" if self.w is None else "w="+str(self.w))
492
+
493
+ def Domain(self, *args, **kargs):
494
+ return ai.HybridZonotope(*args, customRelu = ai.creluSmooth, **kargs)
495
+
496
+ class HNIPS(HBox):
497
+ def __str__(self):
498
+ return "HSmooth(%s)" % ("" if self.w is None else "w="+str(self.w))
499
+
500
+ def Domain(self, *args, **kargs):
501
+ return ai.HybridZonotope(*args, customRelu = ai.creluNIPS, **kargs)
502
+
503
+ class ZSmooth(ZBox):
504
+ def __str__(self):
505
+ return "ZSmooth(%s)" % ("" if self.w is None else "w="+str(self.w))
506
+
507
+ def Domain(self, *args, **kargs):
508
+ return ai.Zonotope(*args, customRelu = ai.creluSmooth, **kargs)
509
+
510
+
511
+
512
+
513
+
514
+ # stochastic correlation
515
+ class HRand(WrapDom):
516
+ # domain must be an ai style domain like hybrid zonotope.
517
+ def __init__(self, num_correlated, a = "HSwitch()", **kargs):
518
+ super(HRand, self).__init__(Box())
519
+ self.num_correlated = num_correlated
520
+ self.dom = eval(a) if type(a) is str else a
521
+
522
+ def Domain(self, d):
523
+ with torch.no_grad():
524
+ out = d.abstractApplyLeaf('stochasticCorrelate', self.num_correlated)
525
+ out = self.dom.Domain(out.head, out.beta, out.errors)
526
+ return out
527
+
528
+ def __str__(self):
529
+ return "HRand(%s, domain = %s)" % (str(self.num_correlated), str(self.a))
helpers.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import future
2
+ import builtins
3
+ import past
4
+ import six
5
+ import inspect
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.optim as optim
11
+ import numpy as np
12
+ import argparse
13
+ import decimal
14
+ import PIL
15
+ from torchvision import datasets, transforms
16
+ from datetime import datetime
17
+
18
+ from forbiddenfruit import curse
19
+ #from torch.autograd import Variable
20
+
21
+ from timeit import default_timer as timer
22
+
23
+ class Timer:
24
+ def __init__(self, activity = None, units = 1, shouldPrint = True, f = None):
25
+ self.activity = activity
26
+ self.units = units
27
+ self.shouldPrint = shouldPrint
28
+ self.f = f
29
+ def __enter__(self):
30
+ self.start = timer()
31
+ return self
32
+ def getUnitTime(self):
33
+ return (self.end - self.start) / self.units
34
+
35
+ def __str__(self):
36
+ return "Avg time to " + self.activity + ": "+str(self.getUnitTime())
37
+
38
+ def __exit__(self, *args):
39
+ self.end = timer()
40
+ if self.shouldPrint:
41
+ printBoth(self, f = self.f)
42
+
43
+ def cudify(x):
44
+ if use_cuda:
45
+ return x.cuda(async=True)
46
+ return x
47
+
48
+ def pyval(a, **kargs):
49
+ return dten([a], **kargs)
50
+
51
+ def ifThenElse(cond, a, b):
52
+ cond = cond.to_dtype()
53
+ return cond * a + (1 - cond) * b
54
+
55
+ def ifThenElseL(cond, a, b):
56
+ return cond * a + (1 - cond) * b
57
+
58
+ def product(it):
59
+ if isinstance(it,int):
60
+ return it
61
+ product = 1
62
+ for x in it:
63
+ if x >= 0:
64
+ product *= x
65
+ return product
66
+
67
+ def getEi(batches, num_elem):
68
+ return eye(num_elem).expand(batches, num_elem,num_elem).permute(1,0,2)
69
+
70
+ def one_hot(batch,d):
71
+ bs = batch.size()[0]
72
+ indexes = [ list(range(bs)), batch]
73
+ values = [ 1 for _ in range(bs) ]
74
+ return cudify(torch.sparse.FloatTensor(ltenCPU(indexes), ftenCPU(values), torch.Size([bs,d])))
75
+
76
+ def seye(n, m = None):
77
+ if m is None:
78
+ m = n
79
+ mn = n if n < m else m
80
+ indexes = [[ i for i in range(mn) ], [ i for i in range(mn) ] ]
81
+ values = [1 for i in range(mn) ]
82
+ return cudify(torch.sparse.ByteTensor(ltenCPU(indexes), dtenCPU(values), torch.Size([n,m])))
83
+
84
+ dtype = torch.float32
85
+ ftype = torch.float32
86
+ ltype = torch.int64
87
+ btype = torch.uint8
88
+
89
+ torch.set_default_dtype(dtype)
90
+
91
+ cpu = torch.device("cpu")
92
+
93
+ cuda_async = True
94
+
95
+ ftenCPU = lambda *args, **kargs: torch.tensor(*args, dtype=ftype, device=cpu, **kargs)
96
+ dtenCPU = lambda *args, **kargs: torch.tensor(*args, dtype=dtype, device=cpu, **kargs)
97
+ ltenCPU = lambda *args, **kargs: torch.tensor(*args, dtype=ltype, device=cpu, **kargs)
98
+ btenCPU = lambda *args, **kargs: torch.tensor(*args, dtype=btype, device=cpu, **kargs)
99
+
100
+ if torch.cuda.is_available() and not 'NOCUDA' in os.environ:
101
+ print("using cuda")
102
+ device = torch.device("cuda")
103
+ ften = lambda *args, **kargs: torch.tensor(*args, dtype=ftype, device=device, **kargs).cuda(non_blocking=cuda_async)
104
+ dten = lambda *args, **kargs: torch.tensor(*args, dtype=dtype, device=device, **kargs).cuda(non_blocking=cuda_async)
105
+ lten = lambda *args, **kargs: torch.tensor(*args, dtype=ltype, device=device, **kargs).cuda(non_blocking=cuda_async)
106
+ bten = lambda *args, **kargs: torch.tensor(*args, dtype=btype, device=device, **kargs).cuda(non_blocking=cuda_async)
107
+ ones = lambda *args, **cargs: torch.ones(*args, **cargs).cuda(non_blocking=cuda_async)
108
+ zeros = lambda *args, **cargs: torch.zeros(*args, **cargs).cuda(non_blocking=cuda_async)
109
+ eye = lambda *args, **cargs: torch.eye(*args, **cargs).cuda(non_blocking=cuda_async)
110
+ use_cuda = True
111
+ print("set up cuda")
112
+ else:
113
+ print("not using cuda")
114
+ ften = ftenCPU
115
+ dten = dtenCPU
116
+ lten = ltenCPU
117
+ bten = btenCPU
118
+ ones = torch.ones
119
+ zeros = torch.zeros
120
+ eye = torch.eye
121
+ use_cuda = False
122
+ device = cpu
123
+
124
+ def smoothmax(x, alpha, dim = 0):
125
+ return x.mul(F.softmax(x * alpha, dim)).sum(dim + 1)
126
+
127
+
128
+ def str2bool(v):
129
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
130
+ return True
131
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
132
+ return False
133
+ else:
134
+ raise argparse.ArgumentTypeError('Boolean value expected.')
135
+
136
+
137
+
138
+ def flat(lst):
139
+ lst_ = []
140
+ for l in lst:
141
+ lst_ += l
142
+ return lst_
143
+
144
+
145
+ def printBoth(*st, f = None):
146
+ print(*st)
147
+ if not f is None:
148
+ print(*st, file=f)
149
+
150
+
151
+ def hasMethod(cl, mt):
152
+ return callable(getattr(cl, mt, None))
153
+
154
+ def getMethodNames(Foo):
155
+ return [func for func in dir(Foo) if callable(getattr(Foo, func)) and not func.startswith("__")]
156
+
157
+ def getMethods(Foo):
158
+ return [getattr(Foo, m) for m in getMethodNames(Foo)]
159
+
160
+ max_c_for_norm = 10000
161
+
162
+ def numel(arr):
163
+ return product(arr.size())
164
+
165
+ def chunks(l, n):
166
+ """Yield successive n-sized chunks from l."""
167
+ for i in range(0, len(l), n):
168
+ yield l[i:i + n]
169
+
170
+
171
+ def loadDataset(dataset, batch_size, train, transform = True):
172
+ oargs = {}
173
+ if dataset in ["MNIST", "CIFAR10", "CIFAR100", "FashionMNIST", "PhotoTour"]:
174
+ oargs['train'] = train
175
+ elif dataset in ["STL10", "SVHN"] :
176
+ oargs['split'] = 'train' if train else 'test'
177
+ elif dataset in ["LSUN"]:
178
+ oargs['classes'] = 'train' if train else 'test'
179
+ elif dataset in ["Imagenet12"]:
180
+ pass
181
+ else:
182
+ raise Exception(dataset + " is not yet supported")
183
+
184
+ if dataset in ["MNIST"]:
185
+ transformer = transforms.Compose([ transforms.ToTensor()]
186
+ + ([transforms.Normalize((0.1307,), (0.3081,))] if transform else []))
187
+ elif dataset in ["CIFAR10", "CIFAR100"]:
188
+ transformer = transforms.Compose(([ #transforms.RandomCrop(32, padding=4),
189
+ transforms.RandomAffine(0, (0.125, 0.125), resample=PIL.Image.BICUBIC) ,
190
+ transforms.RandomHorizontalFlip(),
191
+ #transforms.RandomRotation(15, resample = PIL.Image.BILINEAR)
192
+ ] if train else [])
193
+ + [transforms.ToTensor()]
194
+ + ([transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))] if transform else []))
195
+ elif dataset in ["SVHN"]:
196
+ transformer = transforms.Compose([
197
+ transforms.RandomHorizontalFlip(),
198
+ transforms.ToTensor(),
199
+ transforms.Normalize((0.5,0.5,0.5), (0.2,0.2,0.2))])
200
+ else:
201
+ transformer = transforms.ToTensor()
202
+
203
+ if dataset in ["Imagenet12"]:
204
+ # https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#download-the-imagenet-dataset
205
+ train_set = datasets.ImageFolder(
206
+ '../data/Imagenet12/train' if train else '../data/Imagenet12/val',
207
+ transforms.Compose([
208
+ transforms.RandomResizedCrop(224),
209
+ transforms.RandomHorizontalFlip(),
210
+ normalize,
211
+ ]))
212
+ else:
213
+ train_set = getattr(datasets, dataset)('../data', download=True, transform=transformer, **oargs)
214
+ return torch.utils.data.DataLoader(
215
+ train_set
216
+ , batch_size=batch_size
217
+ , shuffle=True,
218
+ **({'num_workers': 1, 'pin_memory': True} if use_cuda else {}))
219
+
220
+
221
+ def variable(Pt):
222
+ class Point:
223
+ def isSafe(self,target):
224
+ pred = self.max(1, keepdim=True)[1] # get the index of the max log-probability
225
+ return pred.eq(target.data.view_as(pred))
226
+
227
+ def isPoint(self):
228
+ return True
229
+
230
+ def labels(self):
231
+ return [self[0].max(1)[1]] # get the index of the max log-probability
232
+
233
+ def softplus(self):
234
+ return F.softplus(self)
235
+
236
+ def elu(self):
237
+ return F.elu(self)
238
+
239
+ def selu(self):
240
+ return F.selu(self)
241
+
242
+ def sigm(self):
243
+ return F.sigmoid(self)
244
+
245
+ def conv3d(self, *args, **kargs):
246
+ return F.conv3d(self, *args, **kargs)
247
+ def conv2d(self, *args, **kargs):
248
+ return F.conv2d(self, *args, **kargs)
249
+ def conv1d(self, *args, **kargs):
250
+ return F.conv1d(self, *args, **kargs)
251
+
252
+ def conv_transpose3d(self, *args, **kargs):
253
+ return F.conv_transpose3d(self, *args, **kargs)
254
+ def conv_transpose2d(self, *args, **kargs):
255
+ return F.conv_transpose2d(self, *args, **kargs)
256
+ def conv_transpose1d(self, *args, **kargs):
257
+ return F.conv_transpose1d(self, *args, **kargs)
258
+
259
+ def max_pool2d(self, *args, **kargs):
260
+ return F.max_pool2d(self, *args, **kargs)
261
+
262
+ def avg_pool2d(self, *args, **kargs):
263
+ return F.avg_pool2d(self, *args, **kargs)
264
+
265
+ def adaptive_avg_pool2d(self, *args, **kargs):
266
+ return F.adaptive_avg_pool2d(self, *args, **kargs)
267
+
268
+
269
+ def cat(self, other, dim = 0, **kargs):
270
+ return torch.cat((self, other), dim = dim, **kargs)
271
+
272
+ def addPar(self, a, b):
273
+ return a + b
274
+
275
+ def abstractApplyLeaf(self, foo, *args, **kargs):
276
+ return self
277
+
278
+ def diameter(self):
279
+ return pyval(0)
280
+
281
+ def to_dtype(self):
282
+ return self.type(dtype=dtype, non_blocking=cuda_async)
283
+
284
+ def loss(self, target, **kargs):
285
+ if torch.__version__[0] == "0":
286
+ return F.cross_entropy(self, target, reduce = False)
287
+ else:
288
+ return F.cross_entropy(self, target, reduction='none')
289
+
290
+ def deep_loss(self, *args, **kargs):
291
+ return 0
292
+
293
+ def merge(self, *args, **kargs):
294
+ return self
295
+
296
+ def splitRelu(self, *args, **kargs):
297
+ return self
298
+
299
+ def lb(self):
300
+ return self
301
+ def vanillaTensorPart(self):
302
+ return self
303
+ def center(self):
304
+ return self
305
+ def ub(self):
306
+ return self
307
+
308
+ def cudify(self, cuda_async = True):
309
+ return self.cuda(non_blocking=cuda_async) if use_cuda else self
310
+
311
+ def log_softmax(self, *args, dim = 1, **kargs):
312
+ return F.log_softmax(self, *args,dim = dim, **kargs)
313
+
314
+ if torch.__version__[0] == "0" and torch.__version__ != "0.4.1":
315
+ Point.log_softmax = log_softmax
316
+
317
+
318
+ def log_softmax(self, *args, dim = 1, **kargs):
319
+ return F.log_softmax(self, *args,dim = dim, **kargs)
320
+
321
+ if torch.__version__[0] == "0" and torch.__version__ != "0.4.1":
322
+ Point.log_softmax = log_softmax
323
+
324
+ for nm in getMethodNames(Point):
325
+ curse(Pt, nm, getattr(Point, nm))
326
+
327
+ variable(torch.autograd.Variable)
328
+ variable(torch.cuda.DoubleTensor)
329
+ variable(torch.DoubleTensor)
330
+ variable(torch.cuda.FloatTensor)
331
+ variable(torch.FloatTensor)
332
+ variable(torch.ByteTensor)
333
+ variable(torch.Tensor)
334
+
335
+
336
+ def default(dic, nm, d):
337
+ if dic is not None and nm in dic:
338
+ return dic[nm]
339
+ return d
340
+
341
+
342
+
343
+
344
+ def softmaxBatchNP(x, epsilon, subtract = False):
345
+ """Compute softmax values for each sets of scores in x."""
346
+ x = x.astype(np.float64)
347
+ ex = x / epsilon if epsilon is not None else x
348
+ if subtract:
349
+ ex -= ex.max(axis=1)[:,np.newaxis]
350
+ e_x = np.exp(ex)
351
+ sm = (e_x / e_x.sum(axis=1)[:,np.newaxis])
352
+ am = np.argmax(x, axis=1)
353
+ bads = np.logical_not(np.isfinite(sm.sum(axis = 1)))
354
+
355
+ if epsilon is None:
356
+ sm[bads] = 0
357
+ sm[bads, am[bads]] = 1
358
+ else:
359
+ epsilon *= (x.shape[1] - 1) / x.shape[1]
360
+ sm[bads] = epsilon / (x.shape[1] - 1)
361
+ sm[bads, am[bads]] = 1 - epsilon
362
+
363
+ sm /= sm.sum(axis=1)[:,np.newaxis]
364
+ return sm
365
+
366
+
367
+ def cadd(a,b):
368
+ both = a.cat(b)
369
+ a, b = both.split(a.size()[0])
370
+ return a + b
371
+
372
+ def msum(a,b, l):
373
+ if a is None:
374
+ return b
375
+ if b is None:
376
+ return a
377
+ return l(a,b)
378
+
379
+ class SubAct(argparse.Action):
380
+ def __init__(self, sub_choices, *args, **kargs):
381
+ super(SubAct,self).__init__(*args, nargs='+', **kargs)
382
+ self.sub_choices = sub_choices
383
+ self.sub_choices_names = None if sub_choices is None else getMethodNames(sub_choices)
384
+
385
+ def __call__(self, parser, namespace, values, option_string=None):
386
+ if self.sub_choices_names is not None and not values[0] in self.sub_choices_names:
387
+ msg = 'invalid choice: %r (choose from %s)' % (values[0], self.sub_choices_names)
388
+ raise argparse.ArgumentError(self, msg)
389
+
390
+ prev = getattr(namespace, self.dest)
391
+ setattr(namespace, self.dest, prev + [values])
392
+
393
+ def catLists(val):
394
+ if isinstance(val, list):
395
+ v = []
396
+ for i in val:
397
+ v += catLists(i)
398
+ return v
399
+ return [val]
400
+
401
+ def sumStr(val):
402
+ s = ""
403
+ for v in val:
404
+ s += v
405
+ return s
406
+
407
+ def catStrs(val):
408
+ s = val[0]
409
+ if len(val) > 1:
410
+ s += "("
411
+ for v in val[1:2]:
412
+ s += v
413
+ for v in val[2:]:
414
+ s += ", "+v
415
+ if len(val) > 1:
416
+ s += ")"
417
+ return s
418
+
419
+ def printNumpy(x):
420
+ return "[" + sumStr([decimal.Decimal(float(v)).__format__("f") + ", " for v in x.data.cpu().numpy()])[:-2]+"]"
421
+
422
+ def printStrList(x):
423
+ return "[" + sumStr(v + ", " for v in x)[:-2]+"]"
424
+
425
+ def printListsNumpy(val):
426
+ if isinstance(val, list):
427
+ return printStrList(printListsNumpy(v) for v in val)
428
+ return printNumpy(val)
429
+
430
+ def parseValues(values, methods, *others):
431
+ if len(values) == 1 and values[0]:
432
+ x = eval(values[0], dict(pair for l in ([methods] + list(others)) for pair in l.__dict__.items()) )
433
+
434
+ return x() if inspect.isclass(x) else x
435
+ args = []
436
+ kargs = {}
437
+ for arg in values[1:]:
438
+ if '=' in arg:
439
+ k = arg.split('=')[0]
440
+ v = arg[len(k)+1:]
441
+ try:
442
+ kargs[k] = eval(v)
443
+ except:
444
+ kargs[k] = v
445
+ else:
446
+ args += [eval(arg)]
447
+ return getattr(methods, values[0])(*args, **kargs)
448
+
449
+ def preDomRes(outDom, target): # TODO: make faster again by keeping sparse tensors sparse
450
+ t = one_hot(target.long(), outDom.size()[1]).to_dense().to_dtype()
451
+ tmat = t.unsqueeze(2).matmul(t.unsqueeze(1))
452
+
453
+ tl = t.unsqueeze(2).expand(-1, -1, tmat.size()[1])
454
+
455
+ inv_t = eye(tmat.size()[1]).expand(tmat.size()[0], -1, -1)
456
+ inv_t = inv_t - tmat
457
+
458
+ tl = tl.bmm(inv_t)
459
+
460
+ fst = outDom.unsqueeze(1).matmul(tl).squeeze(1)
461
+ snd = outDom.unsqueeze(1).matmul(inv_t).squeeze(1)
462
+
463
+ return (fst - snd) + t
464
+
465
+ def mopen(shouldnt, *args, **kargs):
466
+ if shouldnt:
467
+ import contextlib
468
+ return contextlib.suppress()
469
+ return open(*args, **kargs)
470
+
471
+ def file_timestamp():
472
+ return str(datetime.now()).replace(":","").replace(" ", "")
473
+
474
+ def prepareDomainNameForFile(s):
475
+ return s.replace(" ", "_").replace(",", "").replace("(", "_").replace(")", "_").replace("=", "_")
476
+
477
+ # delimited only
478
+ def callCC(foo):
479
+ class RV(BaseException):
480
+ def __init__(self, v):
481
+ self.v = v
482
+
483
+ def cc(x):
484
+ raise RV(x)
485
+
486
+ try:
487
+ return foo(cc)
488
+ except RV as rv:
489
+ return rv.v
losses.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source file is part of DiffAI
2
+ # Copyright (c) 2018 Secure, Reliable, and Intelligent Systems Lab (SRI), ETH Zurich
3
+ # This software is distributed under the MIT License: https://opensource.org/licenses/MIT
4
+ # SPDX-License-Identifier: MIT
5
+ # For more information see https://github.com/eth-sri/diffai
6
+
7
+ # THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT ANY WARRANTY OF ANY KIND, EITHER
8
+ # EXPRESS, IMPLIED OR STATUTORY, INCLUDING BUT NOT LIMITED TO ANY WARRANTY
9
+ # THAT THE SOFTWARE WILL CONFORM TO SPECIFICATIONS OR BE ERROR-FREE AND ANY
10
+ # IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE,
11
+ # TITLE, OR NON-INFRINGEMENT. IN NO EVENT SHALL ETH ZURICH BE LIABLE FOR ANY
12
+ # DAMAGES, INCLUDING BUT NOT LIMITED TO DIRECT, INDIRECT,
13
+ # SPECIAL OR CONSEQUENTIAL DAMAGES, ARISING OUT OF, RESULTING FROM, OR IN
14
+ # ANY WAY CONNECTED WITH THIS SOFTWARE (WHETHER OR NOT BASED UPON WARRANTY,
15
+ # CONTRACT, TORT OR OTHERWISE).
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import torch.optim as optim
21
+
22
+ import helpers as h
23
+ import domains
24
+ from domains import *
25
+ import math
26
+
27
+
28
+ POINT_DOMAINS = [m for m in h.getMethods(domains) if h.hasMethod(m, "attack")] + [ torch.FloatTensor, torch.Tensor, torch.cuda.FloatTensor ]
29
+ SYMETRIC_DOMAINS = [domains.Box] + POINT_DOMAINS
30
+
31
+ def domRes(outDom, target, **args): # TODO: make faster again by keeping sparse tensors sparse
32
+ t = h.one_hot(target.data.long(), outDom.size()[1]).to_dense()
33
+ tmat = t.unsqueeze(2).matmul(t.unsqueeze(1))
34
+
35
+ tl = t.unsqueeze(2).expand(-1, -1, tmat.size()[1])
36
+
37
+ inv_t = h.eye(tmat.size()[1]).expand(tmat.size()[0], -1, -1)
38
+ inv_t = inv_t - tmat
39
+
40
+ tl = tl.bmm(inv_t)
41
+
42
+ fst = outDom.bmm(tl)
43
+ snd = outDom.bmm(inv_t)
44
+ diff = fst - snd
45
+ return diff.lb() + t
46
+
47
+ def isSafeDom(outDom, target, **args):
48
+ od,_ = torch.min(domRes(outDom, target, **args), 1)
49
+ return od.gt(0.0).long().item()
50
+
51
+
52
+ def isSafeBox(target, net, inp, eps, dom):
53
+ atarg = target.argmax(1)[0].unsqueeze(0)
54
+ if hasattr(dom, "attack"):
55
+ x = dom.attack(net, eps, inp, target)
56
+ pred = net(x).argmax(1)[0].unsqueeze(0) # get the index of the max log-probability
57
+ return pred.item() == atarg.item()
58
+ else:
59
+ outDom = net(dom.box(inp, eps))
60
+ return isSafeDom(outDom, atarg)
media/overview.png ADDED
media/resnetTinyFewCombo.png ADDED
models.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from . import components as n
3
+ from . import ai
4
+ from . import scheduling as S
5
+ except:
6
+ import components as n
7
+ import scheduling as S
8
+ import ai
9
+
10
+ ############# Previously Known Models. Not guaranteed to have the same performance as previous papers.
11
+
12
+ def FFNN(c, **kargs):
13
+ return n.FFNN([100, 100, 100, 100, 100,c], last_lin = True, last_zono = True, **kargs)
14
+
15
+ def ConvSmall(c, **kargs):
16
+ return n.LeNet([ (16,4,4,2), (32,4,4,2) ], [100,c], last_lin = True, last_zono = True, **kargs)
17
+
18
+ def ConvMed(c, **kargs):
19
+ return n.LeNet([ (16,4,4,2), (32,4,4,2) ], [100,c], padding = 1, last_lin = True, last_zono = True, **kargs)
20
+
21
+ def ConvBig(c, **kargs):
22
+ return n.LeNet([ (32,3,3,1), (32,4,4,2) , (64,3,3,1), (64,4,4,2)], [512, 512,c], padding = 1, last_lin = True, last_zono = True, **kargs)
23
+
24
+ def ConvLargeIBP(c, **kargs):
25
+ return n.LeNet([ (64, 3, 3, 1), (64,3,3,1), (128,3,3,2), (128,3,3,1), (128,3,3,1)], [200,c], padding=1, ibp_init = True, bias = True, last_lin = True, last_zono = True, **kargs)
26
+
27
+ def ResNetWong(c, **kargs):
28
+ return n.Seq(n.Conv(16, 3, padding=1, bias=False), n.WideBlock(16), n.WideBlock(16), n.WideBlock(32, True), n.WideBlock(64, True), n.FFNN([1000, c], ibp_init = True, bias=True, last_lin=True, last_zono = True, **kargs))
29
+
30
+ def TruncatedVGG(c, **kargs):
31
+ return n.LeNet([ (64, 3, 3, 1), (64,3,3,1), (128,3,3,2), (128,3,3,1)], [512,c], padding=1, ibp_init = True, bias = True, last_lin = True, last_zono = True, **kargs)
32
+
33
+
34
+ ############# New Models
35
+
36
+ def ResNetTiny(c, **kargs): # resnetWide also used by mixtrain and scaling provable adversarial defenses
37
+ def wb(c, bias = True, **kargs):
38
+ return n.WideBlock(c, False, bias=bias, ibp_init=True, batch_norm = False, **kargs)
39
+ return n.Seq(n.Conv(16, 3, padding=1, bias=True, ibp_init = True),
40
+ wb(16),
41
+ wb(32),
42
+ wb(32),
43
+ wb(32),
44
+ wb(32),
45
+ n.FFNN([500, c], bias=True, last_lin=True, ibp_init = True, last_zono = True, **kargs))
46
+
47
+ def ResNetTiny_FewCombo(c, **kargs): # resnetWide also used by mixtrain and scaling provable adversarial defenses
48
+ def wb(c, bias = True, **kargs):
49
+ return n.WideBlock(c, False, bias=bias, ibp_init=True, batch_norm = False, **kargs)
50
+ dl = n.DeepLoss
51
+ cmk = n.CorrMaxK
52
+ cm2d = n.CorrMaxPool2D
53
+ cm3d = n.CorrMaxPool3D
54
+ dec = lambda x: n.DecorrMin(x, num_to_keep = True)
55
+ return n.Seq(cmk(32),
56
+ n.Conv(16, 3, padding=1, bias=True, ibp_init = True), dec(8),
57
+ wb(16), dec(4),
58
+ wb(32), n.Concretize(),
59
+ wb(32),
60
+ wb(32),
61
+ wb(32), cmk(10),
62
+ n.FFNN([500, c], bias=True, last_lin=True, ibp_init = True, last_zono = True, **kargs))
63
+
64
+
65
+ def ResNetTiny_ManyFixed(c, **kargs): # resnetWide also used by mixtrain and scaling provable adversarial defenses
66
+ def wb(c, bias = True, **kargs):
67
+ return n.WideBlock(c, False, bias=bias, ibp_init=True, batch_norm = False, **kargs)
68
+ cmk = n.CorrFix
69
+ dec = lambda x: n.DecorrMin(x, num_to_keep = True)
70
+ return n.Seq(n.CorrMaxK(32),
71
+ n.Conv(16, 3, padding=1, bias=True, ibp_init = True), cmk(16), dec(16),
72
+ wb(16), cmk(8), dec(8),
73
+ wb(32), cmk(8), dec(8),
74
+ wb(32), cmk(4), dec(4),
75
+ wb(32), n.Concretize(),
76
+ wb(32),
77
+ n.FFNN([500, c], bias=True, last_lin=True, ibp_init = True, last_zono = True, **kargs))
78
+
79
+ def SkipNet18(c, **kargs):
80
+ return n.Seq(n.ResNet([2,2,2,2], bias = True, ibp_init = True, skip_net = True), n.FFNN([512, 512, c], bias=True, last_lin=True, last_zono = True, ibp_init = True, **kargs))
81
+
82
+ def SkipNet18_Combo(c, **kargs):
83
+ dl = n.DeepLoss
84
+ cmk = n.CorrFix
85
+ dec = lambda x: n.DecorrMin(x, num_to_keep = True)
86
+ return n.Seq(n.ResNet([2,2,2,2], extra = [ (cmk(20),2),(dec(10),2)
87
+ ,(cmk(10),3),(dec(5),3),(dl(S.Until(90, S.Lin(0, 0.2, 50, 40), 0)), 3)
88
+ ,(cmk(5),4),(dec(2),4)], bias = True, ibp_init=True, skip_net = True), n.FFNN([512, 512, c], bias=True, last_lin=True, last_zono = True, ibp_init=True, **kargs))
89
+
90
+ def ResNet18(c, **kargs):
91
+ return n.Seq(n.ResNet([2,2,2,2], bias = True, ibp_init = True), n.FFNN([512, 512, c], bias=True, last_lin=True, last_zono = True, ibp_init = True, **kargs))
92
+
93
+
94
+ def ResNetLarge_LargeCombo(c, **kargs): # resnetWide also used by mixtrain and scaling provable adversarial defenses
95
+ def wb(c, bias = True, **kargs):
96
+ return n.WideBlock(c, False, bias=bias, ibp_init=True, batch_norm = False, **kargs)
97
+ dl = n.DeepLoss
98
+ cmk = n.CorrMaxK
99
+ cm2d = n.CorrMaxPool2D
100
+ cm3d = n.CorrMaxPool3D
101
+ dec = lambda x: n.DecorrMin(x, num_to_keep = True)
102
+ return n.Seq(n.Conv(16, 3, padding=1, bias=True, ibp_init = True), cmk(4),
103
+ wb(16), cmk(4), dec(4),
104
+ wb(32), cmk(4), dec(4),
105
+ wb(32), dl(S.Until(1, 0, S.Lin(0.5, 0, 50, 3))),
106
+ wb(32), cmk(4), dec(4),
107
+ wb(64), cmk(4), dec(2),
108
+ wb(64), dl(S.Until(24, S.Lin(0, 0.1, 20, 4), S.Lin(0.1, 0, 50))),
109
+ wb(64),
110
+ n.FFNN([1000, c], bias=True, last_lin=True, ibp_init = True, **kargs))
111
+
112
+
113
+
114
+ def ResNet34(c, **kargs):
115
+ return n.Seq(n.ResNet([3,4,6,3], bias = True, ibp_init = True), n.FFNN([512, 512, c], bias=True, last_lin=True, last_zono = True, ibp_init = True, **kargs))
116
+
117
+
118
+ def DenseNet100(c, **kwargs):
119
+ return n.DenseNet(growthRate=12, depth=100, reduction=0.5,
120
+ bottleneck=True, num_classes = c)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy
2
+ six
3
+ future
4
+ forbiddenfruit
5
+ torch==0.4.1
6
+ torchvision==0.2.1
scheduling.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ try:
6
+ from . import helpers as h
7
+ except:
8
+ import helpers as h
9
+
10
+
11
+
12
+ class Const():
13
+ def __init__(self, c):
14
+ self.c = c if c is None else float(c)
15
+
16
+ def getVal(self, c = None, **kargs):
17
+ return self.c if self.c is not None else c
18
+
19
+ def __str__(self):
20
+ return str(self.c)
21
+
22
+ def initConst(x):
23
+ return x if isinstance(x, Const) else Const(x)
24
+
25
+ class Lin(Const):
26
+ def __init__(self, start, end, steps, initial = 0, quant = False):
27
+ self.start = float(start)
28
+ self.end = float(end)
29
+ self.steps = float(steps)
30
+ self.initial = float(initial)
31
+ self.quant = quant
32
+
33
+ def getVal(self, time = 0, **kargs):
34
+ if self.quant:
35
+ time = math.floor(time)
36
+ return (self.end - self.start) * max(0,min(1, float(time - self.initial) / self.steps)) + self.start
37
+
38
+ def __str__(self):
39
+ return "Lin(%s,%s,%s,%s, quant=%s)".format(str(self.start), str(self.end), str(self.steps), str(self.initial), str(self.quant))
40
+
41
+ class Until(Const):
42
+ def __init__(self, thresh, a, b):
43
+ self.a = Const.initConst(a)
44
+ self.b = Const.initConst(b)
45
+ self.thresh = thresh
46
+
47
+ def getVal(self, *args, time = 0, **kargs):
48
+ return self.a.getVal(*args, time = time, **kargs) if time < self.thresh else self.b.getVal(*args, time = time - self.thresh, **kargs)
49
+
50
+ def __str__(self):
51
+ return "Until(%s, %s, %s)" % (str(self.thresh), str(self.a), str(self.b))
52
+
53
+ class Scale(Const): # use with mix when aw = 1, and 0 <= c < 1
54
+ def __init__(self, c):
55
+ self.c = Const.initConst(c)
56
+
57
+ def getVal(self, *args, **kargs):
58
+ c = self.c.getVal(*args, **kargs)
59
+ if c == 0:
60
+ return 0
61
+ assert c >= 0
62
+ assert c < 1
63
+ return c / (1 - c)
64
+
65
+ def __str__(self):
66
+ return "Scale(%s)" % str(self.c)
67
+
68
+ def MixLin(*args, **kargs):
69
+ return Scale(Lin(*args, **kargs))
70
+
71
+ class Normal(Const):
72
+ def __init__(self, c):
73
+ self.c = Const.initConst(c)
74
+
75
+ def getVal(self, *args, shape = [1], **kargs):
76
+ c = self.c.getVal(*args, shape = shape, **kargs)
77
+ return torch.randn(shape, device = h.device).abs() * c
78
+
79
+ def __str__(self):
80
+ return "Normal(%s)" % str(self.c)
81
+
82
+ class Clip(Const):
83
+ def __init__(self, c, l, u):
84
+ self.c = Const.initConst(c)
85
+ self.l = Const.initConst(l)
86
+ self.u = Const.initConst(u)
87
+
88
+ def getVal(self, *args, **kargs):
89
+ c = self.c.getVal(*args, **kargs)
90
+ l = self.l.getVal(*args, **kargs)
91
+ u = self.u.getVal(*args, **kargs)
92
+ if isinstance(c, float):
93
+ return min(max(c,l),u)
94
+ else:
95
+ return c.clamp(l,u)
96
+
97
+ def __str__(self):
98
+ return "Clip(%s, %s, %s)" % (str(self.c), str(self.l), str(self.u))
99
+
100
+ class Fun(Const):
101
+ def __init__(self, foo):
102
+ self.foo = foo
103
+ def getVal(self, *args, **kargs):
104
+ return self.foo(*args, **kargs)
105
+
106
+ def __str__(self):
107
+ return "Fun(...)"
108
+
109
+ class Complement(Const): # use with mix when aw = 1, and 0 <= c < 1
110
+ def __init__(self, c):
111
+ self.c = Const.initConst(c)
112
+
113
+ def getVal(self, *args, **kargs):
114
+ c = self.c.getVal(*args, **kargs)
115
+ assert c >= 0
116
+ assert c <= 1
117
+ return 1 - c
118
+
119
+ def __str__(self):
120
+ return "Complement(%s)" % str(self.c)