Spaces:
Sleeping
Sleeping
remove ema parts in checkpoints.
Browse files- app.py +3 -3
- model_states.pt → model_wo_ema.ckpt +2 -2
- transfer.py +8 -2
app.py
CHANGED
@@ -66,11 +66,11 @@ def process_multi_wrapper_only_show_rendered(rendered_txt_0, rendered_txt_1, ren
|
|
66 |
shared_eta, shared_a_prompt, shared_n_prompt,
|
67 |
only_show_rendered_image=True)
|
68 |
|
69 |
-
# cfg = OmegaConf.load("config.yaml")
|
70 |
-
# model = load_model_from_config(cfg, "model_states.pt", verbose=True)
|
71 |
|
72 |
cfg = OmegaConf.load("config.yaml")
|
73 |
-
model = load_model_from_config(cfg, "
|
|
|
|
|
74 |
|
75 |
ddim_sampler = DDIMSampler(model)
|
76 |
render_tool = Render_Text(model)
|
|
|
66 |
shared_eta, shared_a_prompt, shared_n_prompt,
|
67 |
only_show_rendered_image=True)
|
68 |
|
|
|
|
|
69 |
|
70 |
cfg = OmegaConf.load("config.yaml")
|
71 |
+
model = load_model_from_config(cfg, "model_wo_ema.ckpt", verbose=True)
|
72 |
+
# model = load_model_from_config(cfg, "model_states.pt", verbose=True)
|
73 |
+
# model = load_model_from_config(cfg, "model.ckpt", verbose=True)
|
74 |
|
75 |
ddim_sampler = DDIMSampler(model)
|
76 |
render_tool = Render_Text(model)
|
model_states.pt → model_wo_ema.ckpt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0b86b22188bf580e80773a5ae101bf9787eb258349f3f1acf0ae50fd10cb3fec
|
3 |
+
size 6671922039
|
transfer.py
CHANGED
@@ -6,9 +6,15 @@ model = load_model_from_config(cfg, "model_states.pt", verbose=True)
|
|
6 |
|
7 |
from pytorch_lightning.callbacks import ModelCheckpoint
|
8 |
with model.ema_scope("store ema weights"):
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
file_content = {
|
10 |
-
'state_dict':
|
11 |
}
|
12 |
-
torch.save(file_content, "
|
13 |
print("has stored the transfered ckpt.")
|
14 |
print("trial ends!")
|
|
|
6 |
|
7 |
from pytorch_lightning.callbacks import ModelCheckpoint
|
8 |
with model.ema_scope("store ema weights"):
|
9 |
+
model_sd = model.state_dict()
|
10 |
+
store_sd = {}
|
11 |
+
for key in model_sd:
|
12 |
+
if "ema" in key:
|
13 |
+
continue
|
14 |
+
store_sd[key] = model_sd[key]
|
15 |
file_content = {
|
16 |
+
'state_dict': store_sd
|
17 |
}
|
18 |
+
torch.save(file_content, "model_wo_ema.ckpt")
|
19 |
print("has stored the transfered ckpt.")
|
20 |
print("trial ends!")
|