File size: 5,199 Bytes
746c674 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import future
import builtins
import past
import six
from timeit import default_timer as timer
from datetime import datetime
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torch.utils.data import Dataset
import inspect
from inspect import getargspec
import os
import helpers as h
from helpers import Timer
import copy
import random
from itertools import count
from components import *
import models
import goals
from goals import *
import math
from torch.serialization import SourceChangeWarning
import warnings
parser = argparse.ArgumentParser(description='Convert a pickled PyTorch DiffAI net to an abstract onyx net which returns the interval concretization around the final logits. The first dimension of the output is the natural center, the second dimension is the lb, the third is the ub', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-n', '--net', type=str, default=None, metavar='N', help='Saved and pickled net to use, in pynet format', required=True)
parser.add_argument('-d', '--domain', type=str, default="Point()", help='picks which abstract goals to use for testing. Uses box. Doesn\'t use time, so don\'t use Lin. Unless point, should specify a width w.')
parser.add_argument('-b', '--batch-size', type=int, default=1, help='The batch size to export. Not sure this matters.')
parser.add_argument('-o', '--out', type=str, default="convert_out/", metavar='F', help='Where to save the net.')
parser.add_argument('--update-net', type=h.str2bool, nargs='?', const=True, default=False, help="should update test net")
parser.add_argument('--net-name', type=str, choices = h.getMethodNames(models), default=None, help="update test net name")
parser.add_argument('--save-name', type=str, default=None, help="name to save the net with. Defaults to <domain>___<netfile-.pynet>.onyx")
parser.add_argument('-D', '--dataset', choices = [n for (n,k) in inspect.getmembers(datasets, inspect.isclass) if issubclass(k, Dataset)]
, default="MNIST", help='picks which dataset to use.')
parser.add_argument('--map-to-cpu', type=h.str2bool, nargs='?', const=True, default=False, help="map cuda operations in save back to cpu; enables to run on a computer without a GPU")
parser.add_argument('--tf-input', type=h.str2bool, nargs='?', const=True, default=False, help="change the shape of the input data from batch-channels-height-width (standard in pytroch) to batch-height-width-channels (standard in tf)")
args = parser.parse_args()
out_dir = args.out
if not os.path.exists(out_dir):
os.makedirs(out_dir)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always", SourceChangeWarning)
if args.map_to_cpu:
net = torch.load(args.net, map_location='cpu')
else:
net = torch.load(args.net)
net_name = None
if args.net_name is not None:
net_name = args.net_name
elif args.update_net and 'name' in dir(net):
net_name = net.name
def buildNet(n, input_dims, num_classes):
n = n(num_classes)
if args.dataset in ["MNIST"]:
n = Seq(Normalize([0.1307], [0.3081] ), n)
elif args.dataset in ["CIFAR10", "CIFAR100"]:
n = Seq(Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), n)
elif dataset in ["SVHN"]:
n = Seq(Normalize([0.5,0.5,0.5], [0.2, 0.2, 0.2]), n)
elif dataset in ["Imagenet12"]:
n = Seq(Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]), n)
n = n.infer(input_dims)
n.clip_norm()
return n
if net_name is not None:
n = getattr(models,net_name)
n = buildNet(n, net.inShape, net.outShape)
n.load_state_dict(net.state_dict())
net = n
net = net.to(h.device)
net.remove_norm()
domain = eval(args.domain)
if args.save_name is None:
save_name = h.prepareDomainNameForFile(args.domain) + "___" + os.path.basename(args.net)[:-6] + ".onyx"
else:
save_name = args.save_name
def abstractNet(inpt):
if args.tf_input:
inpt = inpt.permute(0, 3, 1, 2)
dom = domain.box(inpt, w = None)
o = net(dom, onyx=True).unsqueeze(1)
out = torch.cat([o.vanillaTensorPart(), o.lb().vanillaTensorPart(), o.ub().vanillaTensorPart()], dim=1)
return out
input_shape = [args.batch_size] + list(net.inShape)
if args.tf_input:
input_shape = [args.batch_size] + list(net.inShape)[1:] + [net.inShape[0]]
dummy = h.zeros(input_shape)
abstractNet(dummy)
class AbstractNet(nn.Module):
def __init__(self, domain, net, abstractNet):
super(AbstractNet, self).__init__()
self.net = net
self.abstractNet = abstractNet
if hasattr(domain, "net") and domain.net is not None:
self.netDom = domain.net
def forward(self, inpt):
return self.abstractNet(inpt)
absNet = AbstractNet(domain, net, abstractNet)
out_path = os.path.join(out_dir, save_name)
print("Saving:", out_path)
param_list = ["param"+str(i) for i in range(len(list(absNet.parameters())))]
torch.onnx.export(absNet, dummy, out_path, verbose=False, input_names=["actual_input"] + param_list, output_names=["output"])
|