File size: 7,088 Bytes
0925cf1
3871549
a58c598
e14fae9
c63d488
1aa42ec
5bde01d
0925cf1
1aa42ec
e14fae9
1aa42ec
fd34bb6
1aa42ec
 
e14fae9
 
1aa42ec
 
 
 
 
 
 
 
 
 
c3967c0
 
 
 
 
 
 
 
 
 
 
 
 
e14fae9
3871549
c3967c0
 
 
 
 
 
 
 
3871549
c3967c0
 
 
 
 
3871549
 
c3967c0
 
 
 
0adcbb8
c3967c0
 
 
3871549
c3967c0
3871549
 
 
 
 
 
 
 
1aa42ec
e14fae9
c92c1fc
1aa42ec
05c550f
 
 
 
 
e14fae9
1aa42ec
05c550f
 
 
 
3871549
e14fae9
 
 
 
 
c92c1fc
e14fae9
0925cf1
5bde01d
c6747cf
e66a721
1aa42ec
e66a721
 
 
 
c6747cf
 
874cb7c
65dc494
c6747cf
3871549
1aa42ec
 
 
 
e14fae9
8f724dc
 
 
c8f91a3
8f724dc
db07984
dfe65d8
 
 
 
 
 
 
 
 
 
e66a721
c28f29b
 
3871549
edf126d
92ec9db
563066a
a031477
1aa42ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e14fae9
1aa42ec
e14fae9
1aa42ec
 
3871549
 
1aa42ec
e14fae9
1aa42ec
e14fae9
3871549
82d2444
761d42b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import torch
from tqdm import tqdm
from diffusers.models import AutoencoderKL
from diffusers import StableDiffusionXLPipeline
import gradio as gr
import requests
import spaces

models_list = []
loras_list = ["None"]
models = {}

def download_file(url, filename, progress=gr.Progress(track_tqdm=True)):
    response = requests.get(url, stream=True)
    total_size_in_bytes = int(response.headers.get('content-length', 0))
    block_size = 1024  # 1 Kibibyte
    progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)

    with open(filename, 'wb') as file:
        for data in response.iter_content(block_size):
            progress_bar.update(len(data))
            file.write(data)
    progress_bar.close()
    if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
        print("ERROR, something went wrong")

def get_civitai_model_info(model_id):
    url = f"https://civitai.com/api/v1/models/{model_id}"
    response = requests.get(url)
    if response.status_code != 200:
        return None
    return response.json()

def find_download_url(data, file_extension):
    for file in data.get('modelVersions', [{}])[0].get('files', []):
        if file['name'].endswith(file_extension):
            return file['downloadUrl']
    return None

def download_civitai_model(model_id, lora_id=""):
    try:
        model_data = get_civitai_model_info(model_id)
        if model_data is None:
            return f"Error: Model with ID {model_id} not found."

        model_name = model_data['name']
        model_ckpt_url = find_download_url(model_data, '.ckpt')
        model_safetensors_url = find_download_url(model_data, '.safetensors')
        model_url = model_ckpt_url or model_safetensors_url

        if not model_url:
            return f"Error: No suitable file found for model {model_name}."

        file_extension = '.ckpt' if model_ckpt_url else '.safetensors'
        download_file(model_url, f"{model_name}{file_extension}")

        if lora_id:
            lora_data = get_civitai_model_info(lora_id)
            if lora_data is None:
                return f"Error: LoRA with ID {lora_id} not found."

            lora_name = lora_data['name']
            lora_safetensors_url = find_download_url(lora_data, '.safetensors')
            if not lora_safetensors_url:
                return f"Error: No suitable file found for LoRA {lora_name}."

            download_file(lora_safetensors_url, f"{lora_name}.safetensors")
            loras_list.append(lora_name)
        else:
            lora_name = "None"

        models_list.append(model_name)
        return "Model/LoRA Downloaded!"
    except Exception as e:
        return f"Error downloading model or LoRA: {e}"

def load_model(model, lora="", use_lora=False):
    try:
        print(f"\n\nLoading {model}...")
        vae = AutoencoderKL.from_pretrained(
            "madebyollin/sdxl-vae-fp16-fix",
            torch_dtype=torch.float16,
        )

        pipeline = StableDiffusionXLPipeline.from_pretrained(
            model,
            vae=vae,
            torch_dtype=torch.float16,
        )

        if use_lora and lora != "":
            pipeline.load_lora_weights(lora)

        pipeline.to("cuda")
        models[model] = pipeline
        return "Model/LoRA loaded successfully!"
    except Exception as e:
        return f"Error loading model {model}: {e}"

@spaces.GPU
def generate_images(
    model_name,
    lora_name,
    prompt,
    negative_prompt,
    num_inference_steps,
    guidance_scale,
    height,
    width,
    num_images=4,
    progress=gr.Progress(track_tqdm=True)
):
    if prompt is not None and prompt.strip() != "":
        if lora_name == "None":
            load_model(model_name, "", False)
        elif lora_name in loras_list and lora_name != "None":
            load_model(model_name, lora_name, True)

        pipe = models.get(model_name)
        if pipe is None:
            return []

        outputs = []

        for _ in range(num_images):
            output = pipe(
                prompt,
                negative_prompt=negative_prompt,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                height=height,
                width=width
            )["images"][0]
            outputs.append(output)

        return outputs
    else:
        return gr.Warning("Prompt empty!")

# Create the Gradio blocks
with gr.Blocks(theme='ParityError/Interstellar') as demo:
    with gr.Row(equal_height=False):
        with gr.Tab("Generate"):
            with gr.Column(elem_id="input_column"):
                with gr.Group(elem_id="input_group"):
                    model_dropdown = gr.Dropdown(choices=models_list, value=models_list[0] if models_list else None, label="Model", elem_id="model_dropdown")
                    lora_dropdown = gr.Dropdown(choices=loras_list, value=loras_list[0], label="LoRA")
                    prompt = gr.Textbox(label="Prompt", elem_id="prompt_textbox")
                    generate_btn = gr.Button("Generate Image", elem_id="generate_button")
                with gr.Accordion("Advanced", open=False, elem_id="advanced_accordion"):
                    negative_prompt = gr.Textbox(label="Negative Prompt", value="lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", elem_id="negative_prompt_textbox")
                    num_inference_steps = gr.Slider(minimum=10, maximum=50, step=1, value=25, label="Number of Inference Steps", elem_id="num_inference_steps_slider")
                    guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.5, value=7.5, label="Guidance Scale", elem_id="guidance_scale_slider")
                    height = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Height", elem_id="height_slider")
                    width = gr.Slider(minimum=1024, maximum=2048, step=256, value=1024, label="Width", elem_id="width_slider")
                    num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images", elem_id="num_images_slider")
            with gr.Column(elem_id="output_column"):
                output_gallery = gr.Gallery(label="Generated Images", height=480, scale=1, elem_id="output_gallery")

            generate_btn.click(generate_images, inputs=[model_dropdown, lora_dropdown, prompt, negative_prompt, num_inference_steps, guidance_scale, height, width, num_images], outputs=output_gallery)

        with gr.Tab("Download Custom Model"):
            with gr.Group():
                model_id = gr.Textbox(label="CivitAI Model ID")
                lora_id = gr.Textbox(label="CivitAI LoRA ID (Optional)")
                download_button = gr.Button("Download Model")

            download_output = gr.Textbox(label="Download Output")

            download_button.click(download_civitai_model, inputs=[model_id, lora_id], outputs=download_output)

demo.launch()