Spaces:
Running
Running
import os | |
import spaces | |
import argparse | |
from pathlib import Path | |
import os | |
import torch | |
from diffusers import (DiffusionPipeline, AutoencoderKL, FlowMatchEulerDiscreteScheduler, StableDiffusionXLPipeline, StableDiffusionPipeline, | |
FluxPipeline, FluxTransformer2DModel, SD3Transformer2DModel, StableDiffusion3Pipeline) | |
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection, CLIPFeatureExtractor, AutoTokenizer, T5EncoderModel, BitsAndBytesConfig as TFBitsAndBytesConfig | |
from huggingface_hub import save_torch_state_dict, snapshot_download | |
from diffusers.loaders.single_file_utils import (convert_flux_transformer_checkpoint_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, | |
convert_sd3_t5_checkpoint_to_diffusers) | |
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | |
import safetensors.torch | |
import gradio as gr | |
import shutil | |
import gc | |
import tempfile | |
# also requires aria, gdown, peft, huggingface_hub, safetensors, transformers, accelerate, pytorch_lightning | |
from utils import (get_token, set_token, is_repo_exists, is_repo_name, get_download_file, upload_repo, gate_repo) | |
from sdutils import (SCHEDULER_CONFIG_MAP, get_scheduler_config, fuse_loras, DTYPE_DEFAULT, get_dtype, get_dtypes, get_model_type_from_key, get_process_dtype) | |
def fake_gpu(): | |
pass | |
try: | |
from diffusers import BitsAndBytesConfig | |
is_nf4 = True | |
except Exception: | |
is_nf4 = False | |
FLUX_BASE_REPOS = ["camenduru/FLUX.1-dev-diffusers", "black-forest-labs/FLUX.1-schnell", "John6666/flux1-dev-fp8-flux", "John6666/flux1-schnell-fp8-flux"] | |
FLUX_T5_URL = "https://huggingface.co./camenduru/FLUX.1-dev/blob/main/t5xxl_fp8_e4m3fn.safetensors" | |
SD35_BASE_REPOS = ["adamo1139/stable-diffusion-3.5-large-ungated", "adamo1139/stable-diffusion-3.5-large-turbo-ungated"] | |
SD35_T5_URL = "https://huggingface.co./adamo1139/stable-diffusion-3.5-large-turbo-ungated/blob/main/text_encoders/t5xxl_fp8_e4m3fn.safetensors" | |
TEMP_DIR = tempfile.mkdtemp() | |
IS_ZERO = os.environ.get("SPACES_ZERO_GPU") is not None | |
IS_CUDA = torch.cuda.is_available() | |
def safe_clean(path: str): | |
try: | |
if Path(path).exists(): | |
if Path(path).is_dir(): shutil.rmtree(str(Path(path))) | |
else: Path(path).unlink() | |
print(f"Deleted: {path}") | |
else: print(f"File not found: {path}") | |
except Exception as e: | |
print(f"Failed to delete: {path} {e}") | |
def save_readme_md(dir, url): | |
orig_url = "" | |
orig_name = "" | |
if is_repo_name(url): | |
orig_name = url | |
orig_url = f"https://huggingface.co./{url}/" | |
elif "http" in url: | |
orig_name = url | |
orig_url = url | |
if orig_name and orig_url: | |
md = f"""--- | |
license: other | |
language: | |
- en | |
library_name: diffusers | |
pipeline_tag: text-to-image | |
tags: | |
- text-to-image | |
--- | |
Converted from [{orig_name}]({orig_url}). | |
""" | |
else: | |
md = f"""--- | |
license: other | |
language: | |
- en | |
library_name: diffusers | |
pipeline_tag: text-to-image | |
tags: | |
- text-to-image | |
--- | |
""" | |
path = str(Path(dir, "README.md")) | |
with open(path, mode='w', encoding="utf-8") as f: | |
f.write(md) | |
def save_module(model, name: str, dir: str, dtype: str="fp8", progress=gr.Progress(track_tqdm=True)): # doesn't work | |
if name in ["vae", "transformer", "unet"]: pattern = "diffusion_pytorch_model{suffix}.safetensors" | |
else: pattern = "model{suffix}.safetensors" | |
if name in ["transformer", "unet"]: size = "10GB" | |
else: size = "5GB" | |
path = str(Path(f"{dir.removesuffix('/')}/{name}")) | |
os.makedirs(path, exist_ok=True) | |
progress(0, desc=f"Saving {name} to {dir}...") | |
print(f"Saving {name} to {dir}...") | |
model.to("cpu") | |
sd = dict(model.state_dict()) | |
new_sd = {} | |
for key in list(sd.keys()): | |
q = sd.pop(key) | |
if dtype == "fp8": new_sd[key] = q if q.dtype == torch.float8_e4m3fn else q.to(torch.float8_e4m3fn) | |
else: new_sd[key] = q | |
del sd | |
gc.collect() | |
save_torch_state_dict(state_dict=new_sd, save_directory=path, filename_pattern=pattern, max_shard_size=size) | |
del new_sd | |
gc.collect() | |
def save_module_sd(sd: dict, name: str, dir: str, dtype: str="fp8", progress=gr.Progress(track_tqdm=True)): | |
if name in ["vae", "transformer", "unet"]: pattern = "diffusion_pytorch_model{suffix}.safetensors" | |
else: pattern = "model{suffix}.safetensors" | |
if name in ["transformer", "unet"]: size = "10GB" | |
else: size = "5GB" | |
path = str(Path(f"{dir.removesuffix('/')}/{name}")) | |
os.makedirs(path, exist_ok=True) | |
progress(0, desc=f"Saving state_dict of {name} to {dir}...") | |
print(f"Saving state_dict of {name} to {dir}...") | |
new_sd = {} | |
for key in list(sd.keys()): | |
q = sd.pop(key).to("cpu") | |
if dtype == "fp8": new_sd[key] = q if q.dtype == torch.float8_e4m3fn else q.to(torch.float8_e4m3fn) | |
else: new_sd[key] = q | |
save_torch_state_dict(state_dict=new_sd, save_directory=path, filename_pattern=pattern, max_shard_size=size) | |
del new_sd | |
gc.collect() | |
def convert_flux_fp8_cpu(new_file: str, new_dir: str, dtype: str, base_repo: str, civitai_key: str, kwargs: dict, progress=gr.Progress(track_tqdm=True)): | |
temp_dir = TEMP_DIR | |
down_dir = str(Path(f"{TEMP_DIR}/down")) | |
os.makedirs(down_dir, exist_ok=True) | |
hf_token = get_token() | |
progress(0.25, desc=f"Loading {new_file}...") | |
orig_sd = safetensors.torch.load_file(new_file) | |
progress(0.3, desc=f"Converting {new_file}...") | |
conv_sd = convert_flux_transformer_checkpoint_to_diffusers(orig_sd) | |
del orig_sd | |
gc.collect() | |
progress(0.35, desc=f"Saving {new_file}...") | |
save_module_sd(conv_sd, "transformer", new_dir, dtype) | |
del conv_sd | |
gc.collect() | |
progress(0.5, desc=f"Loading text_encoder_2 from {FLUX_T5_URL}...") | |
t5_file = get_download_file(temp_dir, FLUX_T5_URL, civitai_key) | |
if not t5_file: raise Exception(f"Safetensors file not found: {FLUX_T5_URL}") | |
t5_sd = safetensors.torch.load_file(t5_file) | |
safe_clean(t5_file) | |
save_module_sd(t5_sd, "text_encoder_2", new_dir, dtype) | |
del t5_sd | |
gc.collect() | |
progress(0.6, desc=f"Loading other components from {base_repo}...") | |
pipe = FluxPipeline.from_pretrained(base_repo, transformer=None, text_encoder_2=None, use_safetensors=True, **kwargs, | |
torch_dtype=torch.bfloat16, token=hf_token) | |
pipe.save_pretrained(new_dir) | |
progress(0.75, desc=f"Loading nontensor files from {base_repo}...") | |
snapshot_download(repo_id=base_repo, local_dir=down_dir, token=hf_token, force_download=True, | |
ignore_patterns=["*.safetensors", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.jpeg", "*.png", "*.webp"]) | |
shutil.copytree(down_dir, new_dir, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.jpeg", "*.png", "*.webp"), dirs_exist_ok=True) | |
safe_clean(down_dir) | |
def convert_sd35_fp8_cpu(new_file: str, new_dir: str, dtype: str, base_repo: str, civitai_key: str, kwargs: dict, progress=gr.Progress(track_tqdm=True)): | |
temp_dir = TEMP_DIR | |
down_dir = str(Path(f"{TEMP_DIR}/down")) | |
os.makedirs(down_dir, exist_ok=True) | |
hf_token = get_token() | |
progress(0.25, desc=f"Loading {new_file}...") | |
orig_sd = safetensors.torch.load_file(new_file) | |
progress(0.3, desc=f"Converting {new_file}...") | |
conv_sd = convert_sd3_transformer_checkpoint_to_diffusers(orig_sd) | |
del orig_sd | |
gc.collect() | |
progress(0.35, desc=f"Saving {new_file}...") | |
save_module_sd(conv_sd, "transformer", new_dir, dtype) | |
del conv_sd | |
gc.collect() | |
progress(0.5, desc=f"Loading text_encoder_3 from {SD35_T5_URL}...") | |
t5_file = get_download_file(temp_dir, SD35_T5_URL, civitai_key) | |
if not t5_file: raise Exception(f"Safetensors file not found: {SD35_T5_URL}") | |
t5_sd = safetensors.torch.load_file(t5_file) | |
safe_clean(t5_file) | |
conv_t5_sd = convert_sd3_t5_checkpoint_to_diffusers(t5_sd) | |
del t5_sd | |
gc.collect() | |
save_module_sd(conv_t5_sd, "text_encoder_3", new_dir, dtype) | |
del conv_t5_sd | |
gc.collect() | |
progress(0.6, desc=f"Loading other components from {base_repo}...") | |
pipe = StableDiffusion3Pipeline.from_pretrained(base_repo, transformer=None, text_encoder_3=None, use_safetensors=True, **kwargs, | |
torch_dtype=torch.bfloat16, token=hf_token) | |
pipe.save_pretrained(new_dir) | |
progress(0.75, desc=f"Loading nontensor files from {base_repo}...") | |
snapshot_download(repo_id=base_repo, local_dir=down_dir, token=hf_token, force_download=True, | |
ignore_patterns=["*.safetensors", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.jpeg", "*.png", "*.webp"]) | |
shutil.copytree(down_dir, new_dir, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.jpeg", "*.png", "*.webp"), dirs_exist_ok=True) | |
safe_clean(down_dir) | |
#@spaces.GPU(duration=60) | |
def load_and_save_pipeline(pipe, model_type: str, url: str, new_file: str, new_dir: str, dtype: str, | |
scheduler: str, ema: bool, image_size: str, is_safety_checker: bool, base_repo: str, civitai_key: str, lora_dict: dict, | |
my_vae, my_clip_tokenizer, my_clip_encoder, my_t5_tokenizer, my_t5_encoder, | |
kwargs: dict, dkwargs: dict, progress=gr.Progress(track_tqdm=True)): | |
try: | |
hf_token = get_token() | |
temp_dir = TEMP_DIR | |
qkwargs = {} | |
tfqkwargs = {} | |
if is_nf4: | |
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16) | |
nf4_config_tf = TFBitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16) | |
else: | |
nf4_config = None | |
nf4_config_tf = None | |
if dtype == "NF4" and nf4_config is not None and nf4_config_tf is not None: | |
qkwargs["quantization_config"] = nf4_config | |
tfqkwargs["quantization_config"] = nf4_config_tf | |
#print(f"model_type:{model_type}, dtype:{dtype}, scheduler:{scheduler}, ema:{ema}, base_repo:{base_repo}") | |
#print("lora_dict:", lora_dict, "kwargs:", kwargs, "dkwargs:", dkwargs) | |
#t5 = None | |
if model_type == "SDXL": | |
if is_repo_name(url): pipe = StableDiffusionXLPipeline.from_pretrained(url, use_safetensors=True, **kwargs, **dkwargs, token=hf_token) | |
else: pipe = StableDiffusionXLPipeline.from_single_file(new_file, use_safetensors=True, **kwargs, **dkwargs) | |
pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs) | |
sconf = get_scheduler_config(scheduler) | |
pipe.scheduler = sconf[0].from_config(pipe.scheduler.config, **sconf[1]) | |
pipe.save_pretrained(new_dir) | |
elif model_type == "SD 1.5": | |
if is_safety_checker: | |
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") | |
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32") | |
kwargs["requires_safety_checker"] = True | |
kwargs["safety_checker"] = safety_checker | |
kwargs["feature_extractor"] = feature_extractor | |
else: kwargs["requires_safety_checker"] = False | |
if is_repo_name(url): pipe = StableDiffusionPipeline.from_pretrained(url, extract_ema=ema, use_safetensors=True, **kwargs, **dkwargs, token=hf_token) | |
else: pipe = StableDiffusionPipeline.from_single_file(new_file, extract_ema=ema, use_safetensors=True, **kwargs, **dkwargs) | |
pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs) | |
sconf = get_scheduler_config(scheduler) | |
pipe.scheduler = sconf[0].from_config(pipe.scheduler.config, **sconf[1]) | |
if image_size != "512": pipe.vae = AutoencoderKL.from_config(pipe.vae.config, sample_size=int(image_size)) | |
pipe.save_pretrained(new_dir) | |
elif model_type == "FLUX": | |
if dtype != "fp8": | |
if is_repo_name(url): | |
transformer = FluxTransformer2DModel.from_pretrained(url, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs) | |
#if my_t5_encoder is None: | |
# t5 = T5EncoderModel.from_pretrained(url, subfolder="text_encoder_2", config=base_repo, **dkwargs, **tfqkwargs) | |
# kwargs["text_encoder_2"] = t5 | |
pipe = FluxPipeline.from_pretrained(url, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token) | |
else: | |
transformer = FluxTransformer2DModel.from_single_file(new_file, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs) | |
#if my_t5_encoder is None: | |
# t5 = T5EncoderModel.from_pretrained(base_repo, subfolder="text_encoder_2", config=base_repo, **dkwargs, **tfqkwargs) | |
# kwargs["text_encoder_2"] = t5 | |
pipe = FluxPipeline.from_pretrained(base_repo, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token) | |
pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs) | |
pipe.save_pretrained(new_dir) | |
elif not is_repo_name(url): convert_flux_fp8_cpu(new_file, new_dir, dtype, base_repo, civitai_key, kwargs) | |
elif model_type == "SD 3.5": | |
if dtype != "fp8": | |
if is_repo_name(url): | |
transformer = SD3Transformer2DModel.from_pretrained(url, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs) | |
#if my_t5_encoder is None: | |
# t5 = T5EncoderModel.from_pretrained(url, subfolder="text_encoder_3", config=base_repo, **dkwargs, **tfqkwargs) | |
# kwargs["text_encoder_3"] = t5 | |
pipe = StableDiffusion3Pipeline.from_pretrained(url, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token) | |
else: | |
transformer = SD3Transformer2DModel.from_single_file(new_file, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs) | |
#if my_t5_encoder is None: | |
# t5 = T5EncoderModel.from_pretrained(base_repo, subfolder="text_encoder_3", config=base_repo, **dkwargs, **tfqkwargs) | |
# kwargs["text_encoder_3"] = t5 | |
pipe = StableDiffusion3Pipeline.from_pretrained(base_repo, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token) | |
pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs) | |
pipe.save_pretrained(new_dir) | |
elif not is_repo_name(url): convert_sd35_fp8_cpu(new_file, new_dir, dtype, base_repo, civitai_key, kwargs) | |
else: # unknown model type | |
if is_repo_name(url): pipe = DiffusionPipeline.from_pretrained(url, use_safetensors=True, **kwargs, **dkwargs, token=hf_token) | |
else: pipe = DiffusionPipeline.from_single_file(new_file, use_safetensors=True, **kwargs, **dkwargs) | |
pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs) | |
pipe.save_pretrained(new_dir) | |
except Exception as e: | |
print(f"Failed to load pipeline. {e}") | |
raise Exception(f"Failed to load pipeline. {e}") from e | |
finally: | |
return pipe | |
def convert_url_to_diffusers(url: str, civitai_key: str="", is_upload_sf: bool=False, dtype: str="fp16", vae: str="", clip: str="", t5: str="", | |
scheduler: str="Euler a", ema: bool=True, image_size: str="768", safety_checker: bool=False, | |
base_repo: str="", mtype: str="", lora_dict: dict={}, is_local: bool=True, progress=gr.Progress(track_tqdm=True)): | |
try: | |
hf_token = get_token() | |
progress(0, desc="Start converting...") | |
temp_dir = TEMP_DIR | |
pipe = None | |
if is_repo_name(url) and is_repo_exists(url): | |
new_file = url | |
model_type = mtype | |
else: | |
new_file = get_download_file(temp_dir, url, civitai_key) | |
if not new_file or Path(new_file).suffix.lower() not in set([".safetensors", ".ckpt", ".bin", ".sft"]): | |
safe_clean(new_file) | |
raise Exception(f"Safetensors file not found: {url}") | |
model_type = get_model_type_from_key(new_file) | |
new_dir = Path(new_file).stem.replace(" ", "_").replace(",", "_").replace(".", "_") # | |
kwargs = {} | |
dkwargs = {} | |
if dtype != DTYPE_DEFAULT: dkwargs["torch_dtype"] = get_process_dtype(dtype, model_type) | |
print(f"Model type: {model_type} / VAE: {vae} / CLIP: {clip} / T5: {t5} / Scheduler: {scheduler} / dtype: {dtype} / EMA: {ema} / Base repo: {base_repo} / LoRAs: {lora_dict}") | |
my_vae = None | |
if vae: | |
progress(0, desc=f"Loading VAE: {vae}...") | |
if is_repo_name(vae): my_vae = AutoencoderKL.from_pretrained(vae, **dkwargs, token=hf_token) | |
else: | |
new_vae_file = get_download_file(temp_dir, vae, civitai_key) | |
my_vae = AutoencoderKL.from_single_file(new_vae_file, **dkwargs) if new_vae_file else None | |
safe_clean(new_vae_file) | |
if my_vae: kwargs["vae"] = my_vae | |
my_clip_tokenizer = None | |
my_clip_encoder = None | |
if clip: | |
progress(0, desc=f"Loading CLIP: {clip}...") | |
if is_repo_name(clip): | |
my_clip_tokenizer = CLIPTokenizer.from_pretrained(clip, token=hf_token) | |
if model_type == "SD 3.5": my_clip_encoder = CLIPTextModelWithProjection.from_pretrained(clip, **dkwargs, token=hf_token) | |
else: my_clip_encoder = CLIPTextModel.from_pretrained(clip, **dkwargs, token=hf_token) | |
else: | |
new_clip_file = get_download_file(temp_dir, clip, civitai_key) | |
if model_type == "SD 3.5": my_clip_encoder = CLIPTextModelWithProjection.from_single_file(new_clip_file, **dkwargs) if new_clip_file else None | |
else: my_clip_encoder = CLIPTextModel.from_single_file(new_clip_file, **dkwargs) if new_clip_file else None | |
safe_clean(new_clip_file) | |
if model_type == "SD 3.5": | |
if my_clip_tokenizer: | |
kwargs["tokenizer"] = my_clip_tokenizer | |
kwargs["tokenizer_2"] = my_clip_tokenizer | |
if my_clip_encoder: | |
kwargs["text_encoder"] = my_clip_encoder | |
kwargs["text_encoder_2"] = my_clip_encoder | |
else: | |
if my_clip_tokenizer: kwargs["tokenizer"] = my_clip_tokenizer | |
if my_clip_encoder: kwargs["text_encoder"] = my_clip_encoder | |
my_t5_tokenizer = None | |
my_t5_encoder = None | |
if t5: | |
progress(0, desc=f"Loading T5: {t5}...") | |
if is_repo_name(t5): | |
my_t5_tokenizer = AutoTokenizer.from_pretrained(t5, token=hf_token) | |
my_t5_encoder = T5EncoderModel.from_pretrained(t5, **dkwargs, token=hf_token) | |
else: | |
new_t5_file = get_download_file(temp_dir, t5, civitai_key) | |
my_t5_encoder = T5EncoderModel.from_single_file(new_t5_file, **dkwargs) if new_t5_file else None | |
safe_clean(new_t5_file) | |
if model_type == "SD 3.5": | |
if my_t5_tokenizer: kwargs["tokenizer_3"] = my_t5_tokenizer | |
if my_t5_encoder: kwargs["text_encoder_3"] = my_t5_encoder | |
else: | |
if my_t5_tokenizer: kwargs["tokenizer_2"] = my_t5_tokenizer | |
if my_t5_encoder: kwargs["text_encoder_2"] = my_t5_encoder | |
pipe = load_and_save_pipeline(pipe, model_type, url, new_file, new_dir, dtype, scheduler, ema, image_size, safety_checker, base_repo, civitai_key, lora_dict, | |
my_vae, my_clip_tokenizer, my_clip_encoder, my_t5_tokenizer, my_t5_encoder, kwargs, dkwargs) | |
if Path(new_dir).exists(): save_readme_md(new_dir, url) | |
if not is_local: | |
if not is_repo_name(new_file) and is_upload_sf: shutil.move(str(Path(new_file).resolve()), str(Path(new_dir, Path(new_file).name).resolve())) | |
else: safe_clean(new_file) | |
progress(1, desc="Converted.") | |
return new_dir | |
except Exception as e: | |
print(f"Failed to convert. {e}") | |
raise Exception(f"Failed to convert. {e}") from e | |
finally: | |
del pipe | |
torch.cuda.empty_cache() | |
gc.collect() | |
def convert_url_to_diffusers_repo(dl_url: str, hf_user: str, hf_repo: str, hf_token: str, civitai_key="", is_private: bool=True, | |
gated: str="False", is_overwrite: bool=False, is_pr: bool=False, | |
is_upload_sf: bool=False, urls: list=[], dtype: str="fp16", vae: str="", clip: str="", t5: str="", scheduler: str="Euler a", | |
ema: bool=True, image_size: str="768", safety_checker: bool=False, | |
base_repo: str="", mtype: str="", lora1: str="", lora1s=1.0, lora2: str="", lora2s=1.0, lora3: str="", lora3s=1.0, | |
lora4: str="", lora4s=1.0, lora5: str="", lora5s=1.0, args: str="", progress=gr.Progress(track_tqdm=True)): | |
try: | |
is_local = False | |
if not civitai_key and os.environ.get("CIVITAI_API_KEY"): civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key | |
if not hf_token and os.environ.get("HF_TOKEN"): hf_token = os.environ.get("HF_TOKEN") # default HF write token | |
if not hf_user: raise gr.Error(f"Invalid user name: {hf_user}") | |
if gated != "False" and is_private: raise gr.Error(f"Gated repo must be public") | |
set_token(hf_token) | |
lora_dict = {lora1: lora1s, lora2: lora2s, lora3: lora3s, lora4: lora4s, lora5: lora5s} | |
new_path = convert_url_to_diffusers(dl_url, civitai_key, is_upload_sf, dtype, vae, clip, t5, scheduler, ema, image_size, safety_checker, base_repo, mtype, lora_dict, is_local) | |
if not new_path: return "" | |
new_repo_id = f"{hf_user}/{Path(new_path).stem}" | |
if hf_repo != "": new_repo_id = f"{hf_user}/{hf_repo}" | |
if not is_repo_name(new_repo_id): raise gr.Error(f"Invalid repo name: {new_repo_id}") | |
if not is_overwrite and is_repo_exists(new_repo_id) and not is_pr: raise gr.Error(f"Repo already exists: {new_repo_id}") | |
repo_url = upload_repo(new_repo_id, new_path, is_private, is_pr) | |
gate_repo(new_repo_id, gated) | |
safe_clean(new_path) | |
if not urls: urls = [] | |
urls.append(repo_url) | |
md = "### Your new repo:\n" | |
for u in urls: | |
md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>" | |
return gr.update(value=urls, choices=urls), gr.update(value=md) | |
except Exception as e: | |
print(f"Error occured. {e}") | |
raise gr.Error(f"Error occured. {e}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--url", type=str, required=True, help="URL of the model to convert.") | |
parser.add_argument("--dtype", default="fp16", type=str, choices=get_dtypes(), help='Output data type. (Default: "fp16")') | |
parser.add_argument("--scheduler", default="Euler a", type=str, choices=list(SCHEDULER_CONFIG_MAP.keys()), required=False, help="Scheduler name to use.") | |
parser.add_argument("--vae", default="", type=str, required=False, help="URL or Repo ID of the VAE to use.") | |
parser.add_argument("--clip", default="", type=str, required=False, help="URL or Repo ID of the CLIP to use.") | |
parser.add_argument("--t5", default="", type=str, required=False, help="URL or Repo ID of the T5 to use.") | |
parser.add_argument("--base", default="", type=str, required=False, help="Repo ID of the base repo.") | |
parser.add_argument("--nonema", action="store_true", default=False, help="Don't extract EMA (for SD 1.5).") | |
parser.add_argument("--civitai_key", default="", type=str, required=False, help="Civitai API Key (If you want to download file from Civitai).") | |
parser.add_argument("--lora1", default="", type=str, required=False, help="URL of the LoRA to use.") | |
parser.add_argument("--lora1s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora1.") | |
parser.add_argument("--lora2", default="", type=str, required=False, help="URL of the LoRA to use.") | |
parser.add_argument("--lora2s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora2.") | |
parser.add_argument("--lora3", default="", type=str, required=False, help="URL of the LoRA to use.") | |
parser.add_argument("--lora3s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora3.") | |
parser.add_argument("--lora4", default="", type=str, required=False, help="URL of the LoRA to use.") | |
parser.add_argument("--lora4s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora4.") | |
parser.add_argument("--lora5", default="", type=str, required=False, help="URL of the LoRA to use.") | |
parser.add_argument("--lora5s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora5.") | |
parser.add_argument("--loras", default="", type=str, required=False, help="Folder of the LoRA to use.") | |
args = parser.parse_args() | |
assert args.url is not None, "Must provide a URL!" | |
is_local = True | |
lora_dict = {args.lora1: args.lora1s, args.lora2: args.lora2s, args.lora3: args.lora3s, args.lora4: args.lora4s, args.lora5: args.lora5s} | |
if args.loras and Path(args.loras).exists(): | |
for p in Path(args.loras).glob('**/*.safetensors'): | |
lora_dict[str(p)] = 1.0 | |
ema = not args.nonema | |
mtype = "SDXL" | |
convert_url_to_diffusers(args.url, args.civitai_key, args.dtype, args.vae, args.clip, args.t5, args.scheduler, ema, args.base, mtype, lora_dict, is_local) | |