Oranblock commited on
Commit
e8a3c76
1 Parent(s): 0e09841

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -61
app.py CHANGED
@@ -8,33 +8,20 @@ import re
8
  import gradio as gr
9
  import numpy as np
10
  from PIL import Image
 
11
  import torch
12
  from diffusers import DiffusionPipeline
13
  from typing import Tuple
14
- import logging
15
-
16
- # Set up logging
17
- logging.basicConfig(level=logging.INFO)
18
- logger = logging.getLogger(__name__)
19
-
20
- # Check for GPU availability and fall back to CPU if necessary
21
- if torch.cuda.is_available():
22
- device = torch.device("cuda")
23
- logger.info("GPU detected. Using CUDA.")
24
- else:
25
- device = torch.device("cpu")
26
- logger.warning("No GPU detected. Falling back to CPU. This will be slower.")
27
 
28
  # Setup rules for bad words (ensure the prompts are kid-friendly)
29
  bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
30
  default_negative = os.getenv("default_negative","")
31
 
32
  def check_text(prompt, negative=""):
33
- restricted_words = []
34
- for word in bad_words:
35
- if word in prompt.lower() or word in negative.lower():
36
- restricted_words.append(word)
37
- return restricted_words
38
 
39
  # Kid-friendly styles
40
  style_list = [
@@ -76,33 +63,23 @@ def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str
76
  DESCRIPTION = """## Children's Sticker Generator
77
 
78
  Generate fun and playful stickers for children using AI.
79
-
80
  """
81
- DESCRIPTION += "🚀 Running on GPU for faster generation." if device.type == "cuda" else "⚠️ Running on CPU. This may be slower."
 
 
82
 
83
  MAX_SEED = np.iinfo(np.int32).max
84
- CACHE_EXAMPLES = True
 
 
85
 
86
  # Initialize the DiffusionPipeline
87
- try:
88
- if device.type == "cuda":
89
- pipe = DiffusionPipeline.from_pretrained(
90
- "stabilityai/stable-diffusion-xl-base-1.0",
91
- torch_dtype=torch.float16,
92
- use_safetensors=True,
93
- variant="fp16",
94
- ).to(device)
95
- pipe.enable_xformers_memory_efficient_attention()
96
- else:
97
- pipe = DiffusionPipeline.from_pretrained(
98
- "runwayml/stable-diffusion-v1-5",
99
- torch_dtype=torch.float32,
100
- use_safetensors=True,
101
- ).to(device)
102
- logger.info("DiffusionPipeline initialized successfully")
103
- except Exception as e:
104
- logger.error(f"Error initializing DiffusionPipeline: {str(e)}")
105
- raise
106
 
107
  # Convert mm to pixels for a specific DPI (300) and ensure divisible by 8
108
  def mm_to_pixels(mm, dpi=300):
@@ -143,6 +120,7 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
143
  seed = random.randint(0, MAX_SEED)
144
  return seed
145
 
 
146
  def generate(
147
  prompt: str,
148
  negative_prompt: str = "",
@@ -150,27 +128,27 @@ def generate(
150
  style: str = DEFAULT_STYLE_NAME,
151
  seed: int = 0,
152
  size: str = "75mm",
153
- guidance_scale: float = 7.5,
154
  randomize_seed: bool = False,
155
  background: str = "transparent",
156
  progress=gr.Progress(track_tqdm=True),
157
  ):
158
- restricted_words = check_text(prompt, negative_prompt)
159
- if restricted_words:
160
- return [], seed, f"Prompt contains restricted words: {', '.join(restricted_words)}"
161
-
162
  # Ensure prompt is 2-3 words long
163
- prompt = " ".join(prompt.split()[:3])
164
 
165
  # Apply style
166
  prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
167
  seed = int(randomize_seed_fn(seed, randomize_seed))
168
- generator = torch.Generator(device=device).manual_seed(seed)
169
 
 
170
  width, height = size_map.get(size, (1024, 1024))
171
 
172
  if not use_negative_prompt:
173
- negative_prompt = ""
174
 
175
  options = {
176
  "prompt": prompt,
@@ -178,9 +156,9 @@ def generate(
178
  "width": width,
179
  "height": height,
180
  "guidance_scale": guidance_scale,
181
- "num_inference_steps": 30 if device.type == "cuda" else 20,
182
  "generator": generator,
183
- "num_images_per_prompt": 4 if device.type == "cuda" else 1,
184
  "output_type": "pil",
185
  }
186
 
@@ -188,7 +166,7 @@ def generate(
188
  images = pipe(**options).images
189
  image_paths = [save_image(img, background) for img in images]
190
 
191
- return image_paths, seed, None
192
 
193
  examples = [
194
  "cute bunny",
@@ -220,7 +198,6 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
220
  )
221
  run_button = gr.Button("Run")
222
  result = gr.Gallery(label="Generated Stickers", columns=2, preview=True)
223
- error_output = gr.Textbox(label="Error", visible=False)
224
  with gr.Accordion("Advanced options", open=False):
225
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True)
226
  negative_prompt = gr.Text(
@@ -256,16 +233,16 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
256
  )
257
  guidance_scale = gr.Slider(
258
  label="Guidance Scale",
259
- minimum=1.0,
260
  maximum=20.0,
261
  step=0.1,
262
- value=7.5,
263
  )
264
 
265
  gr.Examples(
266
  examples=examples,
267
  inputs=prompt,
268
- outputs=[result, seed, error_output],
269
  fn=generate,
270
  cache_examples=CACHE_EXAMPLES,
271
  )
@@ -288,13 +265,9 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
288
  randomize_seed,
289
  background_selection,
290
  ],
291
- outputs=[result, seed, error_output],
292
  api_name="run",
293
  )
294
 
295
  if __name__ == "__main__":
296
- try:
297
- demo.queue(max_size=20).launch()
298
- except Exception as e:
299
- logger.error(f"Error launching Gradio interface: {str(e)}")
300
- raise
 
8
  import gradio as gr
9
  import numpy as np
10
  from PIL import Image
11
+ import spaces
12
  import torch
13
  from diffusers import DiffusionPipeline
14
  from typing import Tuple
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Setup rules for bad words (ensure the prompts are kid-friendly)
17
  bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
18
  default_negative = os.getenv("default_negative","")
19
 
20
  def check_text(prompt, negative=""):
21
+ for i in bad_words:
22
+ if i in prompt:
23
+ return True
24
+ return False
 
25
 
26
  # Kid-friendly styles
27
  style_list = [
 
63
  DESCRIPTION = """## Children's Sticker Generator
64
 
65
  Generate fun and playful stickers for children using AI.
 
66
  """
67
+
68
+ if not torch.cuda.is_available():
69
+ DESCRIPTION += "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
70
 
71
  MAX_SEED = np.iinfo(np.int32).max
72
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
73
+
74
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
75
 
76
  # Initialize the DiffusionPipeline
77
+ pipe = DiffusionPipeline.from_pretrained(
78
+ "SG161222/RealVisXL_V3.0_Turbo", # or any model of your choice
79
+ torch_dtype=torch.float16,
80
+ use_safetensors=True,
81
+ variant="fp16"
82
+ ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Convert mm to pixels for a specific DPI (300) and ensure divisible by 8
85
  def mm_to_pixels(mm, dpi=300):
 
120
  seed = random.randint(0, MAX_SEED)
121
  return seed
122
 
123
+ @spaces.GPU(enable_queue=True)
124
  def generate(
125
  prompt: str,
126
  negative_prompt: str = "",
 
128
  style: str = DEFAULT_STYLE_NAME,
129
  seed: int = 0,
130
  size: str = "75mm",
131
+ guidance_scale: float = 3,
132
  randomize_seed: bool = False,
133
  background: str = "transparent",
134
  progress=gr.Progress(track_tqdm=True),
135
  ):
136
+ if check_text(prompt, negative_prompt):
137
+ raise ValueError("Prompt contains restricted words.")
138
+
 
139
  # Ensure prompt is 2-3 words long
140
+ prompt = " ".join(re.findall(r'\w+', prompt)[:3])
141
 
142
  # Apply style
143
  prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
144
  seed = int(randomize_seed_fn(seed, randomize_seed))
145
+ generator = torch.Generator().manual_seed(seed)
146
 
147
+ # Ensure we have only white or transparent background options
148
  width, height = size_map.get(size, (1024, 1024))
149
 
150
  if not use_negative_prompt:
151
+ negative_prompt = "" # type: ignore
152
 
153
  options = {
154
  "prompt": prompt,
 
156
  "width": width,
157
  "height": height,
158
  "guidance_scale": guidance_scale,
159
+ "num_inference_steps": 25,
160
  "generator": generator,
161
+ "num_images_per_prompt": 6, # Max 6 images
162
  "output_type": "pil",
163
  }
164
 
 
166
  images = pipe(**options).images
167
  image_paths = [save_image(img, background) for img in images]
168
 
169
+ return image_paths, seed
170
 
171
  examples = [
172
  "cute bunny",
 
198
  )
199
  run_button = gr.Button("Run")
200
  result = gr.Gallery(label="Generated Stickers", columns=2, preview=True)
 
201
  with gr.Accordion("Advanced options", open=False):
202
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True)
203
  negative_prompt = gr.Text(
 
233
  )
234
  guidance_scale = gr.Slider(
235
  label="Guidance Scale",
236
+ minimum=0.1,
237
  maximum=20.0,
238
  step=0.1,
239
+ value=6,
240
  )
241
 
242
  gr.Examples(
243
  examples=examples,
244
  inputs=prompt,
245
+ outputs=[result, seed],
246
  fn=generate,
247
  cache_examples=CACHE_EXAMPLES,
248
  )
 
265
  randomize_seed,
266
  background_selection,
267
  ],
268
+ outputs=[result, seed],
269
  api_name="run",
270
  )
271
 
272
  if __name__ == "__main__":
273
+ demo.queue(max_size=20).launch()