ahsanMah commited on
Commit
7358cfe
·
1 Parent(s): 1dcd56b

consolidated cmdline args

Browse files
Files changed (1) hide show
  1. msma.py +72 -83
msma.py CHANGED
@@ -2,7 +2,7 @@ import datetime
2
  import json
3
  import os
4
  import pickle
5
- from functools import partial
6
  from pickle import dump, load
7
  from typing import Literal
8
 
@@ -135,50 +135,6 @@ def quantile_scorer(gmm, X, y=None):
135
  return np.quantile(gmm.score_samples(X), 0.1)
136
 
137
 
138
- def train_gmm(score_path, outdir, grid_search=False):
139
- X = torch.load(score_path)
140
-
141
- gm = GaussianMixture(
142
- n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000
143
- )
144
- clf = Pipeline([("scaler", StandardScaler()), ("GMM", gm)])
145
-
146
- if grid_search:
147
- param_grid = dict(
148
- GMM__n_components=range(2, 11, 1),
149
- )
150
-
151
- grid = GridSearchCV(
152
- estimator=clf,
153
- param_grid=param_grid,
154
- cv=5,
155
- n_jobs=2,
156
- verbose=1,
157
- scoring=quantile_scorer,
158
- )
159
-
160
- grid_result = grid.fit(X)
161
-
162
- print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
163
- print("-----" * 15)
164
- means = grid_result.cv_results_["mean_test_score"]
165
- stds = grid_result.cv_results_["std_test_score"]
166
- params = grid_result.cv_results_["params"]
167
- for mean, stdev, param in zip(means, stds, params):
168
- print("%f (%f) with: %r" % (mean, stdev, param))
169
- clf = grid.best_estimator_
170
-
171
- clf.fit(X)
172
- inlier_nll = -clf.score_samples(X)
173
-
174
- os.makedirs(outdir, exist_ok=True)
175
- with open(f"{outdir}/refscores.npz", "wb") as f:
176
- np.savez_compressed(f, inlier_nll)
177
-
178
- with open(f"{outdir}/gmm.pkl", "wb") as f:
179
- dump(clf, f, protocol=5)
180
-
181
-
182
  def compute_gmm_likelihood(x_score, gmmdir):
183
  with open(f"{gmmdir}/gmm.pkl", "rb") as f:
184
  clf = load(f)
@@ -237,8 +193,9 @@ def cmdline():
237
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
238
 
239
 
240
- @cmdline.command(name="cache-scores")
241
- @click.option(
 
242
  "--preset",
243
  help="Configuration preset",
244
  metavar="STR",
@@ -246,20 +203,73 @@ def cmdline():
246
  default="edm2-img64-s-fid",
247
  show_default=True,
248
  )
249
- @click.option(
250
- "--dataset_path",
251
- help="Path to the dataset",
252
- metavar="ZIP|DIR",
253
- type=str,
254
- default=None,
255
- )
256
- @click.option(
257
- "--outdir",
258
- help="Where to load/save the results",
259
- metavar="DIR",
260
- type=str,
261
- required=True,
262
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  def cache_score_norms(preset, dataset_path, outdir):
264
  device = DEVICE
265
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
@@ -290,28 +300,6 @@ def cache_score_norms(preset, dataset_path, outdir):
290
 
291
 
292
  @cmdline.command(name="train-flow")
293
- @click.option(
294
- "--dataset_path",
295
- help="Path to the dataset",
296
- metavar="ZIP|DIR",
297
- type=str,
298
- default=None,
299
- )
300
- @click.option(
301
- "--outdir",
302
- help="Where to load/save the results",
303
- metavar="DIR",
304
- type=str,
305
- required=True,
306
- )
307
- @click.option(
308
- "--preset",
309
- help="Configuration preset",
310
- metavar="STR",
311
- type=str,
312
- default="edm2-img64-s-fid",
313
- show_default=True,
314
- )
315
  @click.option(
316
  "--epochs",
317
  help="Number of epochs",
@@ -328,6 +316,7 @@ def cache_score_norms(preset, dataset_path, outdir):
328
  default=4,
329
  show_default=True,
330
  )
 
331
  def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
332
  print("using device:", DEVICE)
333
  device = DEVICE
 
2
  import json
3
  import os
4
  import pickle
5
+ from functools import partial, wraps
6
  from pickle import dump, load
7
  from typing import Literal
8
 
 
135
  return np.quantile(gmm.score_samples(X), 0.1)
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def compute_gmm_likelihood(x_score, gmmdir):
139
  with open(f"{gmmdir}/gmm.pkl", "rb") as f:
140
  clf = load(f)
 
193
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
194
 
195
 
196
+ def common_args(func):
197
+ @wraps(func)
198
+ @click.option(
199
  "--preset",
200
  help="Configuration preset",
201
  metavar="STR",
 
203
  default="edm2-img64-s-fid",
204
  show_default=True,
205
  )
206
+ @click.option(
207
+ "--dataset_path",
208
+ help="Path to the dataset",
209
+ metavar="ZIP|DIR",
210
+ type=str,
211
+ default=None,
212
+ )
213
+ @click.option(
214
+ "--outdir",
215
+ help="Where to load/save the results",
216
+ metavar="DIR",
217
+ type=str,
218
+ required=True,
219
+ )
220
+ def wrapper(*args, **kwargs):
221
+ return func(*args, **kwargs)
222
+
223
+ return wrapper
224
+
225
+ @cmdline.command('train-gmm')
226
+ @common_args
227
+ def train_gmm(score_path, outdir, grid_search=False):
228
+ X = torch.load(score_path)
229
+
230
+ gm = GaussianMixture(
231
+ n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000
232
+ )
233
+ clf = Pipeline([("scaler", StandardScaler()), ("GMM", gm)])
234
+
235
+ if grid_search:
236
+ param_grid = dict(
237
+ GMM__n_components=range(2, 11, 1),
238
+ )
239
+
240
+ grid = GridSearchCV(
241
+ estimator=clf,
242
+ param_grid=param_grid,
243
+ cv=5,
244
+ n_jobs=2,
245
+ verbose=1,
246
+ scoring=quantile_scorer,
247
+ )
248
+
249
+ grid_result = grid.fit(X)
250
+
251
+ print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
252
+ print("-----" * 15)
253
+ means = grid_result.cv_results_["mean_test_score"]
254
+ stds = grid_result.cv_results_["std_test_score"]
255
+ params = grid_result.cv_results_["params"]
256
+ for mean, stdev, param in zip(means, stds, params):
257
+ print("%f (%f) with: %r" % (mean, stdev, param))
258
+ clf = grid.best_estimator_
259
+
260
+ clf.fit(X)
261
+ inlier_nll = -clf.score_samples(X)
262
+
263
+ os.makedirs(outdir, exist_ok=True)
264
+ with open(f"{outdir}/refscores.npz", "wb") as f:
265
+ np.savez_compressed(f, inlier_nll)
266
+
267
+ with open(f"{outdir}/gmm.pkl", "wb") as f:
268
+ dump(clf, f, protocol=5)
269
+
270
+
271
+ @cmdline.command(name="cache-scores")
272
+ @common_args
273
  def cache_score_norms(preset, dataset_path, outdir):
274
  device = DEVICE
275
  dsobj = ImageFolderDataset(path=dataset_path, resolution=64)
 
300
 
301
 
302
  @cmdline.command(name="train-flow")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  @click.option(
304
  "--epochs",
305
  help="Number of epochs",
 
316
  default=4,
317
  show_default=True,
318
  )
319
+ @common_args
320
  def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
321
  print("using device:", DEVICE)
322
  device = DEVICE