patrickvonplaten commited on
Commit
d44229e
·
1 Parent(s): c803873
Files changed (2) hide show
  1. app.py +325 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ StableDiffusionPipeline,
3
+ StableDiffusionImg2ImgPipeline,
4
+ DPMSolverMultistepScheduler,
5
+ )
6
+ import gradio as gr
7
+ import torch
8
+ from PIL import Image
9
+ import time
10
+ import psutil
11
+ import random
12
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
13
+
14
+
15
+ start_time = time.time()
16
+ current_steps = 25
17
+
18
+ SAFETY_CHECKER = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=torch.float16)
19
+
20
+
21
+ class Model:
22
+ def __init__(self, name, path=""):
23
+ self.name = name
24
+ self.path = path
25
+
26
+ if path != "":
27
+ self.pipe_t2i = StableDiffusionPipeline.from_pretrained(
28
+ path, torch_dtype=torch.float16, safety_checker=SAFETY_CHECKER
29
+ )
30
+ self.pipe_t2i.scheduler = DPMSolverMultistepScheduler.from_config(
31
+ self.pipe_t2i.scheduler.config
32
+ )
33
+ self.pipe_i2i = StableDiffusionImg2ImgPipeline(**self.pipe_t2i.components)
34
+ else:
35
+ self.pipe_t2i = None
36
+ self.pipe_i2i = None
37
+
38
+
39
+ models = [
40
+ Model("Protogen v2.2 (Anime)", "darkstorm2150/Protogen_v2.2_Official_Release"),
41
+ Model("Protogen x3.4 (Photorealism)", "darkstorm2150/Protogen_x3.4_Official_Release"),
42
+ Model("Protogen x5.3 (Photorealism)", "darkstorm2150/Protogen_x5.3_Official_Release"),
43
+ Model("Protogen x5.8 Rebuilt (Scifi+Anime)", "darkstorm2150/Protogen_x5.8_Official_Release"),
44
+ Model("Protogen Dragon (RPG Model)", "darkstorm2150/Protogen_Dragon_Official_Release"),
45
+ Model("Protogen Nova", "darkstorm2150/Protogen_Nova_Official_Release"),
46
+ Model("Protogen Eclipse", "darkstorm2150/Protogen_Eclipse_Official_Release"),
47
+ Model("Protogen Infinity", "darkstorm2150/Protogen_Infinity_Official_Release"),
48
+ ]
49
+
50
+ MODELS = {m.name: m for m in models}
51
+
52
+ device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
53
+
54
+
55
+ def error_str(error, title="Error"):
56
+ return (
57
+ f"""#### {title}
58
+ {error}"""
59
+ if error
60
+ else ""
61
+ )
62
+
63
+
64
+ def inference(
65
+ model_name,
66
+ prompt,
67
+ guidance,
68
+ steps,
69
+ n_images=1,
70
+ width=512,
71
+ height=512,
72
+ seed=0,
73
+ img=None,
74
+ strength=0.5,
75
+ neg_prompt="",
76
+ ):
77
+
78
+ print(psutil.virtual_memory()) # print memory usage
79
+
80
+ if seed == 0:
81
+ seed = random.randint(0, 2147483647)
82
+
83
+ generator = torch.Generator("cuda").manual_seed(seed)
84
+
85
+ try:
86
+ if img is not None:
87
+ return (
88
+ img_to_img(
89
+ model_name,
90
+ prompt,
91
+ n_images,
92
+ neg_prompt,
93
+ img,
94
+ strength,
95
+ guidance,
96
+ steps,
97
+ width,
98
+ height,
99
+ generator,
100
+ seed,
101
+ ),
102
+ f"Done. Seed: {seed}",
103
+ )
104
+ else:
105
+ return (
106
+ txt_to_img(
107
+ model_name,
108
+ prompt,
109
+ n_images,
110
+ neg_prompt,
111
+ guidance,
112
+ steps,
113
+ width,
114
+ height,
115
+ generator,
116
+ seed,
117
+ ),
118
+ f"Done. Seed: {seed}",
119
+ )
120
+ except Exception as e:
121
+ return None, error_str(e)
122
+
123
+
124
+ def txt_to_img(
125
+ model_name,
126
+ prompt,
127
+ n_images,
128
+ neg_prompt,
129
+ guidance,
130
+ steps,
131
+ width,
132
+ height,
133
+ generator,
134
+ seed,
135
+ ):
136
+ pipe = MODELS[model_name].pipe_t2i
137
+
138
+ if torch.cuda.is_available():
139
+ pipe = pipe.to("cuda")
140
+ pipe.enable_xformers_memory_efficient_attention()
141
+
142
+ result = pipe(
143
+ prompt,
144
+ negative_prompt=neg_prompt,
145
+ num_images_per_prompt=n_images,
146
+ num_inference_steps=int(steps),
147
+ guidance_scale=guidance,
148
+ width=width,
149
+ height=height,
150
+ generator=generator,
151
+ )
152
+
153
+ pipe.to("cpu")
154
+
155
+ return replace_nsfw_images(result)
156
+
157
+
158
+ def img_to_img(
159
+ model_name,
160
+ prompt,
161
+ n_images,
162
+ neg_prompt,
163
+ img,
164
+ strength,
165
+ guidance,
166
+ steps,
167
+ width,
168
+ height,
169
+ generator,
170
+ seed,
171
+ ):
172
+ pipe = MODELS[model_name].pipe_i2i
173
+
174
+ if torch.cuda.is_available():
175
+ pipe = pipe.to("cuda")
176
+ pipe.enable_xformers_memory_efficient_attention()
177
+
178
+ ratio = min(height / img.height, width / img.width)
179
+ img = img.resize((int(img.width * ratio), int(img.height * ratio)), Image.LANCZOS)
180
+
181
+ result = pipe(
182
+ prompt,
183
+ negative_prompt=neg_prompt,
184
+ num_images_per_prompt=n_images,
185
+ image=img,
186
+ num_inference_steps=int(steps),
187
+ strength=strength,
188
+ guidance_scale=guidance,
189
+ generator=generator,
190
+ )
191
+
192
+ pipe.to("cpu")
193
+
194
+ return replace_nsfw_images(result)
195
+
196
+
197
+ def replace_nsfw_images(results):
198
+ for i in range(len(results.images)):
199
+ if results.nsfw_content_detected[i]:
200
+ results.images[i] = Image.open("nsfw.png")
201
+ return results.images
202
+
203
+
204
+ with gr.Blocks(css="style.css") as demo:
205
+ with gr.Row():
206
+
207
+ with gr.Column(scale=55):
208
+ with gr.Group():
209
+ prompt = gr.Textbox(
210
+ label="Repo id on Hub",
211
+ placeholder="Path to model, e.g. CompVis/stable-diffusion-v1-4",
212
+ )
213
+ with gr.Box(visible=False) as custom_model_group:
214
+ custom_model_path = gr.Textbox(
215
+ label="Custom model path",
216
+ placeholder="Path to model, e.g. darkstorm2150/Protogen_x3.4_Official_Release",
217
+ interactive=True,
218
+ )
219
+ gr.HTML(
220
+ "<div><font size='2'>Custom models have to be downloaded first, so give it some time.</font></div>"
221
+ )
222
+
223
+ with gr.Row():
224
+ prompt = gr.Textbox(
225
+ label="Prompt",
226
+ show_label=False,
227
+ max_lines=2,
228
+ placeholder="Enter prompt.",
229
+ ).style(container=False)
230
+ generate = gr.Button(value="Generate").style(
231
+ rounded=(False, True, True, False)
232
+ )
233
+
234
+ # image_out = gr.Image(height=512)
235
+ gallery = gr.Gallery(
236
+ label="Generated images", show_label=False, elem_id="gallery"
237
+ ).style(grid=[2], height="auto")
238
+
239
+ state_info = gr.Textbox(label="State", show_label=False, max_lines=2).style(
240
+ container=False
241
+ )
242
+ error_output = gr.Markdown()
243
+
244
+ with gr.Column(scale=45):
245
+ with gr.Tab("Options"):
246
+ with gr.Group():
247
+ neg_prompt = gr.Textbox(
248
+ label="Negative prompt",
249
+ placeholder="What to exclude from the image",
250
+ )
251
+
252
+ n_images = gr.Slider(
253
+ label="Images", value=1, minimum=1, maximum=4, step=1
254
+ )
255
+
256
+ with gr.Row():
257
+ guidance = gr.Slider(
258
+ label="Guidance scale", value=7.5, maximum=15
259
+ )
260
+ steps = gr.Slider(
261
+ label="Steps",
262
+ value=current_steps,
263
+ minimum=2,
264
+ maximum=75,
265
+ step=1,
266
+ )
267
+
268
+ with gr.Row():
269
+ width = gr.Slider(
270
+ label="Width", value=512, minimum=64, maximum=1024, step=8
271
+ )
272
+ height = gr.Slider(
273
+ label="Height", value=512, minimum=64, maximum=1024, step=8
274
+ )
275
+
276
+ seed = gr.Slider(
277
+ 0, 2147483647, label="Seed (0 = random)", value=0, step=1
278
+ )
279
+
280
+ with gr.Tab("Image to image"):
281
+ with gr.Group():
282
+ image = gr.Image(
283
+ label="Image", height=256, tool="editor", type="pil"
284
+ )
285
+ strength = gr.Slider(
286
+ label="Transformation strength",
287
+ minimum=0,
288
+ maximum=1,
289
+ step=0.01,
290
+ value=0.5,
291
+ )
292
+
293
+ inputs = [
294
+ model_name,
295
+ prompt,
296
+ guidance,
297
+ steps,
298
+ n_images,
299
+ width,
300
+ height,
301
+ seed,
302
+ image,
303
+ strength,
304
+ neg_prompt,
305
+ ]
306
+ outputs = [gallery, error_output]
307
+ prompt.submit(inference, inputs=inputs, outputs=outputs)
308
+ generate.click(inference, inputs=inputs, outputs=outputs)
309
+
310
+ gr.HTML(
311
+ """
312
+ <div style="border-top: 1px solid #303030;">
313
+ <br>
314
+ <p>Models by <a href="https://huggingface.co/darkstorm2150">@darkstorm2150</a> and others. ❤️</p>
315
+ <p>This space uses the <a href="https://github.com/LuChengTHU/dpm-solver">DPM-Solver++</a> sampler by <a href="https://arxiv.org/abs/2206.00927">Cheng Lu, et al.</a>.</p>
316
+ <p>Space by: Darkstorm (Victor Espinoza)<br>
317
+ <a href="https://www.instagram.com/officialvictorespinoza/">Instagram</a>
318
+ </div>
319
+ """
320
+ )
321
+
322
+ print(f"Space built in {time.time() - start_time:.2f} seconds")
323
+
324
+ demo.queue(concurrency_count=1)
325
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ git+https://github.com/huggingface/diffusers.git
3
+ git+https://github.com/huggingface/transformers
4
+ scipy
5
+ ftfy
6
+ psutil
7
+ accelerate==0.12.0