ahsanMah commited on
Commit
bddc1f1
·
1 Parent(s): ffaef20

uploading gmm and score norms

Browse files
Files changed (1) hide show
  1. push_to_hf.py +12 -6
push_to_hf.py CHANGED
@@ -26,9 +26,13 @@ def main(basedir, preset):
26
  modeldir = basedir / preset
27
 
28
  net = build_model_from_pickle(preset)
 
 
 
29
  model = ScoreFlow(
30
  net,
31
- num_flows=8,
 
32
  )
33
  model.flow.load_state_dict(torch.load(modeldir / "flow.pt"))
34
 
@@ -46,10 +50,13 @@ def main(basedir, preset):
46
  save_file(model.state_dict(), tmpdir / "model.safetensors")
47
 
48
  # save config
49
- (tmpdir / "config.json").write_text(json.dumps(model.config, sort_keys=True, indent=4))
50
-
51
- # TODO: save gmm and cached score norms
52
 
 
 
 
53
 
54
  # Generate model card
55
  # card = generate_model_card(model)
@@ -57,8 +64,7 @@ def main(basedir, preset):
57
 
58
  # Save logs
59
  shutil.copytree(modeldir / "logs", tmpdir / "logs")
60
-
61
-
62
  # Push to hub
63
  api.upload_folder(repo_id=repo_id, path_in_repo=preset, folder_path=tmpdir)
64
 
 
26
  modeldir = basedir / preset
27
 
28
  net = build_model_from_pickle(preset)
29
+ with open(modeldir / "config.json", "rb") as f:
30
+ model_params = json.load(f)
31
+
32
  model = ScoreFlow(
33
  net,
34
+ device="cpu",
35
+ **model_params["PatchFlow"],
36
  )
37
  model.flow.load_state_dict(torch.load(modeldir / "flow.pt"))
38
 
 
50
  save_file(model.state_dict(), tmpdir / "model.safetensors")
51
 
52
  # save config
53
+ (tmpdir / "config.json").write_text(
54
+ json.dumps(model.config, sort_keys=True, indent=4)
55
+ )
56
 
57
+ # save gmm and cached score norms
58
+ shutil.copyfile(modeldir / "gmm.pkl", tmpdir / "gmm.pkl")
59
+ shutil.copyfile(modeldir / "refscores.npz", tmpdir / "refscores.npz")
60
 
61
  # Generate model card
62
  # card = generate_model_card(model)
 
64
 
65
  # Save logs
66
  shutil.copytree(modeldir / "logs", tmpdir / "logs")
67
+
 
68
  # Push to hub
69
  api.upload_folder(repo_id=repo_id, path_in_repo=preset, folder_path=tmpdir)
70