sayakpaul HF staff commited on
Commit
7081a39
·
1 Parent(s): 2b2693f

better conditioning on weight porting.

Browse files
Files changed (1) hide show
  1. convert.py +8 -22
convert.py CHANGED
@@ -6,8 +6,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import \
6
  StableDiffusionSafetyChecker
7
  from transformers import CLIPTextModel
8
 
9
- from conversion_utils import (populate_text_encoder, populate_unet,
10
- run_assertion)
11
 
12
  PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
13
  REVISION = None
@@ -68,31 +67,18 @@ def run_conversion(text_encoder_weights: str = None, unet_weights: str = None):
68
  print("Loading fine-tuned text encoder weights.")
69
  text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
70
  tf_text_encoder.load_weights(text_encoder_weights_path)
 
 
 
71
  if unet_weights is not None:
72
  print("Loading fine-tuned UNet weights.")
73
  unet_weights_path = tf.keras.utils.get_file(origin=unet_weights)
74
  tf_unet.load_weights(unet_weights_path)
 
 
 
75
 
76
- text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)
77
- unet_state_dict_from_tf = populate_unet(tf_unet)
78
- print("Conversion done, now running optional assertions...")
79
-
80
- # Since we cannot compare the fine-tuned weights.
81
- if text_encoder_weights is None:
82
- text_encoder_state_dict_from_pt = pt_text_encoder.state_dict()
83
- run_assertion(text_encoder_state_dict_from_pt, text_encoder_state_dict_from_tf)
84
- if unet_weights is None:
85
- unet_state_dict_from_pt = pt_unet.state_dict()
86
- run_assertion(unet_state_dict_from_pt, unet_state_dict_from_tf)
87
-
88
- if text_encoder_weights is None or unet_weights is None:
89
- print(
90
- "Assertions successful, populating the converted parameters into the diffusers models..."
91
- )
92
- pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
93
- pt_unet.load_state_dict(unet_state_dict_from_tf)
94
-
95
- print("Parameters ported, preparing StabelDiffusionPipeline...")
96
  pipeline = StableDiffusionPipeline.from_pretrained(
97
  PRETRAINED_CKPT,
98
  unet=pt_unet,
 
6
  StableDiffusionSafetyChecker
7
  from transformers import CLIPTextModel
8
 
9
+ from conversion_utils import populate_text_encoder, populate_unet
 
10
 
11
  PRETRAINED_CKPT = "CompVis/stable-diffusion-v1-4"
12
  REVISION = None
 
67
  print("Loading fine-tuned text encoder weights.")
68
  text_encoder_weights_path = tf.keras.utils.get_file(origin=text_encoder_weights)
69
  tf_text_encoder.load_weights(text_encoder_weights_path)
70
+ text_encoder_state_dict_from_tf = populate_text_encoder(tf_text_encoder)
71
+ pt_text_encoder.load_state_dict(text_encoder_state_dict_from_tf)
72
+ print("Populated PT text encoder from TF weights.")
73
  if unet_weights is not None:
74
  print("Loading fine-tuned UNet weights.")
75
  unet_weights_path = tf.keras.utils.get_file(origin=unet_weights)
76
  tf_unet.load_weights(unet_weights_path)
77
+ unet_state_dict_from_tf = populate_unet(tf_unet)
78
+ pt_unet.load_state_dict(unet_state_dict_from_tf)
79
+ print("Populated PT UNet from TF weights.")
80
 
81
+ print("Weights ported, preparing StabelDiffusionPipeline...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  pipeline = StableDiffusionPipeline.from_pretrained(
83
  PRETRAINED_CKPT,
84
  unet=pt_unet,