Spaces:
Runtime error
Runtime error
utility script for pushing to HF hub
Browse files- push_to_hf.py +50 -25
push_to_hf.py
CHANGED
@@ -1,42 +1,67 @@
|
|
|
|
1 |
import shutil
|
2 |
from pathlib import Path
|
3 |
from tempfile import TemporaryDirectory
|
4 |
|
|
|
5 |
import torch
|
6 |
from huggingface_hub import HfApi
|
7 |
from safetensors.torch import save_file
|
8 |
|
9 |
-
from msma import ScoreFlow
|
10 |
|
11 |
-
basedir = Path("models/condgauss")
|
12 |
-
preset = "edm2-img64-s-fid"
|
13 |
-
modeldir = basedir / preset
|
14 |
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
|
22 |
-
|
23 |
|
24 |
-
#
|
25 |
-
|
26 |
-
tmpdir = Path(tmpdir)
|
27 |
|
28 |
-
# Save
|
29 |
-
|
|
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
# (tmpdir / "README.md").write_text(card)
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
# ...
|
40 |
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
import shutil
|
3 |
from pathlib import Path
|
4 |
from tempfile import TemporaryDirectory
|
5 |
|
6 |
+
import click
|
7 |
import torch
|
8 |
from huggingface_hub import HfApi
|
9 |
from safetensors.torch import save_file
|
10 |
|
11 |
+
from msma import EDMScorer, ScoreFlow, build_model_from_pickle
|
12 |
|
|
|
|
|
|
|
13 |
|
14 |
+
@click.command
|
15 |
+
@click.option(
|
16 |
+
"--basedir",
|
17 |
+
help="Directory holding the model weights and logs",
|
18 |
+
type=str,
|
19 |
+
required=True,
|
20 |
+
)
|
21 |
+
@click.option(
|
22 |
+
"--preset", help="Preset of the score model used", type=str, required=True
|
23 |
+
)
|
24 |
+
def main(basedir, preset):
|
25 |
+
basedir = Path(basedir)
|
26 |
+
modeldir = basedir / preset
|
27 |
|
28 |
+
net = build_model_from_pickle(preset)
|
29 |
+
model = ScoreFlow(
|
30 |
+
net,
|
31 |
+
num_flows=8,
|
32 |
+
)
|
33 |
+
model.flow.load_state_dict(torch.load(modeldir / "flow.pt"))
|
34 |
|
35 |
+
api = HfApi()
|
36 |
+
repo_name = "ahsanMah/localizing-edm"
|
37 |
|
38 |
+
# Create repo if not existing yet and get the associated repo_id
|
39 |
+
repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
|
|
|
40 |
|
41 |
+
# Save all files in a temporary directory and push them in a single commit
|
42 |
+
with TemporaryDirectory() as tmpdir:
|
43 |
+
tmpdir = Path(tmpdir)
|
44 |
|
45 |
+
# Save weights
|
46 |
+
save_file(model.state_dict(), tmpdir / "model.safetensors")
|
|
|
47 |
|
48 |
+
# save config
|
49 |
+
(tmpdir / "config.json").write_text(json.dumps(model.config, sort_keys=True, indent=4))
|
50 |
+
|
51 |
+
# TODO: save gmm and cached score norms
|
|
|
52 |
|
53 |
+
|
54 |
+
# Generate model card
|
55 |
+
# card = generate_model_card(model)
|
56 |
+
# (tmpdir / "README.md").write_text(card)
|
57 |
+
|
58 |
+
# Save logs
|
59 |
+
shutil.copytree(modeldir / "logs", tmpdir / "logs")
|
60 |
+
|
61 |
+
|
62 |
+
# Push to hub
|
63 |
+
api.upload_folder(repo_id=repo_id, path_in_repo=preset, folder_path=tmpdir)
|
64 |
+
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
main()
|