Duskfallcrew commited on
Commit
3e84455
·
verified ·
1 Parent(s): 1126c53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -40
app.py CHANGED
@@ -13,16 +13,18 @@ import subprocess
13
  from urllib.parse import urlparse, unquote
14
  from pathlib import Path
15
  import tempfile
16
- from tqdm import tqdm
17
  import psutil
18
  import math
19
  import shutil
20
  import hashlib
21
  from datetime import datetime
22
  from typing import Dict, List, Optional
23
- from huggingface_hub import login, HfApi
 
24
  from huggingface_hub.errors import HfHubHTTPError
25
 
 
26
  # ---------------------- DEPENDENCIES ----------------------
27
  def install_dependencies_gradio():
28
  """Installs the necessary dependencies."""
@@ -55,6 +57,49 @@ def create_model_repo(api, user, orgs_name, model_name, make_private=False):
55
  return repo_id
56
 
57
  # ---------------------- MODEL LOADING AND CONVERSION ----------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def load_sdxl_checkpoint(checkpoint_path):
60
  """Loads an SDXL checkpoint (.ckpt or .safetensors) and returns components."""
@@ -85,61 +130,57 @@ def load_sdxl_checkpoint(checkpoint_path):
85
 
86
  def build_diffusers_model(text_encoder1_state, text_encoder2_state, vae_state, unet_state, reference_model_path=None):
87
  """Builds the Diffusers pipeline components from the loaded state dicts."""
88
- # --- Load configurations, create models (empty), load state dicts ---
 
 
 
89
 
90
  # 1. Text Encoders
91
- if reference_model_path:
92
- config_text_encoder1 = CLIPTextConfig.from_pretrained(reference_model_path, subfolder="text_encoder")
93
- config_text_encoder2 = CLIPTextConfig.from_pretrained(reference_model_path, subfolder="text_encoder_2")
94
- else: #Default
95
- config_text_encoder1 = CLIPTextConfig.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
96
- config_text_encoder2 = CLIPTextConfig.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder_2")
97
 
98
  text_encoder1 = CLIPTextModel(config_text_encoder1)
99
  text_encoder2 = CLIPTextModel(config_text_encoder2)
100
  text_encoder1.load_state_dict(text_encoder1_state)
101
  text_encoder2.load_state_dict(text_encoder2_state)
102
- text_encoder1.to(torch.float16) # Ensure fp16
103
- text_encoder2.to(torch.float16)
104
 
105
  # 2. VAE
106
- if reference_model_path:
107
- vae = AutoencoderKL.from_pretrained(reference_model_path, subfolder="vae")
108
- else:
109
- vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae")
110
  vae.load_state_dict(vae_state)
111
- vae.to(torch.float16)
112
 
113
  # 3. UNet
114
- if reference_model_path:
115
- unet = UNet2DConditionModel.from_pretrained(reference_model_path, subfolder="unet")
116
- else:
117
- unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet")
118
-
119
  unet.load_state_dict(unet_state)
120
- unet.to(torch.float16)
121
 
122
  return text_encoder1, text_encoder2, vae, unet
123
 
124
 
125
 
126
- def convert_and_save_sdxl_to_diffusers(checkpoint_path, output_path, reference_model_path):
127
- """Converts an SDXL checkpoint to Diffusers format and saves it."""
 
 
 
 
 
 
128
 
129
  text_encoder1_state, text_encoder2_state, vae_state, unet_state = load_sdxl_checkpoint(checkpoint_path)
130
  text_encoder1, text_encoder2, vae, unet = build_diffusers_model(text_encoder1_state, text_encoder2_state, vae_state, unet_state, reference_model_path)
131
 
132
 
133
- pipeline = StableDiffusionXLPipeline(
134
- vae=vae,
135
- text_encoder=text_encoder1,
136
- text_encoder_2=text_encoder2,
137
- unet=unet,
138
- # You'll likely need to add tokenizer, scheduler, etc., here from the reference model
139
- tokenizer = pipeline.tokenizer,
140
- tokenizer_2 = pipeline.tokenizer_2,
141
- scheduler = pipeline.scheduler
142
- )
143
  pipeline.save_pretrained(output_path)
144
  print(f"Model saved as Diffusers format: {output_path}")
145
 
@@ -150,22 +191,25 @@ def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_priv
150
  api = HfApi()
151
  user = api.whoami(hf_token)
152
  model_repo = create_model_repo(api, user, orgs_name, model_name, make_private)
153
- api.upload_folder(folder_path=model_path, repo_id=model_repo) # Use upload_folder
154
  print(f"Model uploaded to: https://huggingface.co/{model_repo}")
155
 
156
  # ---------------------- GRADIO INTERFACE ----------------------
157
  def main(model_to_load, reference_model, output_path, hf_token, orgs_name, model_name, make_private):
158
  """Main function: SDXL checkpoint to Diffusers, always fp16."""
159
 
160
- convert_and_save_sdxl_to_diffusers(model_to_load, output_path, reference_model)
161
- upload_to_huggingface(output_path, hf_token, orgs_name, model_name, make_private)
 
 
 
 
162
 
163
- return "Conversion and upload completed successfully!"
164
 
165
  with gr.Blocks() as demo:
166
- model_to_load = gr.Textbox(label="SDXL Checkpoint to Load (.ckpt or .safetensors)", placeholder="Path to checkpoint")
167
  reference_model = gr.Textbox(label="Reference Diffusers Model (Optional)", placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0 (Leave blank for default)")
168
- output_path = gr.Textbox(label="Output Path (Diffusers Format)", value="/content/output") # Clarified label
169
  hf_token = gr.Textbox(label="Hugging Face Token", placeholder="Your Hugging Face write token")
170
  orgs_name = gr.Textbox(label="Organization Name (Optional)", placeholder="Your organization name")
171
  model_name = gr.Textbox(label="Model Name", placeholder="The name of your model on Hugging Face")
 
13
  from urllib.parse import urlparse, unquote
14
  from pathlib import Path
15
  import tempfile
16
+ #from tqdm import tqdm # Removed as not crucial and can break display in gradio.
17
  import psutil
18
  import math
19
  import shutil
20
  import hashlib
21
  from datetime import datetime
22
  from typing import Dict, List, Optional
23
+ from huggingface_hub import login, HfApi, hf_hub_download # Import hf_hub_download
24
+ from huggingface_hub.utils import validate_repo_id, HFValidationError
25
  from huggingface_hub.errors import HfHubHTTPError
26
 
27
+
28
  # ---------------------- DEPENDENCIES ----------------------
29
  def install_dependencies_gradio():
30
  """Installs the necessary dependencies."""
 
57
  return repo_id
58
 
59
  # ---------------------- MODEL LOADING AND CONVERSION ----------------------
60
+ def download_model(model_path_or_url):
61
+ """Downloads a model from a URL or Hugging Face Hub, handling various cases.
62
+ Args:
63
+ model_path_or_url: Can be a local path, a URL, a Hugging Face repo ID,
64
+ or a repo ID with a filename (e.g., "user/repo/file.safetensors").
65
+ Returns:
66
+ The local path to the downloaded (or existing) file.
67
+ """
68
+ try:
69
+ # Check if it's a valid Hugging Face repo ID (and potentially a file within)
70
+ try:
71
+ validate_repo_id(model_path_or_url)
72
+ # It's a valid repo ID; use hf_hub_download without a filename
73
+ local_path = hf_hub_download(repo_id=model_path_or_url)
74
+ return local_path
75
+ except HFValidationError:
76
+ # Not a simple repo ID. Might be a repo ID with a filename, or a URL.
77
+ pass
78
+
79
+ if model_path_or_url.startswith("http://") or model_path_or_url.startswith("https://"):
80
+ # It's a URL, use hf_hub_download to handle it (it handles URLs gracefully).
81
+ local_path = hf_hub_download(repo_id=None, filename=None, repo_type=None, url=model_path_or_url)
82
+ return local_path
83
+ elif os.path.isfile(model_path_or_url): # Local File
84
+ return model_path_or_url
85
+ else: #HuggingFace Model
86
+ # Try splitting into repo ID and filename (for "user/repo/file.safetensors")
87
+ try:
88
+ parts = model_path_or_url.split("/", 1) # Split only on the first /
89
+ if len(parts) == 2:
90
+ repo_id, filename = parts
91
+ validate_repo_id(repo_id) # Check the repo_id part.
92
+ local_path = hf_hub_download(repo_id=repo_id, filename=filename)
93
+ return local_path
94
+ else:
95
+ raise ValueError("Invalid input")
96
+ except HFValidationError: #Still invalid
97
+ raise ValueError(f"Invalid model path or URL: {model_path_or_url}")
98
+
99
+
100
+ except Exception as e:
101
+ raise ValueError(f"Error downloading or accessing model: {e}")
102
+
103
 
104
  def load_sdxl_checkpoint(checkpoint_path):
105
  """Loads an SDXL checkpoint (.ckpt or .safetensors) and returns components."""
 
130
 
131
  def build_diffusers_model(text_encoder1_state, text_encoder2_state, vae_state, unet_state, reference_model_path=None):
132
  """Builds the Diffusers pipeline components from the loaded state dicts."""
133
+
134
+ # Default to SDXL base 1.0 if no reference model is provided
135
+ if not reference_model_path:
136
+ reference_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
137
 
138
  # 1. Text Encoders
139
+ config_text_encoder1 = CLIPTextConfig.from_pretrained(reference_model_path, subfolder="text_encoder")
140
+ config_text_encoder2 = CLIPTextConfig.from_pretrained(reference_model_path, subfolder="text_encoder_2")
 
 
 
 
141
 
142
  text_encoder1 = CLIPTextModel(config_text_encoder1)
143
  text_encoder2 = CLIPTextModel(config_text_encoder2)
144
  text_encoder1.load_state_dict(text_encoder1_state)
145
  text_encoder2.load_state_dict(text_encoder2_state)
146
+ text_encoder1.to(torch.float16).to("cpu") # Ensure fp16 and CPU
147
+ text_encoder2.to(torch.float16).to("cpu")
148
 
149
  # 2. VAE
150
+ vae = AutoencoderKL.from_pretrained(reference_model_path, subfolder="vae")
 
 
 
151
  vae.load_state_dict(vae_state)
152
+ vae.to(torch.float16).to("cpu")
153
 
154
  # 3. UNet
155
+ unet = UNet2DConditionModel.from_pretrained(reference_model_path, subfolder="unet")
 
 
 
 
156
  unet.load_state_dict(unet_state)
157
+ unet.to(torch.float16).to("cpu")
158
 
159
  return text_encoder1, text_encoder2, vae, unet
160
 
161
 
162
 
163
+ def convert_and_save_sdxl_to_diffusers(checkpoint_path_or_url, output_path, reference_model_path):
164
+ """Converts an SDXL checkpoint to Diffusers format and saves it.
165
+ Args:
166
+ checkpoint_path_or_url: The path/URL/repo ID of the checkpoint.
167
+ """
168
+
169
+ # Download the model if necessary (handles URLs, repo IDs, and local paths)
170
+ checkpoint_path = download_model(checkpoint_path_or_url)
171
 
172
  text_encoder1_state, text_encoder2_state, vae_state, unet_state = load_sdxl_checkpoint(checkpoint_path)
173
  text_encoder1, text_encoder2, vae, unet = build_diffusers_model(text_encoder1_state, text_encoder2_state, vae_state, unet_state, reference_model_path)
174
 
175
 
176
+ # Load tokenizer and scheduler from the reference model
177
+ pipeline = StableDiffusionXLPipeline.from_pretrained(reference_model_path,
178
+ text_encoder=text_encoder1,
179
+ text_encoder_2=text_encoder2,
180
+ vae=vae,
181
+ unet=unet,
182
+ torch_dtype=torch.float16,)
183
+ pipeline.to("cpu")
 
 
184
  pipeline.save_pretrained(output_path)
185
  print(f"Model saved as Diffusers format: {output_path}")
186
 
 
191
  api = HfApi()
192
  user = api.whoami(hf_token)
193
  model_repo = create_model_repo(api, user, orgs_name, model_name, make_private)
194
+ api.upload_folder(folder_path=model_path, repo_id=model_repo)
195
  print(f"Model uploaded to: https://huggingface.co/{model_repo}")
196
 
197
  # ---------------------- GRADIO INTERFACE ----------------------
198
  def main(model_to_load, reference_model, output_path, hf_token, orgs_name, model_name, make_private):
199
  """Main function: SDXL checkpoint to Diffusers, always fp16."""
200
 
201
+ try:
202
+ convert_and_save_sdxl_to_diffusers(model_to_load, output_path, reference_model)
203
+ upload_to_huggingface(output_path, hf_token, orgs_name, model_name, make_private)
204
+ return "Conversion and upload completed successfully!"
205
+ except Exception as e:
206
+ return f"An error occurred: {e}" # Return the error message
207
 
 
208
 
209
  with gr.Blocks() as demo:
210
+ model_to_load = gr.Textbox(label="SDXL Checkpoint (Path, URL, or HF Repo)", placeholder="Path, URL, or Hugging Face Repo ID (e.g., my-org/my-model or my-org/my-model/file.safetensors)")
211
  reference_model = gr.Textbox(label="Reference Diffusers Model (Optional)", placeholder="e.g., stabilityai/stable-diffusion-xl-base-1.0 (Leave blank for default)")
212
+ output_path = gr.Textbox(label="Output Path (Diffusers Format)", value="output") # Default changed to "output"
213
  hf_token = gr.Textbox(label="Hugging Face Token", placeholder="Your Hugging Face write token")
214
  orgs_name = gr.Textbox(label="Organization Name (Optional)", placeholder="Your organization name")
215
  model_name = gr.Textbox(label="Model Name", placeholder="The name of your model on Hugging Face")