ahsanMah commited on
Commit
b1602ac
·
1 Parent(s): 06232ec

+ porting in msma files

Browse files

+ adding flow model utils

Files changed (4) hide show
  1. app.py +1 -1
  2. dataset.py +269 -0
  3. flowutils.py +263 -0
  4. scorer.py → msma.py +203 -53
app.py CHANGED
@@ -6,7 +6,7 @@ import matplotlib.pyplot as plt
6
  import numpy as np
7
  import torch
8
 
9
- from scorer import build_model, config_presets
10
 
11
 
12
  @cache
 
6
  import numpy as np
7
  import torch
8
 
9
+ from msma import build_model, config_presets
10
 
11
 
12
  @cache
dataset.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Streaming images and labels from datasets created with dataset_tool.py."""
9
+
10
+ import json
11
+ import os
12
+ import zipfile
13
+
14
+ import numpy as np
15
+ import PIL.Image
16
+ import torch
17
+
18
+ import dnnlib
19
+
20
+ try:
21
+ import pyspng
22
+ except ImportError:
23
+ pyspng = None
24
+
25
+ # ----------------------------------------------------------------------------
26
+ # Abstract base class for datasets.
27
+
28
+
29
+ class Dataset(torch.utils.data.Dataset):
30
+ def __init__(
31
+ self,
32
+ name, # Name of the dataset.
33
+ raw_shape, # Shape of the raw image data (NCHW).
34
+ use_labels=True, # Enable conditioning labels? False = label dimension is zero.
35
+ max_size=None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
36
+ xflip=False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
37
+ random_seed=0, # Random seed to use when applying max_size.
38
+ cache=False, # Cache images in CPU memory?
39
+ ):
40
+ self._name = name
41
+ self._raw_shape = list(raw_shape)
42
+ self._use_labels = use_labels
43
+ self._cache = cache
44
+ self._cached_images = dict() # {raw_idx: np.ndarray, ...}
45
+ self._raw_labels = None
46
+ self._label_shape = None
47
+
48
+ # Apply max_size.
49
+ self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
50
+ if (max_size is not None) and (self._raw_idx.size > max_size):
51
+ np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx)
52
+ self._raw_idx = np.sort(self._raw_idx[:max_size])
53
+
54
+ # Apply xflip.
55
+ self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
56
+ if xflip:
57
+ self._raw_idx = np.tile(self._raw_idx, 2)
58
+ self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
59
+
60
+ def _get_raw_labels(self):
61
+ if self._raw_labels is None:
62
+ self._raw_labels = self._load_raw_labels() if self._use_labels else None
63
+ if self._raw_labels is None:
64
+ self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
65
+ assert isinstance(self._raw_labels, np.ndarray)
66
+ assert self._raw_labels.shape[0] == self._raw_shape[0]
67
+ assert self._raw_labels.dtype in [np.float32, np.int64]
68
+ if self._raw_labels.dtype == np.int64:
69
+ assert self._raw_labels.ndim == 1
70
+ assert np.all(self._raw_labels >= 0)
71
+ return self._raw_labels
72
+
73
+ def close(self): # to be overridden by subclass
74
+ pass
75
+
76
+ def _load_raw_image(self, raw_idx): # to be overridden by subclass
77
+ raise NotImplementedError
78
+
79
+ def _load_raw_labels(self): # to be overridden by subclass
80
+ raise NotImplementedError
81
+
82
+ def __getstate__(self):
83
+ return dict(self.__dict__, _raw_labels=None)
84
+
85
+ def __del__(self):
86
+ try:
87
+ self.close()
88
+ except:
89
+ pass
90
+
91
+ def __len__(self):
92
+ return self._raw_idx.size
93
+
94
+ def __getitem__(self, idx):
95
+ raw_idx = self._raw_idx[idx]
96
+ image = self._cached_images.get(raw_idx, None)
97
+ if image is None:
98
+ image = self._load_raw_image(raw_idx)
99
+ if self._cache:
100
+ self._cached_images[raw_idx] = image
101
+ assert isinstance(image, np.ndarray)
102
+ assert list(image.shape) == self._raw_shape[1:]
103
+ if self._xflip[idx]:
104
+ assert image.ndim == 3 # CHW
105
+ image = image[:, :, ::-1]
106
+ return image.copy(), self.get_label(idx)
107
+
108
+ def get_label(self, idx):
109
+ label = self._get_raw_labels()[self._raw_idx[idx]]
110
+ if label.dtype == np.int64:
111
+ onehot = np.zeros(self.label_shape, dtype=np.float32)
112
+ onehot[label] = 1
113
+ label = onehot
114
+ return label.copy()
115
+
116
+ def get_details(self, idx):
117
+ d = dnnlib.EasyDict()
118
+ d.raw_idx = int(self._raw_idx[idx])
119
+ d.xflip = int(self._xflip[idx]) != 0
120
+ d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
121
+ return d
122
+
123
+ @property
124
+ def name(self):
125
+ return self._name
126
+
127
+ @property
128
+ def image_shape(self): # [CHW]
129
+ return list(self._raw_shape[1:])
130
+
131
+ @property
132
+ def num_channels(self):
133
+ assert len(self.image_shape) == 3 # CHW
134
+ return self.image_shape[0]
135
+
136
+ @property
137
+ def resolution(self):
138
+ assert len(self.image_shape) == 3 # CHW
139
+ assert self.image_shape[1] == self.image_shape[2]
140
+ return self.image_shape[1]
141
+
142
+ @property
143
+ def label_shape(self):
144
+ if self._label_shape is None:
145
+ raw_labels = self._get_raw_labels()
146
+ if raw_labels.dtype == np.int64:
147
+ self._label_shape = [int(np.max(raw_labels)) + 1]
148
+ else:
149
+ self._label_shape = raw_labels.shape[1:]
150
+ return list(self._label_shape)
151
+
152
+ @property
153
+ def label_dim(self):
154
+ assert len(self.label_shape) == 1
155
+ return self.label_shape[0]
156
+
157
+ @property
158
+ def has_labels(self):
159
+ return any(x != 0 for x in self.label_shape)
160
+
161
+ @property
162
+ def has_onehot_labels(self):
163
+ return self._get_raw_labels().dtype == np.int64
164
+
165
+
166
+ # ----------------------------------------------------------------------------
167
+ # Dataset subclass that loads images recursively from the specified directory
168
+ # or ZIP file.
169
+
170
+
171
+ class ImageFolderDataset(Dataset):
172
+ def __init__(
173
+ self,
174
+ path, # Path to directory or zip.
175
+ resolution=None, # Ensure specific resolution, None = anything goes.
176
+ **super_kwargs, # Additional arguments for the Dataset base class.
177
+ ):
178
+ self._path = path
179
+ self._zipfile = None
180
+
181
+ if os.path.isdir(self._path):
182
+ self._type = "dir"
183
+ self._all_fnames = {
184
+ os.path.relpath(os.path.join(root, fname), start=self._path)
185
+ for root, _dirs, files in os.walk(self._path)
186
+ for fname in files
187
+ }
188
+ elif self._file_ext(self._path) == ".zip":
189
+ self._type = "zip"
190
+ self._all_fnames = set(self._get_zipfile().namelist())
191
+ else:
192
+ raise IOError("Path must point to a directory or zip")
193
+
194
+ PIL.Image.init()
195
+ supported_ext = PIL.Image.EXTENSION.keys() | {".npy"}
196
+ self._image_fnames = sorted(
197
+ fname
198
+ for fname in self._all_fnames
199
+ if self._file_ext(fname) in supported_ext
200
+ )
201
+ if len(self._image_fnames) == 0:
202
+ raise IOError("No image files found in the specified path")
203
+
204
+ name = os.path.splitext(os.path.basename(self._path))[0]
205
+ raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
206
+ if resolution is not None and (
207
+ raw_shape[2] != resolution or raw_shape[3] != resolution
208
+ ):
209
+ raise IOError("Image files do not match the specified resolution")
210
+ super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
211
+
212
+ @staticmethod
213
+ def _file_ext(fname):
214
+ return os.path.splitext(fname)[1].lower()
215
+
216
+ def _get_zipfile(self):
217
+ assert self._type == "zip"
218
+ if self._zipfile is None:
219
+ self._zipfile = zipfile.ZipFile(self._path)
220
+ return self._zipfile
221
+
222
+ def _open_file(self, fname):
223
+ if self._type == "dir":
224
+ return open(os.path.join(self._path, fname), "rb")
225
+ if self._type == "zip":
226
+ return self._get_zipfile().open(fname, "r")
227
+ return None
228
+
229
+ def close(self):
230
+ try:
231
+ if self._zipfile is not None:
232
+ self._zipfile.close()
233
+ finally:
234
+ self._zipfile = None
235
+
236
+ def __getstate__(self):
237
+ return dict(super().__getstate__(), _zipfile=None)
238
+
239
+ def _load_raw_image(self, raw_idx):
240
+ fname = self._image_fnames[raw_idx]
241
+ ext = self._file_ext(fname)
242
+ with self._open_file(fname) as f:
243
+ if ext == ".npy":
244
+ image = np.load(f)
245
+ image = image.reshape(-1, *image.shape[-2:])
246
+ elif ext == ".png" and pyspng is not None:
247
+ image = pyspng.load(f.read())
248
+ image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
249
+ else:
250
+ image = np.array(PIL.Image.open(f))
251
+ image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
252
+ return image
253
+
254
+ def _load_raw_labels(self):
255
+ fname = "dataset.json"
256
+ if fname not in self._all_fnames:
257
+ return None
258
+ with self._open_file(fname) as f:
259
+ labels = json.load(f)["labels"]
260
+ if labels is None:
261
+ return None
262
+ labels = dict(labels)
263
+ labels = [labels[fname.replace("\\", "/")] for fname in self._image_fnames]
264
+ labels = np.array(labels)
265
+ labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
266
+ return labels
267
+
268
+
269
+ # ----------------------------------------------------------------------------
flowutils.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+
3
+ import normflows as nf
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from einops import rearrange, repeat
8
+
9
+
10
+ def build_flows(
11
+ latent_size, num_flows=4, num_blocks=2, hidden_units=128, context_size=64
12
+ ):
13
+ # Define flows
14
+
15
+ flows = []
16
+ for i in range(num_flows):
17
+ flows += [
18
+ nf.flows.CoupledRationalQuadraticSpline(
19
+ latent_size,
20
+ num_blocks=num_blocks,
21
+ num_hidden_channels=hidden_units,
22
+ num_context_channels=context_size,
23
+ )
24
+ ]
25
+ flows += [nf.flows.LULinearPermute(latent_size)]
26
+
27
+ # Set base distribution
28
+ q0 = nf.distributions.DiagGaussian(latent_size, trainable=True)
29
+
30
+ # Construct flow model
31
+ model = nf.ConditionalNormalizingFlow(q0, flows)
32
+
33
+ return model
34
+
35
+
36
+ def get_emb(sin_inp):
37
+ """
38
+ Gets a base embedding for one dimension with sin and cos intertwined
39
+ """
40
+ emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
41
+ return torch.flatten(emb, -2, -1)
42
+
43
+
44
+ class PositionalEncoding2D(nn.Module):
45
+ def __init__(self, channels):
46
+ """
47
+ :param channels: The last dimension of the tensor you want to apply pos emb to.
48
+ """
49
+ super(PositionalEncoding2D, self).__init__()
50
+ self.org_channels = channels
51
+ channels = int(np.ceil(channels / 4) * 2)
52
+ self.channels = channels
53
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
54
+ self.register_buffer("inv_freq", inv_freq)
55
+ self.register_buffer("cached_penc", None, persistent=False)
56
+
57
+ def forward(self, tensor):
58
+ """
59
+ :param tensor: A 4d tensor of size (batch_size, x, y, ch)
60
+ :return: Positional Encoding Matrix of size (batch_size, x, y, ch)
61
+ """
62
+ if len(tensor.shape) != 4:
63
+ raise RuntimeError("The input tensor has to be 4d!")
64
+
65
+ if (
66
+ self.cached_penc is not None
67
+ and self.cached_penc.shape[:2] == tensor.shape[1:3]
68
+ ):
69
+ return self.cached_penc
70
+
71
+ self.cached_penc = None
72
+ batch_size, orig_ch, x, y = tensor.shape
73
+ pos_x = torch.arange(x, device=tensor.device, dtype=self.inv_freq.dtype)
74
+ pos_y = torch.arange(y, device=tensor.device, dtype=self.inv_freq.dtype)
75
+ sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
76
+ sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
77
+ emb_x = get_emb(sin_inp_x).unsqueeze(1)
78
+ emb_y = get_emb(sin_inp_y)
79
+ emb = torch.zeros(
80
+ (x, y, self.channels * 2),
81
+ device=tensor.device,
82
+ dtype=tensor.dtype,
83
+ )
84
+ emb[:, :, : self.channels] = emb_x
85
+ emb[:, :, self.channels : 2 * self.channels] = emb_y
86
+
87
+ self.cached_penc = emb
88
+
89
+ return self.cached_penc
90
+
91
+
92
+ class SpatialNormer(nn.Module):
93
+ def __init__(
94
+ self,
95
+ in_channels, # channels will be number of sigma scales in input
96
+ kernel_size=3,
97
+ stride=2,
98
+ padding=1,
99
+ ):
100
+ """
101
+ Note that the convolution will reduce the channel dimension
102
+ So (b, num_sigmas, c, h, w) -> (b, num_sigmas, new_h , new_w)
103
+ """
104
+ super().__init__()
105
+ self.conv = nn.Conv3d(
106
+ in_channels,
107
+ in_channels,
108
+ kernel_size,
109
+ # This is the real trick that ensures each
110
+ # sigma dimension is normed separately
111
+ groups=in_channels,
112
+ stride=(1, stride, stride),
113
+ padding=(0, padding, padding),
114
+ bias=False,
115
+ )
116
+ self.conv.weight.data.fill_(1) # all ones weights
117
+ self.conv.weight.requires_grad = False # freeze weights
118
+
119
+ @torch.no_grad()
120
+ def forward(self, x):
121
+ return self.conv(x.square()).pow_(0.5).squeeze(2)
122
+
123
+
124
+ class PatchFlow(torch.nn.Module):
125
+ def __init__(
126
+ self,
127
+ input_size,
128
+ patch_size=3,
129
+ context_embedding_size=128,
130
+ num_blocks=2,
131
+ hidden_units=128,
132
+ ):
133
+ super().__init__()
134
+ num_sigmas, c, h, w = input_size
135
+ self.local_pooler = SpatialNormer(
136
+ in_channels=num_sigmas, kernel_size=patch_size
137
+ )
138
+ self.flow = build_flows(
139
+ latent_size=num_sigmas, context_size=context_embedding_size
140
+ )
141
+ self.position_encoding = PositionalEncoding2D(channels=context_embedding_size)
142
+
143
+ # caching pos encs
144
+ _, _, ctx_h, ctw_w = self.local_pooler(
145
+ torch.empty((1, num_sigmas, c, h, w))
146
+ ).shape
147
+ self.position_encoding(torch.empty(1, 1, ctx_h, ctw_w))
148
+ assert self.position_encoding.cached_penc.shape[-1] == context_embedding_size
149
+
150
+ def init_weights(self):
151
+ # Initialize weights with Xavier
152
+ linear_modules = list(
153
+ filter(lambda m: isinstance(m, nn.Linear), self.flow.modules())
154
+ )
155
+ total = len(linear_modules)
156
+
157
+ for idx, m in enumerate(linear_modules):
158
+ # Last layer gets init w/ zeros
159
+ if idx == total - 1:
160
+ nn.init.zeros_(m.weight.data)
161
+ else:
162
+ nn.init.xavier_uniform_(m.weight.data)
163
+
164
+ if m.bias is not None:
165
+ nn.init.zeros_(m.bias.data)
166
+
167
+ def forward(self, x, chunk_size=32):
168
+ b, s, c, h, w = x.shape
169
+ x_norm = self.local_pooler(x)
170
+ _, _, new_h, new_w = x_norm.shape
171
+ context = self.position_encoding(x_norm)
172
+
173
+ # (Patches * batch) x channels
174
+ local_ctx = rearrange(context, "h w c -> (h w) c")
175
+ patches = rearrange(x_norm, "b c h w -> (h w) b c")
176
+
177
+ nchunks = (patches.shape[0] + chunk_size - 1) // chunk_size
178
+ patches = patches.chunk(nchunks, dim=0)
179
+ ctx_chunks = local_ctx.chunk(nchunks, dim=0)
180
+ patch_logpx = []
181
+
182
+ # gc = repeat(global_ctx, "b c -> (n b) c", n=self.patch_batch_size)
183
+
184
+ for p, ctx in zip(patches, ctx_chunks):
185
+
186
+ # num patches in chunk (<= chunk_size)
187
+ n = p.shape[0]
188
+ ctx = repeat(ctx, "n c -> (n b) c", b=b)
189
+ p = rearrange(p, "n b c -> (n b) c")
190
+
191
+ # Compute log densities for each patch
192
+ logpx = self.flow.log_prob(p, context=ctx)
193
+ logpx = rearrange(logpx, "(n b) -> n b", n=n, b=b)
194
+ patch_logpx.append(logpx)
195
+ # del ctx, p
196
+
197
+ # print(p[:4], ctx[:4], logpx)
198
+ # Convert back to image
199
+ logpx = torch.cat(patch_logpx, dim=0)
200
+ logpx = rearrange(logpx, "(h w) b -> b 1 h w", b=b, h=new_h, w=new_w)
201
+
202
+ return logpx.contiguous()
203
+
204
+ @staticmethod
205
+ def stochastic_step(
206
+ scores, x_batch, flow_model, opt=None, train=False, n_patches=32, device="cpu"
207
+ ):
208
+ if train:
209
+ flow_model.train()
210
+ opt.zero_grad(set_to_none=True)
211
+ else:
212
+ flow_model.eval()
213
+
214
+ patches, context = PatchFlow.get_random_patches(
215
+ scores, x_batch, flow_model, n_patches
216
+ )
217
+
218
+ patch_feature = patches.to(device)
219
+ context_vector = context.to(device)
220
+ patch_feature = rearrange(patch_feature, "n b c -> (n b) c")
221
+ context_vector = rearrange(context_vector, "n b c -> (n b) c")
222
+
223
+ # global_pooled_image = flow_model.global_pooler(x_batch)
224
+ # global_context = flow_model.global_attention(global_pooled_image)
225
+ # gctx = repeat(global_context, "b c -> (n b) c", n=n_patches)
226
+
227
+ # # Concatenate global context to local context
228
+ # context_vector = torch.cat([context_vector, gctx], dim=1)
229
+
230
+ z, ldj = flow_model.flow.inverse_and_log_det(
231
+ patch_feature,
232
+ context=context_vector,
233
+ )
234
+
235
+ loss = -torch.mean(flow_model.flow.q0.log_prob(z) + ldj)
236
+ loss *= n_patches
237
+
238
+ if train:
239
+ loss.backward()
240
+ opt.step()
241
+
242
+ return loss.item() / n_patches
243
+
244
+ @staticmethod
245
+ def get_random_patches(scores, x_batch, flow_model, n_patches):
246
+ b = scores.shape[0]
247
+ h = flow_model.local_pooler(scores)
248
+ patches = rearrange(h, "b c h w -> (h w) b c")
249
+
250
+ context = flow_model.position_encoding(h)
251
+ context = rearrange(context, "h w c -> (h w) c")
252
+ context = repeat(context, "n c -> n b c", b=b)
253
+
254
+ # conserve gpu memory
255
+ patches = patches.cpu()
256
+ context = context.cpu()
257
+
258
+ # Get random patches
259
+ total_patches = patches.shape[0]
260
+ shuffled_idx = torch.randperm(total_patches)
261
+ rand_idx_batch = shuffled_idx[:n_patches]
262
+
263
+ return patches[rand_idx_batch], context[rand_idx_batch]
scorer.py → msma.py RENAMED
@@ -1,5 +1,6 @@
1
  import os
2
  import pickle
 
3
  from pickle import dump, load
4
 
5
  import numpy as np
@@ -9,9 +10,12 @@ from sklearn.mixture import GaussianMixture
9
  from sklearn.model_selection import GridSearchCV
10
  from sklearn.pipeline import Pipeline
11
  from sklearn.preprocessing import StandardScaler
 
12
  from tqdm import tqdm
13
 
14
  import dnnlib
 
 
15
 
16
  model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
17
 
@@ -22,6 +26,17 @@ config_presets = {
22
  }
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
25
  class EDMScorer(torch.nn.Module):
26
  def __init__(
27
  self,
@@ -41,6 +56,7 @@ class EDMScorer(torch.nn.Module):
41
  self.sigma_max = sigma_max
42
  self.sigma_data = sigma_data
43
  self.net = net.eval()
 
44
 
45
  # Adjust noise levels based on how far we want to accumulate
46
  self.sigma_min = 1e-1
@@ -63,7 +79,7 @@ class EDMScorer(torch.nn.Module):
63
  x,
64
  force_fp32=False,
65
  ):
66
- x = x.to(torch.float32)
67
 
68
  batch_scores = []
69
  for sigma in self.sigma_steps:
@@ -76,6 +92,29 @@ class EDMScorer(torch.nn.Module):
76
  return batch_scores
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def build_model(preset="edm2-img64-s-fid", device="cpu"):
80
  netpath = config_presets[preset]
81
  with dnnlib.util.open_url(netpath, verbose=1) as f:
@@ -85,41 +124,45 @@ def build_model(preset="edm2-img64-s-fid", device="cpu"):
85
  return model
86
 
87
 
88
- def train_gmm(score_path, outdir):
89
- def quantile_scorer(gmm, X, y=None):
90
- return np.quantile(gmm.score_samples(X), 0.1)
91
 
92
- X = torch.load(score_path)
93
 
94
- gm = GaussianMixture(init_params="kmeans", covariance_type="full", max_iter=100000)
95
- clf = Pipeline([("scaler", StandardScaler()), ("GMM", gm)])
96
- clf.fit(X)
97
- inlier_nll = -clf.score_samples(X)
98
-
99
- param_grid = dict(
100
- GMM__n_components=range(2, 11, 2),
101
- )
102
 
103
- grid = GridSearchCV(
104
- estimator=clf,
105
- param_grid=param_grid,
106
- cv=10,
107
- n_jobs=2,
108
- verbose=1,
109
- scoring=quantile_scorer,
110
  )
111
 
112
- grid_result = grid.fit(X)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
115
- print("-----" * 15)
116
- means = grid_result.cv_results_["mean_test_score"]
117
- stds = grid_result.cv_results_["std_test_score"]
118
- params = grid_result.cv_results_["params"]
119
- for mean, stdev, param in zip(means, stds, params):
120
- print("%f (%f) with: %r" % (mean, stdev, param))
121
-
122
- clf = grid.best_estimator_
123
 
124
  os.makedirs(outdir, exist_ok=True)
125
  with open(f"{outdir}/refscores.npz", "wb") as f:
@@ -134,26 +177,14 @@ def compute_gmm_likelihood(x_score, gmmdir):
134
  clf = load(f)
135
  nll = -clf.score_samples(x_score)
136
 
137
- with np.load(f"{gmmdir}/refscores.npz", "wb") as f:
138
  ref_nll = f["arr_0"]
139
  percentile = (ref_nll < nll).mean()
140
 
141
  return nll, percentile
142
 
143
 
144
- def test_runner(device="cpu"):
145
- # f = "doge.jpg"
146
- f = "goldfish.JPEG"
147
- image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS)
148
- image = np.array(image)
149
- image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
150
- x = torch.from_numpy(image).unsqueeze(0).to(device)
151
- model = build_model(device=device)
152
- scores = model(x)
153
- return scores
154
-
155
-
156
- def runner(preset, dataset_path, device="cpu"):
157
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
158
  refimg, reflabel = dsobj[0]
159
  print(refimg.shape, refimg.dtype, reflabel)
@@ -178,19 +209,138 @@ def runner(preset, dataset_path, device="cpu"):
178
  print(f"Computed score norms for {score_norms.shape[0]} samples")
179
 
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  if __name__ == "__main__":
182
  device = "cuda" if torch.cuda.is_available() else "cpu"
183
  preset = "edm2-img64-s-fid"
184
- # runner(
 
 
 
 
 
185
  # preset=preset,
186
  # dataset_path="/GROND_STOR/amahmood/datasets/img64/",
187
  # device="cuda",
188
  # )
189
- train_gmm(
190
- f"out/msma/{preset}_imagenette_score_norms.pt", outdir=f"out/msma/{preset}"
191
- )
192
- s = test_runner(device=device)
193
- s = s.square().sum(dim=(2, 3, 4)) ** 0.5
194
- s = s.to("cpu").numpy()
195
- nll, pct = compute_gmm_likelihood(s, gmmdir=f"out/msma/{preset}")
196
- print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile")
 
1
  import os
2
  import pickle
3
+ from functools import partial
4
  from pickle import dump, load
5
 
6
  import numpy as np
 
10
  from sklearn.model_selection import GridSearchCV
11
  from sklearn.pipeline import Pipeline
12
  from sklearn.preprocessing import StandardScaler
13
+ from torch.utils.data import Subset
14
  from tqdm import tqdm
15
 
16
  import dnnlib
17
+ from dataset import ImageFolderDataset
18
+ from flowutils import PatchFlow
19
 
20
  model_root = "https://nvlabs-fi-cdn.nvidia.com/edm2/posthoc-reconstructions"
21
 
 
26
  }
27
 
28
 
29
+ class StandardRGBEncoder:
30
+ def __init__(self):
31
+ super().__init__()
32
+
33
+ def encode(self, x): # raw pixels => final pixels
34
+ return x.to(torch.float32) / 127.5 - 1
35
+
36
+ def decode(self, x): # final latents => raw pixels
37
+ return (x.to(torch.float32) * 127.5 + 128).clip(0, 255).to(torch.uint8)
38
+
39
+
40
  class EDMScorer(torch.nn.Module):
41
  def __init__(
42
  self,
 
56
  self.sigma_max = sigma_max
57
  self.sigma_data = sigma_data
58
  self.net = net.eval()
59
+ self.encoder = StandardRGBEncoder()
60
 
61
  # Adjust noise levels based on how far we want to accumulate
62
  self.sigma_min = 1e-1
 
79
  x,
80
  force_fp32=False,
81
  ):
82
+ x = self.encoder.encode(x).to(torch.float32)
83
 
84
  batch_scores = []
85
  for sigma in self.sigma_steps:
 
92
  return batch_scores
93
 
94
 
95
+ class ScoreFlow(torch.nn.Module):
96
+ def __init__(
97
+ self,
98
+ scorenet,
99
+ vectorize=False,
100
+ device="cpu",
101
+ ):
102
+ super().__init__()
103
+
104
+ h = w = scorenet.net.img_resolution
105
+ c = scorenet.net.img_channels
106
+ num_sigmas = len(scorenet.sigma_steps)
107
+ self.flow = PatchFlow((num_sigmas, c, h, w))
108
+
109
+ self.flow = self.flow.to(device)
110
+ self.scorenet = scorenet.to(device).requires_grad_(False)
111
+ self.flow.init_weights()
112
+
113
+ def forward(self, x, **score_kwargs):
114
+ x_scores = self.scorenet(x, **score_kwargs)
115
+ return self.flow(x_scores)
116
+
117
+
118
  def build_model(preset="edm2-img64-s-fid", device="cpu"):
119
  netpath = config_presets[preset]
120
  with dnnlib.util.open_url(netpath, verbose=1) as f:
 
124
  return model
125
 
126
 
127
+ def quantile_scorer(gmm, X, y=None):
128
+ return np.quantile(gmm.score_samples(X), 0.1)
 
129
 
 
130
 
131
+ def train_gmm(score_path, outdir, grid_search=False):
132
+ X = torch.load(score_path)
 
 
 
 
 
 
133
 
134
+ gm = GaussianMixture(
135
+ n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000
 
 
 
 
 
136
  )
137
 
138
+ if grid_search:
139
+ clf = Pipeline([("scaler", StandardScaler()), ("GMM", gm)])
140
+ param_grid = dict(
141
+ GMM__n_components=range(2, 11, 1),
142
+ )
143
+
144
+ grid = GridSearchCV(
145
+ estimator=clf,
146
+ param_grid=param_grid,
147
+ cv=5,
148
+ n_jobs=2,
149
+ verbose=1,
150
+ scoring=quantile_scorer,
151
+ )
152
+
153
+ grid_result = grid.fit(X)
154
+
155
+ print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
156
+ print("-----" * 15)
157
+ means = grid_result.cv_results_["mean_test_score"]
158
+ stds = grid_result.cv_results_["std_test_score"]
159
+ params = grid_result.cv_results_["params"]
160
+ for mean, stdev, param in zip(means, stds, params):
161
+ print("%f (%f) with: %r" % (mean, stdev, param))
162
+ clf = grid.best_estimator_
163
 
164
+ clf.fit(X)
165
+ inlier_nll = -clf.score_samples(X)
 
 
 
 
 
 
 
166
 
167
  os.makedirs(outdir, exist_ok=True)
168
  with open(f"{outdir}/refscores.npz", "wb") as f:
 
177
  clf = load(f)
178
  nll = -clf.score_samples(x_score)
179
 
180
+ with np.load(f"{gmmdir}/refscores.npz", "rb") as f:
181
  ref_nll = f["arr_0"]
182
  percentile = (ref_nll < nll).mean()
183
 
184
  return nll, percentile
185
 
186
 
187
+ def cache_score_norms(preset, dataset_path, device="cpu"):
 
 
 
 
 
 
 
 
 
 
 
 
188
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
189
  refimg, reflabel = dsobj[0]
190
  print(refimg.shape, refimg.dtype, reflabel)
 
209
  print(f"Computed score norms for {score_norms.shape[0]} samples")
210
 
211
 
212
+ def train_flow(dataset_path, preset, device="cuda"):
213
+ dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
214
+ refimg, reflabel = dsobj[0]
215
+ print(f"Loaded {len(dsobj)} samples from {dataset_path}")
216
+
217
+ # Subset of training dataset
218
+ val_ratio = 0.1
219
+ train_len = int((1 - val_ratio) * len(dsobj))
220
+ val_len = len(dsobj) - train_len
221
+
222
+ print(
223
+ f"Generating train/test split with ratio={val_ratio} -> {train_len}/{val_len}..."
224
+ )
225
+ train_ds = Subset(dsobj, range(train_len))
226
+ val_ds = Subset(dsobj, range(train_len, train_len + val_len))
227
+
228
+ trainiter = torch.utils.data.DataLoader(
229
+ train_ds, batch_size=48, num_workers=4, prefetch_factor=2
230
+ )
231
+ testiter = torch.utils.data.DataLoader(
232
+ val_ds, batch_size=48, num_workers=4, prefetch_factor=2
233
+ )
234
+
235
+ model = ScoreFlow(build_model(preset=preset), device=device)
236
+ opt = torch.optim.AdamW(model.flow.parameters(), lr=3e-4, weight_decay=1e-5)
237
+ train_step = partial(
238
+ PatchFlow.stochastic_step,
239
+ flow_model=model.flow,
240
+ opt=opt,
241
+ train=True,
242
+ n_patches=64,
243
+ device=device,
244
+ )
245
+ eval_step = partial(
246
+ PatchFlow.stochastic_step,
247
+ flow_model=model.flow,
248
+ train=False,
249
+ n_patches=128,
250
+ device=device,
251
+ )
252
+
253
+ pbar = tqdm(trainiter, desc="Train Loss: ? - Val Loss: ?")
254
+ step = 0
255
+
256
+ for x, _ in tqdm(trainiter):
257
+ x = x.to(device)
258
+ scores = model.scorenet(x)
259
+
260
+ if step == 0:
261
+ with torch.inference_mode():
262
+ val_loss = eval_step(scores, x)
263
+
264
+ train_loss = train_step(scores, x)
265
+
266
+ if (step + 1) % 10 == 0:
267
+
268
+ with torch.inference_mode():
269
+ val_loss = 0.0
270
+ for i, (x, _) in enumerate(testiter):
271
+ x = x.to(device)
272
+ scores = model.scorenet(x)
273
+ val_loss += eval_step(scores, x)
274
+ break
275
+ val_loss /= i + 1
276
+
277
+ pbar.set_description(
278
+ f"Step: {step:d} - Train: {train_loss:.3f} - Val: {val_loss:.3f}"
279
+ )
280
+ step += 1
281
+
282
+ torch.save(model.flow.state_dict(), f"out/msma/{preset}/flow.pt")
283
+
284
+
285
+ @torch.inference_mode
286
+ def test_runner(device="cpu"):
287
+ # f = "doge.jpg"
288
+ f = "goldfish.JPEG"
289
+ image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS)
290
+ image = np.array(image)
291
+ image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
292
+ x = torch.from_numpy(image).unsqueeze(0).to(device)
293
+ model = build_model(device=device)
294
+ scores = model(x)
295
+
296
+ return scores
297
+
298
+
299
+ def test_flow_runner(device="cpu", load_weights=None):
300
+ f = "doge.jpg"
301
+ # f = "goldfish.JPEG"
302
+ image = (PIL.Image.open(f)).resize((64, 64), PIL.Image.Resampling.LANCZOS)
303
+ image = np.array(image)
304
+ image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
305
+ x = torch.from_numpy(image).unsqueeze(0).to(device)
306
+ model = build_model(device=device)
307
+
308
+ score_flow = ScoreFlow(scorenet=model, device=device)
309
+
310
+ if load_weights is not None:
311
+ score_flow.flow.load_state_dict(torch.load(load_weights))
312
+
313
+ heatmap = score_flow(x)
314
+ print(heatmap.shape)
315
+
316
+ heatmap = score_flow(x).detach().cpu().numpy()
317
+ heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) * 255
318
+ im = PIL.Image.fromarray(heatmap[0, 0])
319
+ im.convert("RGB").save(
320
+ "heatmap.png",
321
+ )
322
+
323
+ return
324
+
325
+
326
  if __name__ == "__main__":
327
  device = "cuda" if torch.cuda.is_available() else "cpu"
328
  preset = "edm2-img64-s-fid"
329
+ imagenette_path = "/GROND_STOR/amahmood/datasets/img64/"
330
+
331
+ train_flow(imagenette_path, preset, device)
332
+ test_flow_runner("cuda", f"out/msma/{preset}/flow.pt")
333
+
334
+ # cache_score_norms(
335
  # preset=preset,
336
  # dataset_path="/GROND_STOR/amahmood/datasets/img64/",
337
  # device="cuda",
338
  # )
339
+ # train_gmm(
340
+ # f"out/msma/{preset}_imagenette_score_norms.pt", outdir=f"out/msma/{preset}"
341
+ # )
342
+ # s = test_runner(device=device)
343
+ # s = s.square().sum(dim=(2, 3, 4)) ** 0.5
344
+ # s = s.to("cpu").numpy()
345
+ # nll, pct = compute_gmm_likelihood(s, gmmdir=f"out/msma/{preset}/")
346
+ # print(f"Anomaly score for image: {nll[0]:.3f} @ {pct*100:.2f} percentile")