File size: 1,997 Bytes
8933ee4
1b96548
 
 
 
8933ee4
1b96548
 
 
 
8933ee4
1b96548
 
8933ee4
 
 
 
 
 
 
 
 
 
 
 
 
1b96548
8933ee4
bddc1f1
 
 
8933ee4
 
bddc1f1
 
8933ee4
 
1b96548
8933ee4
 
1b96548
8933ee4
 
1b96548
8933ee4
 
 
1b96548
8933ee4
 
1b96548
8933ee4
bddc1f1
 
 
1b96548
bddc1f1
 
 
8933ee4
 
 
 
 
 
 
bddc1f1
8933ee4
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import json
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory

import click
import torch
from huggingface_hub import HfApi
from safetensors.torch import save_file

from msma import EDMScorer, ScoreFlow, build_model_from_pickle


@click.command
@click.option(
    "--basedir",
    help="Directory holding the model weights and logs",
    type=str,
    required=True,
)
@click.option(
    "--preset", help="Preset of the score model used", type=str, required=True
)
def main(basedir, preset):
    basedir = Path(basedir)
    modeldir = basedir / preset

    net = build_model_from_pickle(preset)
    with open(modeldir / "config.json", "rb") as f:
        model_params = json.load(f)

    model = ScoreFlow(
        net,
        device="cpu",
        **model_params["PatchFlow"],
    )
    model.flow.load_state_dict(torch.load(modeldir / "flow.pt"))

    api = HfApi()
    repo_name = "ahsanMah/localizing-edm"

    # Create repo if not existing yet and get the associated repo_id
    repo_id = api.create_repo(repo_name, exist_ok=True).repo_id

    # Save all files in a temporary directory and push them in a single commit
    with TemporaryDirectory() as tmpdir:
        tmpdir = Path(tmpdir)

        # Save weights
        save_file(model.state_dict(), tmpdir / "model.safetensors")

        # save config
        (tmpdir / "config.json").write_text(
            json.dumps(model.config, sort_keys=True, indent=4)
        )

        # save gmm and cached score norms
        shutil.copyfile(modeldir / "gmm.pkl", tmpdir / "gmm.pkl")
        shutil.copyfile(modeldir / "refscores.npz", tmpdir / "refscores.npz")

        # Generate model card
        # card = generate_model_card(model)
        # (tmpdir / "README.md").write_text(card)

        # Save logs
        shutil.copytree(modeldir / "logs", tmpdir / "logs")

        # Push to hub
        api.upload_folder(repo_id=repo_id, path_in_repo=preset, folder_path=tmpdir)


if __name__ == "__main__":
    main()