ahsanMah commited on
Commit
850e111
·
1 Parent(s): bddc1f1

fixed gridsearch arg

Browse files
Files changed (1) hide show
  1. msma.py +13 -5
msma.py CHANGED
@@ -223,16 +223,23 @@ def common_args(func):
223
  return wrapper
224
 
225
  @cmdline.command('train-gmm')
 
 
 
 
 
 
226
  @common_args
227
- def train_gmm(score_path, outdir, grid_search=False):
228
- X = torch.load(score_path)
 
229
 
230
  gm = GaussianMixture(
231
  n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000
232
  )
233
  clf = Pipeline([("scaler", StandardScaler()), ("GMM", gm)])
234
 
235
- if grid_search:
236
  param_grid = dict(
237
  GMM__n_components=range(2, 11, 1),
238
  )
@@ -369,6 +376,9 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
369
  with open(f"{experiment_dir}/logs/{timestamp}/config.json", "w") as f:
370
  json.dump(model.config, f, sort_keys=True, indent=4)
371
 
 
 
 
372
  # totaliters = int(epochs * train_len)
373
  pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?")
374
  step = 0
@@ -433,8 +443,6 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
433
 
434
  # Save final model
435
  torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
436
- with open(f"{experiment_dir}/config.json", "w") as f:
437
- json.dump(model.config, f, sort_keys=True, indent=4)
438
 
439
  writer.close()
440
 
 
223
  return wrapper
224
 
225
  @cmdline.command('train-gmm')
226
+ @click.option(
227
+ "--gridsearch",
228
+ help="Whether to use a grid search on a number of components to find the best fit",
229
+ is_flag=True,
230
+ default=False,
231
+ )
232
  @common_args
233
+ def train_gmm(preset, outdir, gridsearch=False, **kwargs):
234
+ score_path = f"{outdir}/{preset}/imagenette_score_norms.pt"
235
+ X = torch.load(score_path).numpy()
236
 
237
  gm = GaussianMixture(
238
  n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000
239
  )
240
  clf = Pipeline([("scaler", StandardScaler()), ("GMM", gm)])
241
 
242
+ if gridsearch:
243
  param_grid = dict(
244
  GMM__n_components=range(2, 11, 1),
245
  )
 
376
  with open(f"{experiment_dir}/logs/{timestamp}/config.json", "w") as f:
377
  json.dump(model.config, f, sort_keys=True, indent=4)
378
 
379
+ with open(f"{experiment_dir}/config.json", "w") as f:
380
+ json.dump(model.config, f, sort_keys=True, indent=4)
381
+
382
  # totaliters = int(epochs * train_len)
383
  pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?")
384
  step = 0
 
443
 
444
  # Save final model
445
  torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
 
 
446
 
447
  writer.close()
448