Spaces:
Runtime error
Runtime error
uploading gmm and score norms
Browse files- push_to_hf.py +12 -6
push_to_hf.py
CHANGED
@@ -26,9 +26,13 @@ def main(basedir, preset):
|
|
26 |
modeldir = basedir / preset
|
27 |
|
28 |
net = build_model_from_pickle(preset)
|
|
|
|
|
|
|
29 |
model = ScoreFlow(
|
30 |
net,
|
31 |
-
|
|
|
32 |
)
|
33 |
model.flow.load_state_dict(torch.load(modeldir / "flow.pt"))
|
34 |
|
@@ -46,10 +50,13 @@ def main(basedir, preset):
|
|
46 |
save_file(model.state_dict(), tmpdir / "model.safetensors")
|
47 |
|
48 |
# save config
|
49 |
-
(tmpdir / "config.json").write_text(
|
50 |
-
|
51 |
-
|
52 |
|
|
|
|
|
|
|
53 |
|
54 |
# Generate model card
|
55 |
# card = generate_model_card(model)
|
@@ -57,8 +64,7 @@ def main(basedir, preset):
|
|
57 |
|
58 |
# Save logs
|
59 |
shutil.copytree(modeldir / "logs", tmpdir / "logs")
|
60 |
-
|
61 |
-
|
62 |
# Push to hub
|
63 |
api.upload_folder(repo_id=repo_id, path_in_repo=preset, folder_path=tmpdir)
|
64 |
|
|
|
26 |
modeldir = basedir / preset
|
27 |
|
28 |
net = build_model_from_pickle(preset)
|
29 |
+
with open(modeldir / "config.json", "rb") as f:
|
30 |
+
model_params = json.load(f)
|
31 |
+
|
32 |
model = ScoreFlow(
|
33 |
net,
|
34 |
+
device="cpu",
|
35 |
+
**model_params["PatchFlow"],
|
36 |
)
|
37 |
model.flow.load_state_dict(torch.load(modeldir / "flow.pt"))
|
38 |
|
|
|
50 |
save_file(model.state_dict(), tmpdir / "model.safetensors")
|
51 |
|
52 |
# save config
|
53 |
+
(tmpdir / "config.json").write_text(
|
54 |
+
json.dumps(model.config, sort_keys=True, indent=4)
|
55 |
+
)
|
56 |
|
57 |
+
# save gmm and cached score norms
|
58 |
+
shutil.copyfile(modeldir / "gmm.pkl", tmpdir / "gmm.pkl")
|
59 |
+
shutil.copyfile(modeldir / "refscores.npz", tmpdir / "refscores.npz")
|
60 |
|
61 |
# Generate model card
|
62 |
# card = generate_model_card(model)
|
|
|
64 |
|
65 |
# Save logs
|
66 |
shutil.copytree(modeldir / "logs", tmpdir / "logs")
|
67 |
+
|
|
|
68 |
# Push to hub
|
69 |
api.upload_folder(repo_id=repo_id, path_in_repo=preset, folder_path=tmpdir)
|
70 |
|