import torch from diffusers.models import AutoencoderKL from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline import gradio as gr import os import spaces # Load the models outside of the generate_images function model_list = [model.strip() for model in os.environ.get("MODELS").split(",")] lora_list = [model.strip() for model in os.environ.get("LORAS").split(",")] print(f"Detected {len(model_list)} on models and {len(lora_list)} LoRAs.") models = {} for model_name in model_list: try: print(f"\n\nLoading {model_name}...") vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, ) pipeline = ( StableDiffusionXLPipeline.from_single_file if models[model_name].endswith(".safetensors") else StableDiffusionXLPipeline.from_pretrained ) models[model_name] = pipeline( model_name, vae=vae, torch_dtype=torch.float16, custom_pipeline="lpw_stable_diffusion_xl", use_safetensors=True, add_watermarker=False, ) models[model_name].to(device) except Exception as e: print(f"Error loading model {model_name}: {e}") @spaces.GPU def generate_images( model_name, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images=4, progress=gr.Progress(track_tqdm=True) ): if prompt is not None and prompt.strip() != "": pipe = models.get(model_name) if pipe is None: return [] print(f"Prompt is: [ {prompt} ]") outputs = [] for _ in range(num_images): output = pipe( prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, height=height, width=width )["images"][0] outputs.append(output) return outputs else: gr.Warning("Prompt empty!") # Create the Gradio blocks with gr.Blocks(theme='ParityError/Interstellar') as demo: with gr.Row(equal_height=False): with gr.Column(elem_id="input_column"): with gr.Group(elem_id="input_group"): model_dropdown = gr.Dropdown(choices=list(models.keys()), value=model_list[0] if model_list else None, label="Model", elem_id="model_dropdown") prompt = gr.Textbox(label="Prompt", elem_id="prompt_textbox") generate_btn = gr.Button("Generate Image", elem_id="generate_button") with gr.Accordion("Advanced", open=False, elem_id="advanced_accordion"): negative_prompt = gr.Textbox(label="Negative Prompt", value="lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", elem_id="negative_prompt_textbox") num_inference_steps = gr.Slider(minimum=10, maximum=50, step=1, value=25, label="Number of Inference Steps", elem_id="num_inference_steps_slider") guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.5, value=7.5, label="Guidance Scale", elem_id="guidance_scale_slider") height = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Height", elem_id="height_slider") width = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Width", elem_id="width_slider") num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images", elem_id="num_images_slider") with gr.Column(elem_id="output_column"): output_gallery = gr.Gallery(label="Generated Images", height=480, scale=1, elem_id="output_gallery") generate_btn.click(generate_images, inputs=[model_dropdown, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images], outputs=output_gallery) demo.launch()