import torch from tqdm import tqdm from diffusers.models import AutoencoderKL from diffusers import StableDiffusionXLPipeline import gradio as gr import requests import spaces models_list = [] loras_list = ["None"] models = {} def download_file(url, filename, progress=gr.Progress(track_tqdm=True)): response = requests.get(url, stream=True) total_size_in_bytes = int(response.headers.get('content-length', 0)) block_size = 1024 # 1 Kibibyte progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) with open(filename, 'wb') as file: for data in response.iter_content(block_size): progress_bar.update(len(data)) file.write(data) progress_bar.close() if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: print("ERROR, something went wrong") def get_civitai_model_info(model_id): url = f"https://civitai.com/api/v1/models/{model_id}" response = requests.get(url) if response.status_code != 200: return None return response.json() def find_download_url(data, file_extension): for file in data.get('modelVersions', [{}])[0].get('files', []): if file['name'].endswith(file_extension): return file['downloadUrl'] return None def download_civitai_model(model_id, lora_id=""): try: model_data = get_civitai_model_info(model_id) if model_data is None: return f"Error: Model with ID {model_id} not found." model_name = model_data['name'] model_ckpt_url = find_download_url(model_data, '.ckpt') model_safetensors_url = find_download_url(model_data, '.safetensors') model_url = model_ckpt_url or model_safetensors_url if not model_url: return f"Error: No suitable file found for model {model_name}." file_extension = '.ckpt' if model_ckpt_url else '.safetensors' download_file(model_url, f"{model_name}{file_extension}") if lora_id: lora_data = get_civitai_model_info(lora_id) if lora_data is None: return f"Error: LoRA with ID {lora_id} not found." lora_name = lora_data['name'] lora_safetensors_url = find_download_url(lora_data, '.safetensors') if not lora_safetensors_url: return f"Error: No suitable file found for LoRA {lora_name}." download_file(lora_safetensors_url, f"{lora_name}.safetensors") loras_list.append(lora_name) else: lora_name = "None" models_list.append(model_name) return "Model/LoRA Downloaded!" except Exception as e: return f"Error downloading model or LoRA: {e}" def load_model(model, lora="", use_lora=False): try: print(f"\n\nLoading {model}...") vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, ) pipeline = StableDiffusionXLPipeline.from_pretrained( model, vae=vae, torch_dtype=torch.float16, ) if use_lora and lora != "": pipeline.load_lora_weights(lora) pipeline.to("cuda") models[model] = pipeline return "Model/LoRA loaded successfully!" except Exception as e: return f"Error loading model {model}: {e}" @spaces.GPU def generate_images( model_name, lora_name, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images=4, progress=gr.Progress(track_tqdm=True) ): if prompt is not None and prompt.strip() != "": if lora_name == "None": load_model(model_name, "", False) elif lora_name in loras_list and lora_name != "None": load_model(model_name, lora_name, True) pipe = models.get(model_name) if pipe is None: return [] outputs = [] for _ in range(num_images): output = pipe( prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, height=height, width=width )["images"][0] outputs.append(output) return outputs else: return gr.Warning("Prompt empty!") # Create the Gradio blocks with gr.Blocks(theme='ParityError/Interstellar') as demo: with gr.Row(equal_height=False): with gr.Tab("Generate"): with gr.Column(elem_id="input_column"): with gr.Group(elem_id="input_group"): model_dropdown = gr.Dropdown(choices=models_list, value=models_list[0] if models_list else None, label="Model", elem_id="model_dropdown") lora_dropdown = gr.Dropdown(choices=loras_list, value=loras_list[0], label="LoRA") prompt = gr.Textbox(label="Prompt", elem_id="prompt_textbox") generate_btn = gr.Button("Generate Image", elem_id="generate_button") with gr.Accordion("Advanced", open=False, elem_id="advanced_accordion"): negative_prompt = gr.Textbox(label="Negative Prompt", value="lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", elem_id="negative_prompt_textbox") num_inference_steps = gr.Slider(minimum=10, maximum=50, step=1, value=25, label="Number of Inference Steps", elem_id="num_inference_steps_slider") guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.5, value=7.5, label="Guidance Scale", elem_id="guidance_scale_slider") height = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Height", elem_id="height_slider") width = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Width", elem_id="width_slider") num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images", elem_id="num_images_slider") with gr.Column(elem_id="output_column"): output_gallery = gr.Gallery(label="Generated Images", height=480, scale=1, elem_id="output_gallery") generate_btn.click(generate_images, inputs=[model_dropdown, lora_dropdown, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images], outputs=output_gallery) with gr.Tab("Download Custom Model"): with gr.Group(): model_id = gr.Textbox(label="CivitAI Model ID") lora_id = gr.Textbox(label="CivitAI LoRA ID (Optional)") download_button = gr.Button("Download Model") download_output = gr.Textbox(label="Download Output") download_button.click(download_civitai_model, inputs=[model_id, lora_id], outputs=download_output) demo.launch()