Oranblock commited on
Commit
7d9826c
1 Parent(s): 9f5148a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -120
app.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import random
5
  import uuid
6
  import json
7
-
8
  import gradio as gr
9
  import numpy as np
10
  from PIL import Image
@@ -13,25 +13,18 @@ import torch
13
  from diffusers import DiffusionPipeline
14
  from typing import Tuple
15
 
16
- # Check for the Model Base..//
17
-
18
- bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary"]'))
19
- bad_words_negative = json.loads(os.getenv('BAD_WORDS_NEGATIVE', "[]"))
20
  default_negative = os.getenv("default_negative","")
21
 
22
  def check_text(prompt, negative=""):
23
  for i in bad_words:
24
  if i in prompt:
25
  return True
26
- for i in bad_words_negative:
27
- if i in negative:
28
- return True
29
  return False
30
 
31
- # Updated to child-friendly styles
32
-
33
  style_list = [
34
-
35
  {
36
  "name": "Cartoon",
37
  "prompt": "colorful cartoon {prompt}. vibrant, playful, friendly, suitable for children, highly detailed, bright colors",
@@ -42,19 +35,16 @@ style_list = [
42
  "prompt": "children's illustration {prompt}. cute, colorful, fun, simple shapes, smooth lines, highly detailed, joyful",
43
  "negative_prompt": "scary, dark, violent, deformed, ugly",
44
  },
45
-
46
  {
47
  "name": "Sticker",
48
  "prompt": "children's sticker of {prompt}. bright colors, playful, high resolution, cartoonish",
49
  "negative_prompt": "scary, dark, violent, ugly, low resolution",
50
  },
51
-
52
  {
53
  "name": "Fantasy",
54
  "prompt": "fantasy world for children with {prompt}. magical, vibrant, friendly, beautiful, colorful",
55
  "negative_prompt": "dark, scary, violent, ugly, realistic",
56
  },
57
-
58
  {
59
  "name": "(No style)",
60
  "prompt": "{prompt}",
@@ -68,8 +58,6 @@ DEFAULT_STYLE_NAME = "Sticker"
68
 
69
  def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
70
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
71
- if not negative:
72
- negative = ""
73
  return p.replace("{prompt}", positive), n + negative
74
 
75
  DESCRIPTION = """## Children's Sticker Generator
@@ -82,43 +70,39 @@ if not torch.cuda.is_available():
82
 
83
  MAX_SEED = np.iinfo(np.int32).max
84
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
85
- MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
86
- USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
87
- ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
88
 
89
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
90
 
91
- NUM_IMAGES_PER_PROMPT = 1
92
-
93
- if torch.cuda.is_available():
94
- pipe = DiffusionPipeline.from_pretrained(
95
- "SG161222/RealVisXL_V3.0_Turbo",
96
- torch_dtype=torch.float16,
97
- use_safetensors=True,
98
- add_watermarker=False,
99
- variant="fp16"
100
- )
101
- pipe2 = DiffusionPipeline.from_pretrained(
102
- "SG161222/RealVisXL_V2.02_Turbo",
103
- torch_dtype=torch.float16,
104
- use_safetensors=True,
105
- add_watermarker=False,
106
- variant="fp16"
107
- )
108
- if ENABLE_CPU_OFFLOAD:
109
- pipe.enable_model_cpu_offload()
110
- pipe2.enable_model_cpu_offload()
111
- else:
112
- pipe.to(device)
113
- pipe2.to(device)
114
- print("Loaded on Device!")
115
-
116
- if USE_TORCH_COMPILE:
117
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
118
- pipe2.unet = torch.compile(pipe2.unet, mode="reduce-overhead", fullgraph=True)
119
- print("Model Compiled!")
120
-
121
- def save_image(img):
122
  unique_name = str(uuid.uuid4()) + ".png"
123
  img.save(unique_name)
124
  return unique_name
@@ -135,23 +119,28 @@ def generate(
135
  use_negative_prompt: bool = False,
136
  style: str = DEFAULT_STYLE_NAME,
137
  seed: int = 0,
138
- width: int = 512,
139
- height: int = 512,
140
  guidance_scale: float = 3,
141
  randomize_seed: bool = False,
142
- use_resolution_binning: bool = True,
143
  progress=gr.Progress(track_tqdm=True),
144
  ):
145
  if check_text(prompt, negative_prompt):
146
  raise ValueError("Prompt contains restricted words.")
147
 
 
 
 
 
148
  prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
149
  seed = int(randomize_seed_fn(seed, randomize_seed))
150
  generator = torch.Generator().manual_seed(seed)
151
 
 
 
 
152
  if not use_negative_prompt:
153
  negative_prompt = "" # type: ignore
154
- negative_prompt += default_negative
155
 
156
  options = {
157
  "prompt": prompt,
@@ -162,19 +151,19 @@ def generate(
162
  "num_inference_steps": 25,
163
  "generator": generator,
164
  "num_images_per_prompt": NUM_IMAGES_PER_PROMPT,
165
- "use_resolution_binning": use_resolution_binning,
166
  "output_type": "pil",
167
  }
168
 
169
- images = pipe(**options).images + pipe2(**options).images
170
-
171
- image_paths = [save_image(img) for img in images]
 
172
  return image_paths, seed
173
 
174
  examples = [
175
- "A cute cartoon bunny holding a carrot in a colorful garden",
176
- "A playful dragon flying through the clouds, bright and friendly",
177
- "A magical unicorn standing on a rainbow with sparkles",
178
  ]
179
 
180
  css = '''
@@ -182,6 +171,7 @@ css = '''
182
  h1{text-align:center}
183
  '''
184
 
 
185
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
186
  gr.Markdown(DESCRIPTION)
187
  gr.DuplicateButton(
@@ -195,36 +185,20 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
195
  label="Enter your prompt",
196
  show_label=False,
197
  max_lines=1,
198
- placeholder="Enter a fun, child-friendly idea (e.g., cute bunny with a rainbow)",
199
  container=False,
200
  )
201
  run_button = gr.Button("Run")
202
- result = gr.Gallery(label="Generated Stickers", columns=1, preview=True)
203
  with gr.Accordion("Advanced options", open=False):
204
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True)
205
  negative_prompt = gr.Text(
206
  label="Negative prompt",
207
  max_lines=1,
208
  placeholder="Enter a negative prompt",
209
- value="(scary, violent, deformed, ugly, dark)",
210
  visible=True,
211
  )
212
- with gr.Row():
213
- num_inference_steps = gr.Slider(
214
- label="Steps",
215
- minimum=10,
216
- maximum=60,
217
- step=1,
218
- value=25,
219
- )
220
- with gr.Row():
221
- num_images_per_prompt = gr.Slider(
222
- label="Images",
223
- minimum=1,
224
- maximum=5,
225
- step=1,
226
- value=2,
227
- )
228
  seed = gr.Slider(
229
  label="Seed",
230
  minimum=0,
@@ -234,38 +208,24 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
234
  visible=True
235
  )
236
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
237
- with gr.Row(visible=True):
238
- width = gr.Slider(
239
- label="Width",
240
- minimum=512,
241
- maximum=1024,
242
- step=8,
243
- value=512,
244
- )
245
- height = gr.Slider(
246
- label="Height",
247
- minimum=512,
248
- maximum=1024,
249
- step=8,
250
- value=512,
251
- )
252
- with gr.Row():
253
- guidance_scale = gr.Slider(
254
- label="Guidance Scale",
255
- minimum=0.1,
256
- maximum=20.0,
257
- step=0.1,
258
- value=7,
259
- )
260
- with gr.Row(visible=True):
261
- style_selection = gr.Radio(
262
- show_label=True,
263
- container=True,
264
- interactive=True,
265
- choices=STYLE_NAMES,
266
- value=DEFAULT_STYLE_NAME,
267
- label="Sticker Style",
268
  )
 
269
  gr.Examples(
270
  examples=examples,
271
  inputs=prompt,
@@ -274,13 +234,6 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
274
  cache_examples=CACHE_EXAMPLES,
275
  )
276
 
277
- use_negative_prompt.change(
278
- fn=lambda x: gr.update(visible=x),
279
- inputs=use_negative_prompt,
280
- outputs=negative_prompt,
281
- api_name=False,
282
- )
283
-
284
  gr.on(
285
  triggers=[
286
  prompt.submit,
@@ -294,10 +247,10 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
294
  use_negative_prompt,
295
  style_selection,
296
  seed,
297
- width,
298
- height,
299
  guidance_scale,
300
  randomize_seed,
 
301
  ],
302
  outputs=[result, seed],
303
  api_name="run",
 
4
  import random
5
  import uuid
6
  import json
7
+ import re
8
  import gradio as gr
9
  import numpy as np
10
  from PIL import Image
 
13
  from diffusers import DiffusionPipeline
14
  from typing import Tuple
15
 
16
+ # Setup rules for bad words (ensure the prompts are kid-friendly)
17
+ bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
 
 
18
  default_negative = os.getenv("default_negative","")
19
 
20
  def check_text(prompt, negative=""):
21
  for i in bad_words:
22
  if i in prompt:
23
  return True
 
 
 
24
  return False
25
 
26
+ # Kid-friendly styles
 
27
  style_list = [
 
28
  {
29
  "name": "Cartoon",
30
  "prompt": "colorful cartoon {prompt}. vibrant, playful, friendly, suitable for children, highly detailed, bright colors",
 
35
  "prompt": "children's illustration {prompt}. cute, colorful, fun, simple shapes, smooth lines, highly detailed, joyful",
36
  "negative_prompt": "scary, dark, violent, deformed, ugly",
37
  },
 
38
  {
39
  "name": "Sticker",
40
  "prompt": "children's sticker of {prompt}. bright colors, playful, high resolution, cartoonish",
41
  "negative_prompt": "scary, dark, violent, ugly, low resolution",
42
  },
 
43
  {
44
  "name": "Fantasy",
45
  "prompt": "fantasy world for children with {prompt}. magical, vibrant, friendly, beautiful, colorful",
46
  "negative_prompt": "dark, scary, violent, ugly, realistic",
47
  },
 
48
  {
49
  "name": "(No style)",
50
  "prompt": "{prompt}",
 
58
 
59
  def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
60
  p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
 
 
61
  return p.replace("{prompt}", positive), n + negative
62
 
63
  DESCRIPTION = """## Children's Sticker Generator
 
70
 
71
  MAX_SEED = np.iinfo(np.int32).max
72
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
 
 
 
73
 
74
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
75
 
76
+ NUM_IMAGES_PER_PROMPT = 6 # Set maximum images per prompt
77
+
78
+ # Convert mm to pixels for a specific DPI (300)
79
+ def mm_to_pixels(mm, dpi=300):
80
+ return int((mm / 25.4) * dpi)
81
+
82
+ # Default sizes for 75mm and 35mm
83
+ size_map = {
84
+ "75mm": (mm_to_pixels(75), mm_to_pixels(75)), # 75mm in pixels at 300dpi
85
+ "35mm": (mm_to_pixels(35), mm_to_pixels(35)), # 35mm in pixels at 300dpi
86
+ }
87
+
88
+ # Function to post-process images (transparent or white background)
89
+ def save_image(img, background="transparent"):
90
+ img = img.convert("RGBA")
91
+ data = img.getdata()
92
+ new_data = []
93
+
94
+ if background == "transparent":
95
+ for item in data:
96
+ # Replace white with transparent
97
+ if item[0] == 255 and item[1] == 255 and item[2] == 255:
98
+ new_data.append((255, 255, 255, 0)) # Transparent
99
+ else:
100
+ new_data.append(item)
101
+ elif background == "white":
102
+ for item in data:
103
+ new_data.append(item) # Keep as white
104
+
105
+ img.putdata(new_data)
 
106
  unique_name = str(uuid.uuid4()) + ".png"
107
  img.save(unique_name)
108
  return unique_name
 
119
  use_negative_prompt: bool = False,
120
  style: str = DEFAULT_STYLE_NAME,
121
  seed: int = 0,
122
+ size: str = "75mm",
 
123
  guidance_scale: float = 3,
124
  randomize_seed: bool = False,
125
+ background: str = "transparent",
126
  progress=gr.Progress(track_tqdm=True),
127
  ):
128
  if check_text(prompt, negative_prompt):
129
  raise ValueError("Prompt contains restricted words.")
130
 
131
+ # Ensure prompt is 2-3 words long
132
+ prompt = " ".join(re.findall(r'\w+', prompt)[:3])
133
+
134
+ # Apply style
135
  prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
136
  seed = int(randomize_seed_fn(seed, randomize_seed))
137
  generator = torch.Generator().manual_seed(seed)
138
 
139
+ # Ensure we have only white or transparent background options
140
+ width, height = size_map.get(size, (1024, 1024))
141
+
142
  if not use_negative_prompt:
143
  negative_prompt = "" # type: ignore
 
144
 
145
  options = {
146
  "prompt": prompt,
 
151
  "num_inference_steps": 25,
152
  "generator": generator,
153
  "num_images_per_prompt": NUM_IMAGES_PER_PROMPT,
 
154
  "output_type": "pil",
155
  }
156
 
157
+ # Generate images with the pipeline
158
+ images = pipe(**options).images
159
+ image_paths = [save_image(img, background) for img in images]
160
+
161
  return image_paths, seed
162
 
163
  examples = [
164
+ "cute bunny",
165
+ "happy cat",
166
+ "funny dog",
167
  ]
168
 
169
  css = '''
 
171
  h1{text-align:center}
172
  '''
173
 
174
+ # Define the Gradio UI for the sticker generator
175
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
176
  gr.Markdown(DESCRIPTION)
177
  gr.DuplicateButton(
 
185
  label="Enter your prompt",
186
  show_label=False,
187
  max_lines=1,
188
+ placeholder="Enter 2-3 word prompt (e.g., cute bunny)",
189
  container=False,
190
  )
191
  run_button = gr.Button("Run")
192
+ result = gr.Gallery(label="Generated Stickers", columns=2, preview=True)
193
  with gr.Accordion("Advanced options", open=False):
194
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True)
195
  negative_prompt = gr.Text(
196
  label="Negative prompt",
197
  max_lines=1,
198
  placeholder="Enter a negative prompt",
199
+ value="(scary, violent, dark, ugly)",
200
  visible=True,
201
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  seed = gr.Slider(
203
  label="Seed",
204
  minimum=0,
 
208
  visible=True
209
  )
210
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
211
+ size_selection = gr.Radio(
212
+ choices=["75mm", "35mm"],
213
+ value="75mm",
214
+ label="Sticker Size",
215
+ )
216
+ background_selection = gr.Radio(
217
+ choices=["transparent", "white"],
218
+ value="transparent",
219
+ label="Background Color",
220
+ )
221
+ guidance_scale = gr.Slider(
222
+ label="Guidance Scale",
223
+ minimum=0.1,
224
+ maximum=20.0,
225
+ step=0.1,
226
+ value=6,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  )
228
+
229
  gr.Examples(
230
  examples=examples,
231
  inputs=prompt,
 
234
  cache_examples=CACHE_EXAMPLES,
235
  )
236
 
 
 
 
 
 
 
 
237
  gr.on(
238
  triggers=[
239
  prompt.submit,
 
247
  use_negative_prompt,
248
  style_selection,
249
  seed,
250
+ size_selection,
 
251
  guidance_scale,
252
  randomize_seed,
253
+ background_selection,
254
  ],
255
  outputs=[result, seed],
256
  api_name="run",