Spaces:
Runtime error
Runtime error
+ porting in msma files
Browse files+ adding flow model utils
- app.py +1 -1
- dataset.py +269 -0
- flowutils.py +263 -0
- scorer.py → msma.py +203 -53
app.py
CHANGED
@@ -6,7 +6,7 @@ import matplotlib.pyplot as plt
|
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
|
9 |
-
from
|
10 |
|
11 |
|
12 |
@cache
|
|
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
|
9 |
+
from msma import build_model, config_presets
|
10 |
|
11 |
|
12 |
@cache
|
dataset.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under a Creative Commons
|
4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
5 |
+
# You should have received a copy of the license along with this
|
6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
7 |
+
|
8 |
+
"""Streaming images and labels from datasets created with dataset_tool.py."""
|
9 |
+
|
10 |
+
import json
|
11 |
+
import os
|
12 |
+
import zipfile
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import PIL.Image
|
16 |
+
import torch
|
17 |
+
|
18 |
+
import dnnlib
|
19 |
+
|
20 |
+
try:
|
21 |
+
import pyspng
|
22 |
+
except ImportError:
|
23 |
+
pyspng = None
|
24 |
+
|
25 |
+
# ----------------------------------------------------------------------------
|
26 |
+
# Abstract base class for datasets.
|
27 |
+
|
28 |
+
|
29 |
+
class Dataset(torch.utils.data.Dataset):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
name, # Name of the dataset.
|
33 |
+
raw_shape, # Shape of the raw image data (NCHW).
|
34 |
+
use_labels=True, # Enable conditioning labels? False = label dimension is zero.
|
35 |
+
max_size=None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
|
36 |
+
xflip=False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
|
37 |
+
random_seed=0, # Random seed to use when applying max_size.
|
38 |
+
cache=False, # Cache images in CPU memory?
|
39 |
+
):
|
40 |
+
self._name = name
|
41 |
+
self._raw_shape = list(raw_shape)
|
42 |
+
self._use_labels = use_labels
|
43 |
+
self._cache = cache
|
44 |
+
self._cached_images = dict() # {raw_idx: np.ndarray, ...}
|
45 |
+
self._raw_labels = None
|
46 |
+
self._label_shape = None
|
47 |
+
|
48 |
+
# Apply max_size.
|
49 |
+
self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
|
50 |
+
if (max_size is not None) and (self._raw_idx.size > max_size):
|
51 |
+
np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx)
|
52 |
+
self._raw_idx = np.sort(self._raw_idx[:max_size])
|
53 |
+
|
54 |
+
# Apply xflip.
|
55 |
+
self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
|
56 |
+
if xflip:
|
57 |
+
self._raw_idx = np.tile(self._raw_idx, 2)
|
58 |
+
self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
|
59 |
+
|
60 |
+
def _get_raw_labels(self):
|
61 |
+
if self._raw_labels is None:
|
62 |
+
self._raw_labels = self._load_raw_labels() if self._use_labels else None
|
63 |
+
if self._raw_labels is None:
|
64 |
+
self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
|
65 |
+
assert isinstance(self._raw_labels, np.ndarray)
|
66 |
+
assert self._raw_labels.shape[0] == self._raw_shape[0]
|
67 |
+
assert self._raw_labels.dtype in [np.float32, np.int64]
|
68 |
+
if self._raw_labels.dtype == np.int64:
|
69 |
+
assert self._raw_labels.ndim == 1
|
70 |
+
assert np.all(self._raw_labels >= 0)
|
71 |
+
return self._raw_labels
|
72 |
+
|
73 |
+
def close(self): # to be overridden by subclass
|
74 |
+
pass
|
75 |
+
|
76 |
+
def _load_raw_image(self, raw_idx): # to be overridden by subclass
|
77 |
+
raise NotImplementedError
|
78 |
+
|
79 |
+
def _load_raw_labels(self): # to be overridden by subclass
|
80 |
+
raise NotImplementedError
|
81 |
+
|
82 |
+
def __getstate__(self):
|
83 |
+
return dict(self.__dict__, _raw_labels=None)
|
84 |
+
|
85 |
+
def __del__(self):
|
86 |
+
try:
|
87 |
+
self.close()
|
88 |
+
except:
|
89 |
+
pass
|
90 |
+
|
91 |
+
def __len__(self):
|
92 |
+
return self._raw_idx.size
|
93 |
+
|
94 |
+
def __getitem__(self, idx):
|
95 |
+
raw_idx = self._raw_idx[idx]
|
96 |
+
image = self._cached_images.get(raw_idx, None)
|
97 |
+
if image is None:
|
98 |
+
image = self._load_raw_image(raw_idx)
|
99 |
+
if self._cache:
|
100 |
+
self._cached_images[raw_idx] = image
|
101 |
+
assert isinstance(image, np.ndarray)
|
102 |
+
assert list(image.shape) == self._raw_shape[1:]
|
103 |
+
if self._xflip[idx]:
|
104 |
+
assert image.ndim == 3 # CHW
|
105 |
+
image = image[:, :, ::-1]
|
106 |
+
return image.copy(), self.get_label(idx)
|
107 |
+
|
108 |
+
def get_label(self, idx):
|
109 |
+
label = self._get_raw_labels()[self._raw_idx[idx]]
|
110 |
+
if label.dtype == np.int64:
|
111 |
+
onehot = np.zeros(self.label_shape, dtype=np.float32)
|
112 |
+
onehot[label] = 1
|
113 |
+
label = onehot
|
114 |
+
return label.copy()
|
115 |
+
|
116 |
+
def get_details(self, idx):
|
117 |
+
d = dnnlib.EasyDict()
|
118 |
+
d.raw_idx = int(self._raw_idx[idx])
|
119 |
+
d.xflip = int(self._xflip[idx]) != 0
|
120 |
+
d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
|
121 |
+
return d
|
122 |
+
|
123 |
+
@property
|
124 |
+
def name(self):
|
125 |
+
return self._name
|
126 |
+
|
127 |
+
@property
|
128 |
+
def image_shape(self): # [CHW]
|
129 |
+
return list(self._raw_shape[1:])
|
130 |
+
|
131 |
+
@property
|
132 |
+
def num_channels(self):
|
133 |
+
assert len(self.image_shape) == 3 # CHW
|
134 |
+
return self.image_shape[0]
|
135 |
+
|
136 |
+
@property
|
137 |
+
def resolution(self):
|
138 |
+
assert len(self.image_shape) == 3 # CHW
|
139 |
+
assert self.image_shape[1] == self.image_shape[2]
|
140 |
+
return self.image_shape[1]
|
141 |
+
|
142 |
+
@property
|
143 |
+
def label_shape(self):
|
144 |
+
if self._label_shape is None:
|
145 |
+
raw_labels = self._get_raw_labels()
|
146 |
+
if raw_labels.dtype == np.int64:
|
147 |
+
self._label_shape = [int(np.max(raw_labels)) + 1]
|
148 |
+
else:
|
149 |
+
self._label_shape = raw_labels.shape[1:]
|
150 |
+
return list(self._label_shape)
|
151 |
+
|
152 |
+
@property
|
153 |
+
def label_dim(self):
|
154 |
+
assert len(self.label_shape) == 1
|
155 |
+
return self.label_shape[0]
|
156 |
+
|
157 |
+
@property
|
158 |
+
def has_labels(self):
|
159 |
+
return any(x != 0 for x in self.label_shape)
|
160 |
+
|
161 |
+
@property
|
162 |
+
def has_onehot_labels(self):
|
163 |
+
return self._get_raw_labels().dtype == np.int64
|
164 |
+
|
165 |
+
|
166 |
+
# ----------------------------------------------------------------------------
|
167 |
+
# Dataset subclass that loads images recursively from the specified directory
|
168 |
+
# or ZIP file.
|
169 |
+
|
170 |
+
|
171 |
+
class ImageFolderDataset(Dataset):
|
172 |
+
def __init__(
|
173 |
+
self,
|
174 |
+
path, # Path to directory or zip.
|
175 |
+
resolution=None, # Ensure specific resolution, None = anything goes.
|
176 |
+
**super_kwargs, # Additional arguments for the Dataset base class.
|
177 |
+
):
|
178 |
+
self._path = path
|
179 |
+
self._zipfile = None
|
180 |
+
|
181 |
+
if os.path.isdir(self._path):
|
182 |
+
self._type = "dir"
|
183 |
+
self._all_fnames = {
|
184 |
+
os.path.relpath(os.path.join(root, fname), start=self._path)
|
185 |
+
for root, _dirs, files in os.walk(self._path)
|
186 |
+
for fname in files
|
187 |
+
}
|
188 |
+
elif self._file_ext(self._path) == ".zip":
|
189 |
+
self._type = "zip"
|
190 |
+
self._all_fnames = set(self._get_zipfile().namelist())
|
191 |
+
else:
|
192 |
+
raise IOError("Path must point to a directory or zip")
|
193 |
+
|
194 |
+
PIL.Image.init()
|
195 |
+
supported_ext = PIL.Image.EXTENSION.keys() | {".npy"}
|
196 |
+
self._image_fnames = sorted(
|
197 |
+
fname
|
198 |
+
for fname in self._all_fnames
|
199 |
+
if self._file_ext(fname) in supported_ext
|
200 |
+
)
|
201 |
+
if len(self._image_fnames) == 0:
|
202 |
+
raise IOError("No image files found in the specified path")
|
203 |
+
|
204 |
+
name = os.path.splitext(os.path.basename(self._path))[0]
|
205 |
+
raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
|
206 |
+
if resolution is not None and (
|
207 |
+
raw_shape[2] != resolution or raw_shape[3] != resolution
|
208 |
+
):
|
209 |
+
raise IOError("Image files do not match the specified resolution")
|
210 |
+
super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
|
211 |
+
|
212 |
+
@staticmethod
|
213 |
+
def _file_ext(fname):
|
214 |
+
return os.path.splitext(fname)[1].lower()
|
215 |
+
|
216 |
+
def _get_zipfile(self):
|
217 |
+
assert self._type == "zip"
|
218 |
+
if self._zipfile is None:
|
219 |
+
self._zipfile = zipfile.ZipFile(self._path)
|
220 |
+
return self._zipfile
|
221 |
+
|
222 |
+
def _open_file(self, fname):
|
223 |
+
if self._type == "dir":
|
224 |
+
return open(os.path.join(self._path, fname), "rb")
|
225 |
+
if self._type == "zip":
|
226 |
+
return self._get_zipfile().open(fname, "r")
|
227 |
+
return None
|
228 |
+
|
229 |
+
def close(self):
|
230 |
+
try:
|
231 |
+
if self._zipfile is not None:
|
232 |
+
self._zipfile.close()
|
233 |
+
finally:
|
234 |
+
self._zipfile = None
|
235 |
+
|
236 |
+
def __getstate__(self):
|
237 |
+
return dict(super().__getstate__(), _zipfile=None)
|
238 |
+
|
239 |
+
def _load_raw_image(self, raw_idx):
|
240 |
+
fname = self._image_fnames[raw_idx]
|
241 |
+
ext = self._file_ext(fname)
|
242 |
+
with self._open_file(fname) as f:
|
243 |
+
if ext == ".npy":
|
244 |
+
image = np.load(f)
|
245 |
+
image = image.reshape(-1, *image.shape[-2:])
|
246 |
+
elif ext == ".png" and pyspng is not None:
|
247 |
+
image = pyspng.load(f.read())
|
248 |
+
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
|
249 |
+
else:
|
250 |
+
image = np.array(PIL.Image.open(f))
|
251 |
+
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
|
252 |
+
return image
|
253 |
+
|
254 |
+
def _load_raw_labels(self):
|
255 |
+
fname = "dataset.json"
|
256 |
+
if fname not in self._all_fnames:
|
257 |
+
return None
|
258 |
+
with self._open_file(fname) as f:
|
259 |
+
labels = json.load(f)["labels"]
|
260 |
+
if labels is None:
|
261 |
+
return None
|
262 |
+
labels = dict(labels)
|
263 |
+
labels = [labels[fname.replace("\\", "/")] for fname in self._image_fnames]
|
264 |
+
labels = np.array(labels)
|
265 |
+
labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
|
266 |
+
return labels
|
267 |
+
|
268 |
+
|
269 |
+
# ----------------------------------------------------------------------------
|
flowutils.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
|
3 |
+
import normflows as nf
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from einops import rearrange, repeat
|
8 |
+
|
9 |
+
|
10 |
+
def build_flows(
|
11 |
+
latent_size, num_flows=4, num_blocks=2, hidden_units=128, context_size=64
|
12 |
+
):
|
13 |
+
# Define flows
|
14 |
+
|
15 |
+
flows = []
|
16 |
+
for i in range(num_flows):
|
17 |
+
flows += [
|
18 |
+
nf.flows.CoupledRationalQuadraticSpline(
|
19 |
+
latent_size,
|
20 |
+
num_blocks=num_blocks,
|
21 |
+
num_hidden_channels=hidden_units,
|
22 |
+
num_context_channels=context_size,
|
23 |
+
)
|
24 |
+
]
|
25 |
+
flows += [nf.flows.LULinearPermute(latent_size)]
|
26 |
+
|
27 |
+
# Set base distribution
|
28 |
+
q0 = nf.distributions.DiagGaussian(latent_size, trainable=True)
|
29 |
+
|
30 |
+
# Construct flow model
|
31 |
+
model = nf.ConditionalNormalizingFlow(q0, flows)
|
32 |
+
|
33 |
+
return model
|
34 |
+
|
35 |
+
|
36 |
+
def get_emb(sin_inp):
|
37 |
+
"""
|
38 |
+
Gets a base embedding for one dimension with sin and cos intertwined
|
39 |
+
"""
|
40 |
+
emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
|
41 |
+
return torch.flatten(emb, -2, -1)
|
42 |
+
|
43 |
+
|
44 |
+
class PositionalEncoding2D(nn.Module):
|
45 |
+
def __init__(self, channels):
|
46 |
+
"""
|
47 |
+
:param channels: The last dimension of the tensor you want to apply pos emb to.
|
48 |
+
"""
|
49 |
+
super(PositionalEncoding2D, self).__init__()
|
50 |
+
self.org_channels = channels
|
51 |
+
channels = int(np.ceil(channels / 4) * 2)
|
52 |
+
self.channels = channels
|
53 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
|
54 |
+
self.register_buffer("inv_freq", inv_freq)
|
55 |
+
self.register_buffer("cached_penc", None, persistent=False)
|
56 |
+
|
57 |
+
def forward(self, tensor):
|
58 |
+
"""
|
59 |
+
:param tensor: A 4d tensor of size (batch_size, x, y, ch)
|
60 |
+
:return: Positional Encoding Matrix of size (batch_size, x, y, ch)
|
61 |
+
"""
|
62 |
+
if len(tensor.shape) != 4:
|
63 |
+
raise RuntimeError("The input tensor has to be 4d!")
|
64 |
+
|
65 |
+
if (
|
66 |
+
self.cached_penc is not None
|
67 |
+
and self.cached_penc.shape[:2] == tensor.shape[1:3]
|
68 |
+
):
|
69 |
+
return self.cached_penc
|
70 |
+
|
71 |
+
self.cached_penc = None
|
72 |
+
batch_size, orig_ch, x, y = tensor.shape
|
73 |
+
pos_x = torch.arange(x, device=tensor.device, dtype=self.inv_freq.dtype)
|
74 |
+
pos_y = torch.arange(y, device=tensor.device, dtype=self.inv_freq.dtype)
|
75 |
+
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
|
76 |
+
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
|
77 |
+
emb_x = get_emb(sin_inp_x).unsqueeze(1)
|
78 |
+
emb_y = get_emb(sin_inp_y)
|
79 |
+
emb = torch.zeros(
|
80 |
+
(x, y, self.channels * 2),
|
81 |
+
device=tensor.device,
|
82 |
+
dtype=tensor.dtype,
|
83 |
+
)
|
84 |
+
emb[:, :, : self.channels] = emb_x
|
85 |
+
emb[:, :, self.channels : 2 * self.channels] = emb_y
|
86 |
+
|
87 |
+
self.cached_penc = emb
|
88 |
+
|
89 |
+
return self.cached_penc
|
90 |
+
|
91 |
+
|
92 |
+
class SpatialNormer(nn.Module):
|
93 |
+
def __init__(
|
94 |
+
self,
|
95 |
+
in_channels, # channels will be number of sigma scales in input
|
96 |
+
kernel_size=3,
|
97 |
+
stride=2,
|
98 |
+
padding=1,
|
99 |
+
):
|
100 |
+
"""
|
101 |
+
Note that the convolution will reduce the channel dimension
|
102 |
+
So (b, num_sigmas, c, h, w) -> (b, num_sigmas, new_h , new_w)
|
103 |
+
"""
|
104 |
+
super().__init__()
|
105 |
+
self.conv = nn.Conv3d(
|
106 |
+
in_channels,
|
107 |
+
in_channels,
|
108 |
+
kernel_size,
|
109 |
+
# This is the real trick that ensures each
|
110 |
+
# sigma dimension is normed separately
|
111 |
+
groups=in_channels,
|
112 |
+
stride=(1, stride, stride),
|
113 |
+
padding=(0, padding, padding),
|
114 |
+
bias=False,
|
115 |
+
)
|
116 |
+
self.conv.weight.data.fill_(1) # all ones weights
|
117 |
+
self.conv.weight.requires_grad = False # freeze weights
|
118 |
+
|
119 |
+
@torch.no_grad()
|
120 |
+
def forward(self, x):
|
121 |
+
return self.conv(x.square()).pow_(0.5).squeeze(2)
|
122 |
+
|
123 |
+
|
124 |
+
class PatchFlow(torch.nn.Module):
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
input_size,
|
128 |
+
patch_size=3,
|
129 |
+
context_embedding_size=128,
|
130 |
+
num_blocks=2,
|
131 |
+
hidden_units=128,
|
132 |
+
):
|
133 |
+
super().__init__()
|
134 |
+
num_sigmas, c, h, w = input_size
|
135 |
+
self.local_pooler = SpatialNormer(
|
136 |
+
in_channels=num_sigmas, kernel_size=patch_size
|
137 |
+
)
|
138 |
+
self.flow = build_flows(
|
139 |
+
latent_size=num_sigmas, context_size=context_embedding_size
|
140 |
+
)
|
141 |
+
self.position_encoding = PositionalEncoding2D(channels=context_embedding_size)
|
142 |
+
|
143 |
+
# caching pos encs
|
144 |
+
_, _, ctx_h, ctw_w = self.local_pooler(
|
145 |
+
torch.empty((1, num_sigmas, c, h, w))
|
146 |
+
).shape
|
147 |
+
self.position_encoding(torch.empty(1, 1, ctx_h, ctw_w))
|
148 |
+
assert self.position_encoding.cached_penc.shape[-1] == context_embedding_size
|
149 |
+
|
150 |
+
def init_weights(self):
|
151 |
+
# Initialize weights with Xavier
|
152 |
+
linear_modules = list(
|
153 |
+
filter(lambda m: isinstance(m, nn.Linear), self.flow.modules())
|
154 |
+
)
|
155 |
+
total = len(linear_modules)
|
156 |
+
|
157 |
+
for idx, m in enumerate(linear_modules):
|
158 |
+
# Last layer gets init w/ zeros
|
159 |
+
if idx == total - 1:
|
160 |
+
nn.init.zeros_(m.weight.data)
|
161 |
+
else:
|
162 |
+
nn.init.xavier_uniform_(m.weight.data)
|
163 |
+
|
164 |
+
if m.bias is not None:
|
165 |
+
nn.init.zeros_(m.bias.data)
|
166 |
+
|
167 |
+
def forward(self, x, chunk_size=32):
|
168 |
+
b, s, c, h, w = x.shape
|
169 |
+
x_norm = self.local_pooler(x)
|
170 |
+
_, _, new_h, new_w = x_norm.shape
|
171 |
+
context = self.position_encoding(x_norm)
|
172 |
+
|
173 |
+
# (Patches * batch) x channels
|
174 |
+
local_ctx = rearrange(context, "h w c -> (h w) c")
|
175 |
+
patches = rearrange(x_norm, "b c h w -> (h w) b c")
|
176 |
+
|
177 |
+
nchunks = (patches.shape[0] + chunk_size - 1) // chunk_size
|
178 |
+
patches = patches.chunk(nchunks, dim=0)
|
179 |
+
ctx_chunks = local_ctx.chunk(nchunks, dim=0)
|
180 |
+
patch_logpx = []
|
181 |
+
|
182 |
+
# gc = repeat(global_ctx, "b c -> (n b) c", n=self.patch_batch_size)
|
183 |
+
|
184 |
+
for p, ctx in zip(patches, ctx_chunks):
|
185 |
+
|
186 |
+
# num patches in chunk (<= chunk_size)
|
187 |
+
n = p.shape[0]
|
188 |
+
ctx = repeat(ctx, "n c -> (n b) c", b=b)
|
189 |
+
p = rearrange(p, "n b c -> (n b) c")
|
190 |
+
|
191 |
+
# Compute log densities for each patch
|
192 |
+
logpx = self.flow.log_prob(p, context=ctx)
|
193 |
+
logpx = rearrange(logpx, "(n b) -> n b", n=n, b=b)
|
194 |
+
patch_logpx.append(logpx)
|
195 |
+
# del ctx, p
|
196 |
+
|
197 |
+
# print(p[:4], ctx[:4], logpx)
|
198 |
+
# Convert back to image
|
199 |
+
logpx = torch.cat(patch_logpx, dim=0)
|
200 |
+
logpx = rearrange(logpx, "(h w) b -> b 1 h w", b=b, h=new_h, w=new_w)
|
201 |
+
|
202 |
+
return logpx.contiguous()
|
203 |
+
|
204 |
+
@staticmethod
|
205 |
+
def stochastic_step(
|
206 |
+
scores, x_batch, flow_model, opt=None, train=False, n_patches=32, device="cpu"
|
207 |
+
):
|
208 |
+
if train:
|
209 |
+
flow_model.train()
|
210 |
+
opt.zero_grad(set_to_none=True)
|
211 |
+
else:
|
212 |
+
flow_model.eval()
|
213 |
+
|
214 |
+
patches, context = PatchFlow.get_random_patches(
|
215 |
+
scores, x_batch, flow_model, n_patches
|
216 |
+
)
|
217 |
+
|
218 |
+
patch_feature = patches.to(device)
|
219 |
+
context_vector = context.to(device)
|
220 |
+
patch_feature = rearrange(patch_feature, "n b c -> (n b) c")
|
221 |
+
context_vector = rearrange(context_vector, "n b c -> (n b) c")
|
222 |
+
|
223 |
+
# global_pooled_image = flow_model.global_pooler(x_batch)
|
224 |
+
# global_context = flow_model.global_attention(global_pooled_image)
|
225 |
+
# gctx = repeat(global_context, "b c -> (n b) c", n=n_patches)
|
226 |
+
|
227 |
+
# # Concatenate global context to local context
|
228 |
+
# context_vector = torch.cat([context_vector, gctx], dim=1)
|
229 |
+
|
230 |
+
z, ldj = flow_model.flow.inverse_and_log_det(
|
231 |
+
patch_feature,
|
232 |
+
context=context_vector,
|
233 |
+
)
|
234 |
+
|
235 |
+
loss = -torch.mean(flow_model.flow.q0.log_prob(z) + ldj)
|
236 |
+
loss *= n_patches
|
237 |
+
|
238 |
+
if train:
|
239 |
+
loss.backward()
|
240 |
+
opt.step()
|
241 |
+
|
242 |
+
return loss.item() / n_patches
|
243 |
+
|
244 |
+
@staticmethod
|
245 |
+
def get_random_patches(scores, x_batch, flow_model, n_patches):
|
246 |
+
b = scores.shape[0]
|
247 |
+
h = flow_model.local_pooler(scores)
|
248 |
+
patches = rearrange(h, "b c h w -> (h w) b c")
|
249 |
+
|
250 |
+
context = flow_model.position_encoding(h)
|
251 |
+
context = rearrange(context, "h w c -> (h w) c")
|
252 |
+
context = repeat(context, "n c -> n b c", b=b)
|
253 |
+
|
254 |
+
# conserve gpu memory
|
255 |
+
patches = patches.cpu()
|
256 |
+
context = context.cpu()
|
257 |
+
|
258 |
+
# Get random patches
|
259 |
+
total_patches = patches.shape[0]
|
260 |
+
shuffled_idx = torch.randperm(total_patches)
|
261 |
+
rand_idx_batch = shuffled_idx[:n_patches]
|
262 |
+
|
263 |
+
return patches[rand_idx_batch], context[rand_idx_batch]
|
scorer.py → msma.py
RENAMED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
import pickle
|
|
|
3 |
from pickle import dump, load
|
4 |
|
5 |
import numpy as np
|
@@ -9,9 +10,12 @@ from sklearn.mixture import GaussianMixture
|
|
9 |
from sklearn.model_selection import GridSearchCV
|
10 |
from sklearn.pipeline import Pipeline
|
11 |
from sklearn.preprocessing import StandardScaler
|
|
|
12 |
from tqdm import tqdm
|
13 |
|
14 |
import dnnlib
|
|
|
|
|
15 |
|
16 |
model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
|
17 |
|
@@ -22,6 +26,17 @@ config_presets = {
|
|
22 |
}
|
23 |
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
class EDMScorer(torch.nn.Module):
|
26 |
def __init__(
|
27 |
self,
|
@@ -41,6 +56,7 @@ class EDMScorer(torch.nn.Module):
|
|
41 |
self.sigma_max = sigma_max
|
42 |
self.sigma_data = sigma_data
|
43 |
self.net = net.eval()
|
|
|
44 |
|
45 |
# Adjust noise levels based on how far we want to accumulate
|
46 |
self.sigma_min = 1e-1
|
@@ -63,7 +79,7 @@ class EDMScorer(torch.nn.Module):
|
|
63 |
x,
|
64 |
force_fp32=False,
|
65 |
):
|
66 |
-
x = x.to(torch.float32)
|
67 |
|
68 |
batch_scores = []
|
69 |
for sigma in self.sigma_steps:
|
@@ -76,6 +92,29 @@ class EDMScorer(torch.nn.Module):
|
|
76 |
return batch_scores
|
77 |
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
def build_model(preset="edm2-img64-s-fid", device="cpu"):
|
80 |
netpath = config_presets[preset]
|
81 |
with dnnlib.util.open_url(netpath, verbose=1) as f:
|
@@ -85,41 +124,45 @@ def build_model(preset="edm2-img64-s-fid", device="cpu"):
|
|
85 |
return model
|
86 |
|
87 |
|
88 |
-
def
|
89 |
-
|
90 |
-
return np.quantile(gmm.score_samples(X), 0.1)
|
91 |
|
92 |
-
X = torch.load(score_path)
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
clf.fit(X)
|
97 |
-
inlier_nll = -clf.score_samples(X)
|
98 |
-
|
99 |
-
param_grid = dict(
|
100 |
-
GMM__n_components=range(2, 11, 2),
|
101 |
-
)
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
param_grid=param_grid,
|
106 |
-
cv=10,
|
107 |
-
n_jobs=2,
|
108 |
-
verbose=1,
|
109 |
-
scoring=quantile_scorer,
|
110 |
)
|
111 |
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
means = grid_result.cv_results_["mean_test_score"]
|
117 |
-
stds = grid_result.cv_results_["std_test_score"]
|
118 |
-
params = grid_result.cv_results_["params"]
|
119 |
-
for mean, stdev, param in zip(means, stds, params):
|
120 |
-
print("%f (%f) with: %r" % (mean, stdev, param))
|
121 |
-
|
122 |
-
clf = grid.best_estimator_
|
123 |
|
124 |
os.makedirs(outdir, exist_ok=True)
|
125 |
with open(f"{outdir}/refscores.npz", "wb") as f:
|
@@ -134,26 +177,14 @@ def compute_gmm_likelihood(x_score, gmmdir):
|
|
134 |
clf = load(f)
|
135 |
nll = -clf.score_samples(x_score)
|
136 |
|
137 |
-
with np.load(f"{gmmdir}/refscores.npz", "
|
138 |
ref_nll = f["arr_0"]
|
139 |
percentile = (ref_nll < nll).mean()
|
140 |
|
141 |
return nll, percentile
|
142 |
|
143 |
|
144 |
-
def
|
145 |
-
# f = "doge.jpg"
|
146 |
-
f = "goldfish.JPEG"
|
147 |
-
image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS)
|
148 |
-
image = np.array(image)
|
149 |
-
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
|
150 |
-
x = torch.from_numpy(image).unsqueeze(0).to(device)
|
151 |
-
model = build_model(device=device)
|
152 |
-
scores = model(x)
|
153 |
-
return scores
|
154 |
-
|
155 |
-
|
156 |
-
def runner(preset, dataset_path, device="cpu"):
|
157 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
158 |
refimg, reflabel = dsobj[0]
|
159 |
print(refimg.shape, refimg.dtype, reflabel)
|
@@ -178,19 +209,138 @@ def runner(preset, dataset_path, device="cpu"):
|
|
178 |
print(f"Computed score norms for {score_norms.shape[0]} samples")
|
179 |
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
if __name__ == "__main__":
|
182 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
183 |
preset = "edm2-img64-s-fid"
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
185 |
# preset=preset,
|
186 |
# dataset_path="/GROND_STOR/amahmood/datasets/img64/",
|
187 |
# device="cuda",
|
188 |
# )
|
189 |
-
train_gmm(
|
190 |
-
|
191 |
-
)
|
192 |
-
s = test_runner(device=device)
|
193 |
-
s = s.square().sum(dim=(2, 3, 4)) ** 0.5
|
194 |
-
s = s.to("cpu").numpy()
|
195 |
-
nll, pct = compute_gmm_likelihood(s, gmmdir=f"out/msma/{preset}")
|
196 |
-
print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile")
|
|
|
1 |
import os
|
2 |
import pickle
|
3 |
+
from functools import partial
|
4 |
from pickle import dump, load
|
5 |
|
6 |
import numpy as np
|
|
|
10 |
from sklearn.model_selection import GridSearchCV
|
11 |
from sklearn.pipeline import Pipeline
|
12 |
from sklearn.preprocessing import StandardScaler
|
13 |
+
from torch.utils.data import Subset
|
14 |
from tqdm import tqdm
|
15 |
|
16 |
import dnnlib
|
17 |
+
from dataset import ImageFolderDataset
|
18 |
+
from flowutils import PatchFlow
|
19 |
|
20 |
model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
|
21 |
|
|
|
26 |
}
|
27 |
|
28 |
|
29 |
+
class StandardRGBEncoder:
|
30 |
+
def __init__(self):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
def encode(self, x): # raw pixels => final pixels
|
34 |
+
return x.to(torch.float32) / 127.5 - 1
|
35 |
+
|
36 |
+
def decode(self, x): # final latents => raw pixels
|
37 |
+
return (x.to(torch.float32) * 127.5 + 128).clip(0, 255).to(torch.uint8)
|
38 |
+
|
39 |
+
|
40 |
class EDMScorer(torch.nn.Module):
|
41 |
def __init__(
|
42 |
self,
|
|
|
56 |
self.sigma_max = sigma_max
|
57 |
self.sigma_data = sigma_data
|
58 |
self.net = net.eval()
|
59 |
+
self.encoder = StandardRGBEncoder()
|
60 |
|
61 |
# Adjust noise levels based on how far we want to accumulate
|
62 |
self.sigma_min = 1e-1
|
|
|
79 |
x,
|
80 |
force_fp32=False,
|
81 |
):
|
82 |
+
x = self.encoder.encode(x).to(torch.float32)
|
83 |
|
84 |
batch_scores = []
|
85 |
for sigma in self.sigma_steps:
|
|
|
92 |
return batch_scores
|
93 |
|
94 |
|
95 |
+
class ScoreFlow(torch.nn.Module):
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
scorenet,
|
99 |
+
vectorize=False,
|
100 |
+
device="cpu",
|
101 |
+
):
|
102 |
+
super().__init__()
|
103 |
+
|
104 |
+
h = w = scorenet.net.img_resolution
|
105 |
+
c = scorenet.net.img_channels
|
106 |
+
num_sigmas = len(scorenet.sigma_steps)
|
107 |
+
self.flow = PatchFlow((num_sigmas, c, h, w))
|
108 |
+
|
109 |
+
self.flow = self.flow.to(device)
|
110 |
+
self.scorenet = scorenet.to(device).requires_grad_(False)
|
111 |
+
self.flow.init_weights()
|
112 |
+
|
113 |
+
def forward(self, x, **score_kwargs):
|
114 |
+
x_scores = self.scorenet(x, **score_kwargs)
|
115 |
+
return self.flow(x_scores)
|
116 |
+
|
117 |
+
|
118 |
def build_model(preset="edm2-img64-s-fid", device="cpu"):
|
119 |
netpath = config_presets[preset]
|
120 |
with dnnlib.util.open_url(netpath, verbose=1) as f:
|
|
|
124 |
return model
|
125 |
|
126 |
|
127 |
+
def quantile_scorer(gmm, X, y=None):
|
128 |
+
return np.quantile(gmm.score_samples(X), 0.1)
|
|
|
129 |
|
|
|
130 |
|
131 |
+
def train_gmm(score_path, outdir, grid_search=False):
|
132 |
+
X = torch.load(score_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
+
gm = GaussianMixture(
|
135 |
+
n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000
|
|
|
|
|
|
|
|
|
|
|
136 |
)
|
137 |
|
138 |
+
if grid_search:
|
139 |
+
clf = Pipeline([("scaler", StandardScaler()), ("GMM", gm)])
|
140 |
+
param_grid = dict(
|
141 |
+
GMM__n_components=range(2, 11, 1),
|
142 |
+
)
|
143 |
+
|
144 |
+
grid = GridSearchCV(
|
145 |
+
estimator=clf,
|
146 |
+
param_grid=param_grid,
|
147 |
+
cv=5,
|
148 |
+
n_jobs=2,
|
149 |
+
verbose=1,
|
150 |
+
scoring=quantile_scorer,
|
151 |
+
)
|
152 |
+
|
153 |
+
grid_result = grid.fit(X)
|
154 |
+
|
155 |
+
print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
|
156 |
+
print("-----" * 15)
|
157 |
+
means = grid_result.cv_results_["mean_test_score"]
|
158 |
+
stds = grid_result.cv_results_["std_test_score"]
|
159 |
+
params = grid_result.cv_results_["params"]
|
160 |
+
for mean, stdev, param in zip(means, stds, params):
|
161 |
+
print("%f (%f) with: %r" % (mean, stdev, param))
|
162 |
+
clf = grid.best_estimator_
|
163 |
|
164 |
+
clf.fit(X)
|
165 |
+
inlier_nll = -clf.score_samples(X)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
os.makedirs(outdir, exist_ok=True)
|
168 |
with open(f"{outdir}/refscores.npz", "wb") as f:
|
|
|
177 |
clf = load(f)
|
178 |
nll = -clf.score_samples(x_score)
|
179 |
|
180 |
+
with np.load(f"{gmmdir}/refscores.npz", "rb") as f:
|
181 |
ref_nll = f["arr_0"]
|
182 |
percentile = (ref_nll < nll).mean()
|
183 |
|
184 |
return nll, percentile
|
185 |
|
186 |
|
187 |
+
def cache_score_norms(preset, dataset_path, device="cpu"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
189 |
refimg, reflabel = dsobj[0]
|
190 |
print(refimg.shape, refimg.dtype, reflabel)
|
|
|
209 |
print(f"Computed score norms for {score_norms.shape[0]} samples")
|
210 |
|
211 |
|
212 |
+
def train_flow(dataset_path, preset, device="cuda"):
|
213 |
+
dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
|
214 |
+
refimg, reflabel = dsobj[0]
|
215 |
+
print(f"Loaded {len(dsobj)} samples from {dataset_path}")
|
216 |
+
|
217 |
+
# Subset of training dataset
|
218 |
+
val_ratio = 0.1
|
219 |
+
train_len = int((1 - val_ratio) * len(dsobj))
|
220 |
+
val_len = len(dsobj) - train_len
|
221 |
+
|
222 |
+
print(
|
223 |
+
f"Generating train/test split with ratio={val_ratio} -> {train_len}/{val_len}..."
|
224 |
+
)
|
225 |
+
train_ds = Subset(dsobj, range(train_len))
|
226 |
+
val_ds = Subset(dsobj, range(train_len, train_len + val_len))
|
227 |
+
|
228 |
+
trainiter = torch.utils.data.DataLoader(
|
229 |
+
train_ds, batch_size=48, num_workers=4, prefetch_factor=2
|
230 |
+
)
|
231 |
+
testiter = torch.utils.data.DataLoader(
|
232 |
+
val_ds, batch_size=48, num_workers=4, prefetch_factor=2
|
233 |
+
)
|
234 |
+
|
235 |
+
model = ScoreFlow(build_model(preset=preset), device=device)
|
236 |
+
opt = torch.optim.AdamW(model.flow.parameters(), lr=3e-4, weight_decay=1e-5)
|
237 |
+
train_step = partial(
|
238 |
+
PatchFlow.stochastic_step,
|
239 |
+
flow_model=model.flow,
|
240 |
+
opt=opt,
|
241 |
+
train=True,
|
242 |
+
n_patches=64,
|
243 |
+
device=device,
|
244 |
+
)
|
245 |
+
eval_step = partial(
|
246 |
+
PatchFlow.stochastic_step,
|
247 |
+
flow_model=model.flow,
|
248 |
+
train=False,
|
249 |
+
n_patches=128,
|
250 |
+
device=device,
|
251 |
+
)
|
252 |
+
|
253 |
+
pbar = tqdm(trainiter, desc="Train Loss: ? - Val Loss: ?")
|
254 |
+
step = 0
|
255 |
+
|
256 |
+
for x, _ in tqdm(trainiter):
|
257 |
+
x = x.to(device)
|
258 |
+
scores = model.scorenet(x)
|
259 |
+
|
260 |
+
if step == 0:
|
261 |
+
with torch.inference_mode():
|
262 |
+
val_loss = eval_step(scores, x)
|
263 |
+
|
264 |
+
train_loss = train_step(scores, x)
|
265 |
+
|
266 |
+
if (step + 1) % 10 == 0:
|
267 |
+
|
268 |
+
with torch.inference_mode():
|
269 |
+
val_loss = 0.0
|
270 |
+
for i, (x, _) in enumerate(testiter):
|
271 |
+
x = x.to(device)
|
272 |
+
scores = model.scorenet(x)
|
273 |
+
val_loss += eval_step(scores, x)
|
274 |
+
break
|
275 |
+
val_loss /= i + 1
|
276 |
+
|
277 |
+
pbar.set_description(
|
278 |
+
f"Step: {step:d} - Train: {train_loss:.3f} - Val: {val_loss:.3f}"
|
279 |
+
)
|
280 |
+
step += 1
|
281 |
+
|
282 |
+
torch.save(model.flow.state_dict(), f"out/msma/{preset}/flow.pt")
|
283 |
+
|
284 |
+
|
285 |
+
@torch.inference_mode
|
286 |
+
def test_runner(device="cpu"):
|
287 |
+
# f = "doge.jpg"
|
288 |
+
f = "goldfish.JPEG"
|
289 |
+
image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS)
|
290 |
+
image = np.array(image)
|
291 |
+
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
|
292 |
+
x = torch.from_numpy(image).unsqueeze(0).to(device)
|
293 |
+
model = build_model(device=device)
|
294 |
+
scores = model(x)
|
295 |
+
|
296 |
+
return scores
|
297 |
+
|
298 |
+
|
299 |
+
def test_flow_runner(device="cpu", load_weights=None):
|
300 |
+
f = "doge.jpg"
|
301 |
+
# f = "goldfish.JPEG"
|
302 |
+
image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS)
|
303 |
+
image = np.array(image)
|
304 |
+
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
|
305 |
+
x = torch.from_numpy(image).unsqueeze(0).to(device)
|
306 |
+
model = build_model(device=device)
|
307 |
+
|
308 |
+
score_flow = ScoreFlow(scorenet=model, device=device)
|
309 |
+
|
310 |
+
if load_weights is not None:
|
311 |
+
score_flow.flow.load_state_dict(torch.load(load_weights))
|
312 |
+
|
313 |
+
heatmap = score_flow(x)
|
314 |
+
print(heatmap.shape)
|
315 |
+
|
316 |
+
heatmap = score_flow(x).detach().cpu().numpy()
|
317 |
+
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) * 255
|
318 |
+
im = PIL.Image.fromarray(heatmap[0, 0])
|
319 |
+
im.convert("RGB").save(
|
320 |
+
"heatmap.png",
|
321 |
+
)
|
322 |
+
|
323 |
+
return
|
324 |
+
|
325 |
+
|
326 |
if __name__ == "__main__":
|
327 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
328 |
preset = "edm2-img64-s-fid"
|
329 |
+
imagenette_path = "/GROND_STOR/amahmood/datasets/img64/"
|
330 |
+
|
331 |
+
train_flow(imagenette_path, preset, device)
|
332 |
+
test_flow_runner("cuda", f"out/msma/{preset}/flow.pt")
|
333 |
+
|
334 |
+
# cache_score_norms(
|
335 |
# preset=preset,
|
336 |
# dataset_path="/GROND_STOR/amahmood/datasets/img64/",
|
337 |
# device="cuda",
|
338 |
# )
|
339 |
+
# train_gmm(
|
340 |
+
# f"out/msma/{preset}_imagenette_score_norms.pt", outdir=f"out/msma/{preset}"
|
341 |
+
# )
|
342 |
+
# s = test_runner(device=device)
|
343 |
+
# s = s.square().sum(dim=(2, 3, 4)) ** 0.5
|
344 |
+
# s = s.to("cpu").numpy()
|
345 |
+
# nll, pct = compute_gmm_likelihood(s, gmmdir=f"out/msma/{preset}/")
|
346 |
+
# print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile")
|