File size: 7,896 Bytes
08203ce
0925cf1
3871549
a58c598
809d6c5
c63d488
1aa42ec
5bde01d
0925cf1
08203ce
 
 
 
1aa42ec
e14fae9
1aa42ec
fd34bb6
1aa42ec
 
e14fae9
 
d2dc338
1aa42ec
 
 
 
 
 
 
 
 
09073c1
c3967c0
 
 
 
 
 
 
 
 
 
303d4ef
 
c3967c0
f52f10c
3871549
c3967c0
 
 
 
d2dc338
c3967c0
 
 
3871549
c3967c0
 
 
 
08203ce
d2dc338
3871549
 
c3967c0
 
 
 
d2dc338
c3967c0
 
 
3871549
08203ce
 
09073c1
 
d2dc338
 
3871549
09073c1
 
 
 
d2dc338
09073c1
3871549
 
1aa42ec
e47b9ec
 
 
d2dc338
c92c1fc
d2dc338
 
05c550f
 
 
 
 
b54b151
d2dc338
05c550f
 
 
 
3871549
08203ce
 
e14fae9
 
d2dc338
e14fae9
c92c1fc
d2dc338
0925cf1
5bde01d
c6747cf
e66a721
1aa42ec
e66a721
 
 
 
c6747cf
 
874cb7c
65dc494
c6747cf
3871549
8f724dc
 
 
c8f91a3
8f724dc
db07984
dfe65d8
 
 
 
 
 
 
 
 
 
e66a721
c28f29b
 
3871549
edf126d
92ec9db
563066a
a031477
1aa42ec
 
 
 
 
e47b9ec
1aa42ec
 
 
 
 
 
 
 
 
 
 
e14fae9
e47b9ec
1aa42ec
e14fae9
1aa42ec
 
3871549
 
1aa42ec
e14fae9
1aa42ec
e14fae9
09073c1
82d2444
d2dc338
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os
import torch
from tqdm import tqdm
from diffusers.models import AutoencoderKL
from diffusers import StableDiffusionXLPipeline
import gradio as gr
import requests
import spaces

# Ensure directories exist
os.makedirs('models', exist_ok=True)
os.makedirs('loras', exist_ok=True)

models_list = []
loras_list = ["None"]
models = {}

def download_file(url, filename, progress=gr.Progress(track_tqdm=True)):
    response = requests.get(url, stream=True)
    total_size_in_bytes = int(response.headers.get('content-length', 0))
    block_size = 1024  # 1 Kibibyte
    progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)

    with open(filename, 'wb') as file:
        for data in response.iter_content(block_size):
            progress_bar.update(len(data))
            file.write(data)
    progress_bar.close()
    if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
        print("ERROR, something went wrong")

def get_civitai_model_info(model_id):
    url = f"https://civitai.com/api/v1/models/{model_id}"
    response = requests.get(url)
    if response.status_code != 200:
        return None
    return response.json()

def find_download_url(data, file_extension):
    for file in data.get('modelVersions', [{}])[0].get('files', []):
        if file['name'].endswith(file_extension):
            return file['downloadUrl']
        else:
            return None

def download_and_load_civitai_model(model_id, lora_id="", progress=gr.Progress(track_tqdm=True)):
    try:
        model_data = get_civitai_model_info(model_id)
        if model_data is None:
            return f"Error: Model with ID {model_id} not found."

        model_name = model_data['name']
        model_ckpt_url = find_download_url(model_data, '.ckpt')
        model_safetensors_url = find_download_url(model_data, '.safetensors')
        model_url = model_ckpt_url or model_safetensors_url

        if not model_url:
            return f"Error: No suitable file found for model {model_name}."

        file_extension = '.ckpt' if model_ckpt_url else '.safetensors'
        model_filename = os.path.join('models', f"{model_name}{file_extension}")
        download_file(model_url, model_filename)

        if lora_id:
            lora_data = get_civitai_model_info(lora_id)
            if lora_data is None:
                return f"Error: LoRA with ID {lora_id} not found."

            lora_name = lora_data['name']
            lora_safetensors_url = find_download_url(lora_data, '.safetensors')
            if not lora_safetensors_url:
                return f"Error: No suitable file found for LoRA {lora_name}."

            lora_filename = os.path.join('loras', f"{lora_name}.safetensors")
            download_file(lora_safetensors_url, lora_filename)
            if lora_name not in loras_list:
                loras_list.append(lora_name)
        else:
            lora_name = "None"

        if model_name not in models_list:
            models_list.append(model_name)

        # Load model after downloading
        load_result = load_model(model_filename, lora_name, use_lora=(lora_name != "None"))
        return f"Model/LoRA Downloaded and Loaded! {load_result}"
    except Exception as e:
        return f"Error downloading model or LoRA: {e}"

def refresh_dropdowns():
    return gr.update(choices=models_list), gr.update(choices=loras_list)

def load_model(model, lora="", use_lora=False, progress=gr.Progress(track_tqdm=True)):
    try:
        print(f"\n\nLoading {model}...")
        gr.Info(f"Loading {model}, it may take a while.")
        vae = AutoencoderKL.from_pretrained(
            "madebyollin/sdxl-vae-fp16-fix",
            torch_dtype=torch.float16,
        )

        pipeline = StableDiffusionXLPipeline.from_single_file(
            model,
            vae=vae,
            torch_dtype=torch.float16,
        )

        if use_lora and lora != "":
            lora_path = os.path.join('loras', lora + '.safetensors')
            pipeline.load_lora_weights(lora_path)

        pipeline.to("cuda")
        models[model] = pipeline
        return "Model/LoRA loaded successfully!"
    except Exception as e:
        return f"Error loading model {model}: {e}"

@spaces.GPU
def generate_images(
    model_name,
    lora_name,
    prompt,
    negative_prompt,
    num_inference_steps,
    guidance_scale,
    height,
    width,
    num_images=4,
    progress=gr.Progress(track_tqdm=True)
):
    if prompt is not None and prompt.strip() != "":
        pipe = models.get(model_name)
        if pipe is None:
            return []

        outputs = []

        for _ in range(num_images):
            output = pipe(
                prompt,
                negative_prompt=negative_prompt,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                height=height,
                width=width
            )["images"][0]
            outputs.append(output)

        return outputs
    else:
        return gr.Warning("Prompt empty!")

# Create the Gradio blocks
with gr.Blocks(theme='ParityError/Interstellar') as demo:
    with gr.Row(equal_height=False):
        with gr.Tab("Generate"):
            with gr.Column(elem_id="input_column"):
                with gr.Group(elem_id="input_group"):
                    model_dropdown = gr.Dropdown(choices=models_list, value=models_list[0] if models_list else None, label="Model", elem_id="model_dropdown")
                    lora_dropdown = gr.Dropdown(choices=loras_list, value=loras_list[0], label="LoRA")
                    refresh_btn = gr.Button("Refresh Dropdowns")
                    prompt = gr.Textbox(label="Prompt", elem_id="prompt_textbox")
                    generate_btn = gr.Button("Generate Image", elem_id="generate_button")
                with gr.Accordion("Advanced", open=False, elem_id="advanced_accordion"):
                    negative_prompt = gr.Textbox(label="Negative Prompt", value="lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", elem_id="negative_prompt_textbox")
                    num_inference_steps = gr.Slider(minimum=10, maximum=50, step=1, value=25, label="Number of Inference Steps", elem_id="num_inference_steps_slider")
                    guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.5, value=7.5, label="Guidance Scale", elem_id="guidance_scale_slider")
                    height = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Height", elem_id="height_slider")
                    width = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Width", elem_id="width_slider")
                    num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images", elem_id="num_images_slider")
            with gr.Column(elem_id="output_column"):
                output_gallery = gr.Gallery(label="Generated Images", height=480, scale=1, elem_id="output_gallery")

            refresh_btn.click(refresh_dropdowns, outputs=[model_dropdown, lora_dropdown])
            generate_btn.click(generate_images, inputs=[model_dropdown, lora_dropdown, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images], outputs=output_gallery)

        with gr.Tab("Download Custom Model"):
            with gr.Group():
                model_id = gr.Textbox(label="CivitAI Model ID")
                lora_id = gr.Textbox(label="CivitAI LoRA ID (Optional)")
                download_button = gr.Button("Download Model")

            download_output = gr.Textbox(label="Download Output")

            download_button.click(download_and_load_civitai_model, inputs=[model_id, lora_id], outputs=download_output)

    demo.launch()