Spaces:
Runtime error
Runtime error
File size: 1,997 Bytes
8933ee4 1b96548 8933ee4 1b96548 8933ee4 1b96548 8933ee4 1b96548 8933ee4 bddc1f1 8933ee4 bddc1f1 8933ee4 1b96548 8933ee4 1b96548 8933ee4 1b96548 8933ee4 1b96548 8933ee4 1b96548 8933ee4 bddc1f1 1b96548 bddc1f1 8933ee4 bddc1f1 8933ee4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import json
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
import click
import torch
from huggingface_hub import HfApi
from safetensors.torch import save_file
from msma import EDMScorer, ScoreFlow, build_model_from_pickle
@click.command
@click.option(
"--basedir",
help="Directory holding the model weights and logs",
type=str,
required=True,
)
@click.option(
"--preset", help="Preset of the score model used", type=str, required=True
)
def main(basedir, preset):
basedir = Path(basedir)
modeldir = basedir / preset
net = build_model_from_pickle(preset)
with open(modeldir / "config.json", "rb") as f:
model_params = json.load(f)
model = ScoreFlow(
net,
device="cpu",
**model_params["PatchFlow"],
)
model.flow.load_state_dict(torch.load(modeldir / "flow.pt"))
api = HfApi()
repo_name = "ahsanMah/localizing-edm"
# Create repo if not existing yet and get the associated repo_id
repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
# Save all files in a temporary directory and push them in a single commit
with TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
# Save weights
save_file(model.state_dict(), tmpdir / "model.safetensors")
# save config
(tmpdir / "config.json").write_text(
json.dumps(model.config, sort_keys=True, indent=4)
)
# save gmm and cached score norms
shutil.copyfile(modeldir / "gmm.pkl", tmpdir / "gmm.pkl")
shutil.copyfile(modeldir / "refscores.npz", tmpdir / "refscores.npz")
# Generate model card
# card = generate_model_card(model)
# (tmpdir / "README.md").write_text(card)
# Save logs
shutil.copytree(modeldir / "logs", tmpdir / "logs")
# Push to hub
api.upload_folder(repo_id=repo_id, path_in_repo=preset, folder_path=tmpdir)
if __name__ == "__main__":
main()
|