ahsanMah commited on
Commit
8933ee4
·
1 Parent(s): 6e569fe

utility script for pushing to HF hub

Browse files
Files changed (1) hide show
  1. push_to_hf.py +50 -25
push_to_hf.py CHANGED
@@ -1,42 +1,67 @@
 
1
  import shutil
2
  from pathlib import Path
3
  from tempfile import TemporaryDirectory
4
 
 
5
  import torch
6
  from huggingface_hub import HfApi
7
  from safetensors.torch import save_file
8
 
9
- from msma import ScoreFlow
10
 
11
- basedir = Path("models/condgauss")
12
- preset = "edm2-img64-s-fid"
13
- modeldir = basedir / preset
14
 
15
- model = ScoreFlow(preset)
16
- model.flow.load_state_dict(torch.load(modeldir / "flow.pt"))
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- api = HfApi()
19
- repo_name = "ahsanMah/localizing-edm"
 
 
 
 
20
 
21
- # Create repo if not existing yet and get the associated repo_id
22
- repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
23
 
24
- # Save all files in a temporary directory and push them in a single commit
25
- with TemporaryDirectory() as tmpdir:
26
- tmpdir = Path(tmpdir)
27
 
28
- # Save weights
29
- save_file(model.state_dict(), tmpdir / "model.safetensors")
 
30
 
31
- # Generate model card
32
- # card = generate_model_card(model)
33
- # (tmpdir / "README.md").write_text(card)
34
 
35
- # Save logs
36
- shutil.copytree(modeldir / "logs", tmpdir / "logs")
37
- # Save figures
38
- # Save evaluation metrics
39
- # ...
40
 
41
- # Push to hub
42
- api.upload_folder(repo_id=repo_id, folder_path=tmpdir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
  import shutil
3
  from pathlib import Path
4
  from tempfile import TemporaryDirectory
5
 
6
+ import click
7
  import torch
8
  from huggingface_hub import HfApi
9
  from safetensors.torch import save_file
10
 
11
+ from msma import EDMScorer, ScoreFlow, build_model_from_pickle
12
 
 
 
 
13
 
14
+ @click.command
15
+ @click.option(
16
+ "--basedir",
17
+ help="Directory holding the model weights and logs",
18
+ type=str,
19
+ required=True,
20
+ )
21
+ @click.option(
22
+ "--preset", help="Preset of the score model used", type=str, required=True
23
+ )
24
+ def main(basedir, preset):
25
+ basedir = Path(basedir)
26
+ modeldir = basedir / preset
27
 
28
+ net = build_model_from_pickle(preset)
29
+ model = ScoreFlow(
30
+ net,
31
+ num_flows=8,
32
+ )
33
+ model.flow.load_state_dict(torch.load(modeldir / "flow.pt"))
34
 
35
+ api = HfApi()
36
+ repo_name = "ahsanMah/localizing-edm"
37
 
38
+ # Create repo if not existing yet and get the associated repo_id
39
+ repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
 
40
 
41
+ # Save all files in a temporary directory and push them in a single commit
42
+ with TemporaryDirectory() as tmpdir:
43
+ tmpdir = Path(tmpdir)
44
 
45
+ # Save weights
46
+ save_file(model.state_dict(), tmpdir / "model.safetensors")
 
47
 
48
+ # save config
49
+ (tmpdir / "config.json").write_text(json.dumps(model.config, sort_keys=True, indent=4))
50
+
51
+ # TODO: save gmm and cached score norms
 
52
 
53
+
54
+ # Generate model card
55
+ # card = generate_model_card(model)
56
+ # (tmpdir / "README.md").write_text(card)
57
+
58
+ # Save logs
59
+ shutil.copytree(modeldir / "logs", tmpdir / "logs")
60
+
61
+
62
+ # Push to hub
63
+ api.upload_folder(repo_id=repo_id, path_in_repo=preset, folder_path=tmpdir)
64
+
65
+
66
+ if __name__ == "__main__":
67
+ main()