nsfwalex commited on
Commit
6fd5d66
1 Parent(s): adc9c0a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -0
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import os
4
+ import random
5
+ import json
6
+ import uuid
7
+ from huggingface_hub import snapshot_download
8
+ from diffusers import AutoencoderKL
9
+ from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler, AutoPipelineForText2Image, DiffusionPipeline
10
+ from diffusers import EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSDEScheduler
11
+ from diffusers.models.attention_processor import AttnProcessor2_0
12
+ import torch
13
+ from typing import Tuple
14
+ from datetime import datetime
15
+ import requests
16
+ import torch
17
+ from diffusers import DiffusionPipeline
18
+ import importlib
19
+
20
+ MAX_SEED = 12211231
21
+ CACHE_EXAMPLES = "1"
22
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4192"))
23
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "1") == "1"
24
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
25
+
26
+ NUM_IMAGES_PER_PROMPT = 1
27
+
28
+
29
+ cfg = json.load(open("app.conf"))
30
+
31
+ def load_pipeline_and_scheduler():
32
+ clip_skip = cfg.get("clip_skip", 0)
33
+
34
+ # Download the model files
35
+ ckpt_dir = snapshot_download(repo_id=cfg["model_id"])
36
+
37
+ # Load the models
38
+ vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.float16)
39
+
40
+ pipe = StableDiffusionXLPipeline.from_pretrained(
41
+ ckpt_dir,
42
+ vae=vae,
43
+ torch_dtype=torch.float16,
44
+ use_safetensors=True,
45
+ variant="fp16"
46
+ )
47
+ pipe = pipe.to("cuda")
48
+
49
+ pipe.unet.set_attn_processor(AttnProcessor2_0())
50
+
51
+ # Define samplers
52
+ samplers = {
53
+ "Euler a": EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config),
54
+ "DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
55
+ }
56
+ # Set the scheduler based on the selected sampler
57
+ pipe.scheduler = samplers[cfg.get("sampler","DPM++ SDE Karras")]
58
+
59
+ # Set clip skip
60
+ pipe.text_encoder.config.num_hidden_layers -= (clip_skip - 1)
61
+
62
+ if USE_TORCH_COMPILE:
63
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
64
+ print("Model Compiled!")
65
+ return pipe
66
+ pipe = load_pipeline_and_scheduler()
67
+ css = '''
68
+ .gradio-container{max-width: 560px !important}
69
+ body {
70
+ background-color: rgb(3, 7, 18);
71
+ color: white;
72
+ }
73
+ .gradio-container {
74
+ background-color: rgb(3, 7, 18) !important;
75
+ border: none !important;
76
+ }
77
+ '''
78
+ js = '''
79
+ <script src="https://raw.githubusercontent.com/insanensfwdev/hf-gradio-text2img-card/main/prompt.js"></script>
80
+ <script>
81
+ window.g=function(){
82
+ const conditions = {
83
+ "tag": ["normal", "sexy", "porn"],
84
+ "exclude_category": ["Clothing"],
85
+ "count_per_tag": 1
86
+ };
87
+ prompt = generateSexyPrompt()
88
+ console.log(prompt);
89
+ return prompt
90
+ }
91
+ window.postMessageToParent = function(prompt, event, source, value) {
92
+ // Construct the message object with the provided parameters
93
+ console.log("post start",event, source, value);
94
+ const message = {
95
+ event: event,
96
+ source: source,
97
+ value: value
98
+ };
99
+
100
+ // Post the message to the parent window
101
+ window.parent.postMessage(message, '*');
102
+ console.log("post finish");
103
+ return prompt;
104
+ }
105
+ function uploadImage(prompt, images, event, source, value) {
106
+ // Ensure we're in an iframe
107
+ console.log("uploadImage", prompt, images && images.length > 0 ? images[0].image.url : null, event, source, value);
108
+ if (window.self !== window.top) {
109
+ // Get the first image from the gallery (assuming it's an array)
110
+ let imageUrl = images && images.length > 0 ? images[0].image.url : null;
111
+
112
+ // Prepare the data to send
113
+ let data = {
114
+ event: event,
115
+ source: source,
116
+ prompt: prompt,
117
+ image: imageUrl
118
+ };
119
+
120
+ // Post the message to the parent window
121
+ window.parent.postMessage(JSON.stringify(data), '*');
122
+ } else {
123
+ console.log("Not in an iframe, can't post to parent");
124
+ }
125
+ }
126
+ </script>
127
+ '''
128
+ def save_image(img):
129
+ unique_name = str(uuid.uuid4()) + ".png"
130
+ img.save(unique_name)
131
+ return unique_name
132
+
133
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
134
+ if randomize_seed:
135
+ seed = random.randint(0, MAX_SEED)
136
+ return seed
137
+
138
+ @spaces.GPU(duration=60)
139
+ def generate(prompt, progress=gr.Progress(track_tqdm=True)):
140
+ negative_prompt = cfg.get("negative_prompt", "")
141
+ style_selection = ""
142
+ use_negative_prompt = True
143
+ seed = 0
144
+ width = cfg.get("width", 1024)
145
+ height = cfg.get("width", 768)
146
+ inference_steps = cfg.get("inference_steps", 30)
147
+ randomize_seed = True
148
+ guidance_scale = cfg.get("guidance_scale", 7.5)
149
+ prompt_str = cfg.get("prompt", "{prompt}").replace("{prompt}", prompt)
150
+
151
+ seed = int(randomize_seed_fn(seed, randomize_seed))
152
+ generator = torch.Generator(pipe.device).manual_seed(seed)
153
+
154
+ images = pipe(
155
+ prompt=prompt_str,
156
+ negative_prompt=negative_prompt,
157
+ width=width,
158
+ height=height,
159
+ guidance_scale=guidance_scale,
160
+ num_inference_steps=inference_steps,
161
+ generator=generator,
162
+ num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
163
+ output_type="pil",
164
+ ).images
165
+
166
+ image_paths = [save_image(img) for img in images]
167
+ print(image_paths)
168
+ return image_paths
169
+
170
+
171
+ with gr.Blocks(css=css,head=js,fill_height=True) as demo:
172
+ with gr.Row(equal_height=False):
173
+ with gr.Group():
174
+ result = gr.Gallery(value=cfg.get("cover_path",""),
175
+ label="Result", show_label=False, columns=1, rows=1, show_share_button=True,
176
+ show_download_button=True,allow_preview=True,interactive=False, min_width=cfg.get("window_min_width", 340)
177
+ )
178
+ with gr.Row():
179
+ prompt = gr.Text(
180
+ show_label=False,
181
+ max_lines=2,
182
+ lines=2,
183
+ placeholder="Enter what you want to see",
184
+ container=False,
185
+ scale=5,
186
+ min_width=100,
187
+ )
188
+ random_button = gr.Button("Surprise Me", scale=1, min_width=10)
189
+ run_button = gr.Button( "GO!", scale=1, min_width=20)
190
+
191
+ random_button.click(fn=lambda x:x, inputs=[prompt], outputs=[prompt], js='''()=>window.g()''')
192
+ run_button.click(generate, inputs=[prompt], outputs=[result], js=f'''(p)=>window.postMessageToParent(p,"process_started","demo_hf_{cfg.get("name")}_card", "click_go")''')
193
+ result.change(fn=lambda x:x, inputs=[prompt,result], outputs=[], js=f'''(p,img)=>window.uploadImage(p, img,"process_started","demo_hf_{cfg.get("name")}_card", "finish")''')
194
+
195
+ if __name__ == "__main__":
196
+ demo.queue(max_size=200).launch()