Spaces:
Build error
Build error
better conditioning on weight porting.
Browse files- convert.py +8 -22
convert.py
CHANGED
@@ -6,8 +6,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import \
|
|
6 |
StableDiffusionSafetyChecker
|
7 |
from transformers import CLIPTextModel
|
8 |
|
9 |
-
from conversion_utils import
|
10 |
-
run_assertion)
|
11 |
|
12 |
PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
|
13 |
REVISION = None
|
@@ -68,31 +67,18 @@ def run_conversion(text_encoder_weights: str = None, unet_weights: str = None):
|
|
68 |
print("Loading fine-tuned text encoder weights.")
|
69 |
text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
|
70 |
tf_text_encoder.load_weights(text_encoder_weights_path)
|
|
|
|
|
|
|
71 |
if unet_weights is not None:
|
72 |
print("Loading fine-tuned UNet weights.")
|
73 |
unet_weights_path = tf.keras.utils.get_file(origin=unet_weights)
|
74 |
tf_unet.load_weights(unet_weights_path)
|
|
|
|
|
|
|
75 |
|
76 |
-
|
77 |
-
unet_state_dict_from_tf = populate_unet(tf_unet)
|
78 |
-
print("Conversion done, now running optional assertions...")
|
79 |
-
|
80 |
-
# Since we cannot compare the fine-tuned weights.
|
81 |
-
if text_encoder_weights is None:
|
82 |
-
text_encoder_state_dict_from_pt = pt_text_encoder.state_dict()
|
83 |
-
run_assertion(text_encoder_state_dict_from_pt, text_encoder_state_dict_from_tf)
|
84 |
-
if unet_weights is None:
|
85 |
-
unet_state_dict_from_pt = pt_unet.state_dict()
|
86 |
-
run_assertion(unet_state_dict_from_pt, unet_state_dict_from_tf)
|
87 |
-
|
88 |
-
if text_encoder_weights is None or unet_weights is None:
|
89 |
-
print(
|
90 |
-
"Assertions successful, populating the converted parameters into the diffusers models..."
|
91 |
-
)
|
92 |
-
pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
|
93 |
-
pt_unet.load_state_dict(unet_state_dict_from_tf)
|
94 |
-
|
95 |
-
print("Parameters ported, preparing StabelDiffusionPipeline...")
|
96 |
pipeline = StableDiffusionPipeline.from_pretrained(
|
97 |
PRETRAINED_CKPT,
|
98 |
unet=pt_unet,
|
|
|
6 |
StableDiffusionSafetyChecker
|
7 |
from transformers import CLIPTextModel
|
8 |
|
9 |
+
from conversion_utils import populate_text_encoder, populate_unet
|
|
|
10 |
|
11 |
PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
|
12 |
REVISION = None
|
|
|
67 |
print("Loading fine-tuned text encoder weights.")
|
68 |
text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
|
69 |
tf_text_encoder.load_weights(text_encoder_weights_path)
|
70 |
+
text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)
|
71 |
+
pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
|
72 |
+
print("Populated PT text encoder from TF weights.")
|
73 |
if unet_weights is not None:
|
74 |
print("Loading fine-tuned UNet weights.")
|
75 |
unet_weights_path = tf.keras.utils.get_file(origin=unet_weights)
|
76 |
tf_unet.load_weights(unet_weights_path)
|
77 |
+
unet_state_dict_from_tf = populate_unet(tf_unet)
|
78 |
+
pt_unet.load_state_dict(unet_state_dict_from_tf)
|
79 |
+
print("Populated PT UNet from TF weights.")
|
80 |
|
81 |
+
print("Weights ported, preparing StabelDiffusionPipeline...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
pipeline = StableDiffusionPipeline.from_pretrained(
|
83 |
PRETRAINED_CKPT,
|
84 |
unet=pt_unet,
|