Oranblock commited on
Commit
20b41d0
1 Parent(s): 7bfe827

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -78
app.py CHANGED
@@ -8,24 +8,17 @@ import re
8
  import gradio as gr
9
  import numpy as np
10
  from PIL import Image
11
- import spaces
12
  import torch
13
  from diffusers import DiffusionPipeline
14
  from typing import Tuple
15
 
16
- # Initialize device to None
17
- device = None
18
- pipe = None
19
 
20
  # Setup rules for bad words (ensure the prompts are kid-friendly)
21
  bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
22
  default_negative = os.getenv("default_negative","")
23
 
24
- # Remove the GPU-specific import and decorator
25
- # import spaces
26
- # @spaces.GPU(enable_queue=True)
27
-
28
- # Update the check_text function to be more informative
29
  def check_text(prompt, negative=""):
30
  restricted_words = []
31
  for word in bad_words:
@@ -33,39 +26,6 @@ def check_text(prompt, negative=""):
33
  restricted_words.append(word)
34
  return restricted_words
35
 
36
-
37
- restricted_words = check_text(prompt, negative_prompt)
38
- if restricted_words:
39
- return [], seed, f"Prompt contains restricted words: {', '.join(restricted_words)}"
40
-
41
- # Ensure prompt is 2-3 words long
42
- prompt = " ".join(prompt.split()[:3])
43
-
44
- # Apply style
45
- prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
46
- seed = int(randomize_seed_fn(seed, randomize_seed))
47
- generator = torch.manual_seed(seed)
48
-
49
- width, height = size_map.get(size, (512, 512))
50
-
51
- if not use_negative_prompt:
52
- negative_prompt = ""
53
-
54
- options = {
55
- "prompt": prompt,
56
- "negative_prompt": negative_prompt,
57
- "width": width,
58
- "height": height,
59
- "guidance_scale": guidance_scale,
60
- "num_inference_steps": 20, # Reduced from 25
61
- "generator": generator,
62
- "num_images_per_prompt": 2, # Reduced from 6
63
- "output_type": "pil",
64
- }
65
-
66
-
67
-
68
-
69
  # Kid-friendly styles
70
  style_list = [
71
  {
@@ -108,21 +68,18 @@ DESCRIPTION = """## Children's Sticker Generator
108
  Generate fun and playful stickers for children using AI.
109
  """
110
 
 
 
 
111
  MAX_SEED = np.iinfo(np.int32).max
112
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
113
 
114
- def initialize_pipeline(device_type):
115
- global device, pipe
116
- device = torch.device(device_type)
117
- pipe = DiffusionPipeline.from_pretrained(
118
- "SG161222/RealVisXL_V3.0_Turbo",
119
- torch_dtype=torch.float32 if device_type == "cpu" else torch.float16,
120
- use_safetensors=True,
121
- ).to(device)
122
-
123
- # Initialize with CPU by default
124
- initialize_pipeline("cpu")
125
-
126
 
127
  # Convert mm to pixels for a specific DPI (300) and ensure divisible by 8
128
  def mm_to_pixels(mm, dpi=300):
@@ -163,8 +120,6 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
163
  seed = random.randint(0, MAX_SEED)
164
  return seed
165
 
166
- @spaces.GPU(enable_queue=True)
167
- # Update the generate function
168
  def generate(
169
  prompt: str,
170
  negative_prompt: str = "",
@@ -177,28 +132,22 @@ def generate(
177
  background: str = "transparent",
178
  progress=gr.Progress(track_tqdm=True),
179
  ):
180
- global device, pipe
181
-
182
- # Switch device if necessary
183
- if device.type != device_type:
184
- initialize_pipeline(device_type)
185
 
186
- if check_text(prompt, negative_prompt):
187
- raise ValueError("Prompt contains restricted words.")
188
-
189
  # Ensure prompt is 2-3 words long
190
- prompt = " ".join(re.findall(r'\w+', prompt)[:3])
191
 
192
  # Apply style
193
  prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
194
  seed = int(randomize_seed_fn(seed, randomize_seed))
195
  generator = torch.Generator(device=device).manual_seed(seed)
196
 
197
- # Ensure we have only white or transparent background options
198
- width, height = size_map.get(size, (1024, 1024))
199
 
200
  if not use_negative_prompt:
201
- negative_prompt = "" # type: ignore
202
 
203
  options = {
204
  "prompt": prompt,
@@ -206,9 +155,9 @@ def generate(
206
  "width": width,
207
  "height": height,
208
  "guidance_scale": guidance_scale,
209
- "num_inference_steps": 25,
210
  "generator": generator,
211
- "num_images_per_prompt": 6, # Max 6 images
212
  "output_type": "pil",
213
  }
214
 
@@ -216,8 +165,7 @@ def generate(
216
  images = pipe(**options).images
217
  image_paths = [save_image(img, background) for img in images]
218
 
219
- return image_paths, seed, None # Added None for potential error message
220
-
221
 
222
  examples = [
223
  "cute bunny",
@@ -229,9 +177,15 @@ css = '''
229
  .gradio-container{max-width: 700px !important}
230
  h1{text-align:center}
231
  '''
232
- # Update the Gradio interface
 
233
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
234
  gr.Markdown(DESCRIPTION)
 
 
 
 
 
235
  with gr.Group():
236
  with gr.Row():
237
  prompt = gr.Text(
@@ -284,11 +238,6 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
284
  step=0.1,
285
  value=15.7,
286
  )
287
- device_selection = gr.Radio(
288
- choices=["cpu", "cuda"],
289
- value="cpu",
290
- label="Device",
291
- )
292
 
293
  gr.Examples(
294
  examples=examples,
 
8
  import gradio as gr
9
  import numpy as np
10
  from PIL import Image
 
11
  import torch
12
  from diffusers import DiffusionPipeline
13
  from typing import Tuple
14
 
15
+ # Check if GPU is available; fallback to CPU if needed
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
17
 
18
  # Setup rules for bad words (ensure the prompts are kid-friendly)
19
  bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
20
  default_negative = os.getenv("default_negative","")
21
 
 
 
 
 
 
22
  def check_text(prompt, negative=""):
23
  restricted_words = []
24
  for word in bad_words:
 
26
  restricted_words.append(word)
27
  return restricted_words
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # Kid-friendly styles
30
  style_list = [
31
  {
 
68
  Generate fun and playful stickers for children using AI.
69
  """
70
 
71
+ if not torch.cuda.is_available():
72
+ DESCRIPTION += "\n<p>⚠️Running on CPU, This may be slower.</p>"
73
+
74
  MAX_SEED = np.iinfo(np.int32).max
75
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
76
 
77
+ # Initialize the DiffusionPipeline
78
+ pipe = DiffusionPipeline.from_pretrained(
79
+ "SG161222/RealVisXL_V3.0_Turbo",
80
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
81
+ use_safetensors=True,
82
+ ).to(device)
 
 
 
 
 
 
83
 
84
  # Convert mm to pixels for a specific DPI (300) and ensure divisible by 8
85
  def mm_to_pixels(mm, dpi=300):
 
120
  seed = random.randint(0, MAX_SEED)
121
  return seed
122
 
 
 
123
  def generate(
124
  prompt: str,
125
  negative_prompt: str = "",
 
132
  background: str = "transparent",
133
  progress=gr.Progress(track_tqdm=True),
134
  ):
135
+ restricted_words = check_text(prompt, negative_prompt)
136
+ if restricted_words:
137
+ return [], seed, f"Prompt contains restricted words: {', '.join(restricted_words)}"
 
 
138
 
 
 
 
139
  # Ensure prompt is 2-3 words long
140
+ prompt = " ".join(prompt.split()[:3])
141
 
142
  # Apply style
143
  prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
144
  seed = int(randomize_seed_fn(seed, randomize_seed))
145
  generator = torch.Generator(device=device).manual_seed(seed)
146
 
147
+ width, height = size_map.get(size, (512, 512))
 
148
 
149
  if not use_negative_prompt:
150
+ negative_prompt = ""
151
 
152
  options = {
153
  "prompt": prompt,
 
155
  "width": width,
156
  "height": height,
157
  "guidance_scale": guidance_scale,
158
+ "num_inference_steps": 20,
159
  "generator": generator,
160
+ "num_images_per_prompt": 2,
161
  "output_type": "pil",
162
  }
163
 
 
165
  images = pipe(**options).images
166
  image_paths = [save_image(img, background) for img in images]
167
 
168
+ return image_paths, seed, None
 
169
 
170
  examples = [
171
  "cute bunny",
 
177
  .gradio-container{max-width: 700px !important}
178
  h1{text-align:center}
179
  '''
180
+
181
+ # Define the Gradio UI for the sticker generator
182
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
183
  gr.Markdown(DESCRIPTION)
184
+ gr.DuplicateButton(
185
+ value="Duplicate Space for private use",
186
+ elem_id="duplicate-button",
187
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
188
+ )
189
  with gr.Group():
190
  with gr.Row():
191
  prompt = gr.Text(
 
238
  step=0.1,
239
  value=15.7,
240
  )
 
 
 
 
 
241
 
242
  gr.Examples(
243
  examples=examples,