Spaces:
Runtime error
Runtime error
fixed gridsearch arg
Browse files
msma.py
CHANGED
@@ -223,16 +223,23 @@ def common_args(func):
|
|
223 |
return wrapper
|
224 |
|
225 |
@cmdline.command('train-gmm')
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
@common_args
|
227 |
-
def train_gmm(
|
228 |
-
|
|
|
229 |
|
230 |
gm = GaussianMixture(
|
231 |
n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000
|
232 |
)
|
233 |
clf = Pipeline([("scaler", StandardScaler()), ("GMM", gm)])
|
234 |
|
235 |
-
if
|
236 |
param_grid = dict(
|
237 |
GMM__n_components=range(2, 11, 1),
|
238 |
)
|
@@ -369,6 +376,9 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
|
|
369 |
with open(f"{experiment_dir}/logs/{timestamp}/config.json", "w") as f:
|
370 |
json.dump(model.config, f, sort_keys=True, indent=4)
|
371 |
|
|
|
|
|
|
|
372 |
# totaliters = int(epochs * train_len)
|
373 |
pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?")
|
374 |
step = 0
|
@@ -433,8 +443,6 @@ def train_flow(dataset_path, preset, outdir, epochs, **flow_kwargs):
|
|
433 |
|
434 |
# Save final model
|
435 |
torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
|
436 |
-
with open(f"{experiment_dir}/config.json", "w") as f:
|
437 |
-
json.dump(model.config, f, sort_keys=True, indent=4)
|
438 |
|
439 |
writer.close()
|
440 |
|
|
|
223 |
return wrapper
|
224 |
|
225 |
@cmdline.command('train-gmm')
|
226 |
+
@click.option(
|
227 |
+
"--gridsearch",
|
228 |
+
help="Whether to use a grid search on a number of components to find the best fit",
|
229 |
+
is_flag=True,
|
230 |
+
default=False,
|
231 |
+
)
|
232 |
@common_args
|
233 |
+
def train_gmm(preset, outdir, gridsearch=False, **kwargs):
|
234 |
+
score_path = f"{outdir}/{preset}/imagenette_score_norms.pt"
|
235 |
+
X = torch.load(score_path).numpy()
|
236 |
|
237 |
gm = GaussianMixture(
|
238 |
n_components=7, init_params="kmeans", covariance_type="full", max_iter=100000
|
239 |
)
|
240 |
clf = Pipeline([("scaler", StandardScaler()), ("GMM", gm)])
|
241 |
|
242 |
+
if gridsearch:
|
243 |
param_grid = dict(
|
244 |
GMM__n_components=range(2, 11, 1),
|
245 |
)
|
|
|
376 |
with open(f"{experiment_dir}/logs/{timestamp}/config.json", "w") as f:
|
377 |
json.dump(model.config, f, sort_keys=True, indent=4)
|
378 |
|
379 |
+
with open(f"{experiment_dir}/config.json", "w") as f:
|
380 |
+
json.dump(model.config, f, sort_keys=True, indent=4)
|
381 |
+
|
382 |
# totaliters = int(epochs * train_len)
|
383 |
pbar = tqdm(range(epochs), desc="Train Loss: ? - Val Loss: ?")
|
384 |
step = 0
|
|
|
443 |
|
444 |
# Save final model
|
445 |
torch.save(model.flow.state_dict(), f"{experiment_dir}/flow.pt")
|
|
|
|
|
446 |
|
447 |
writer.close()
|
448 |
|