Oranblock commited on
Commit
7bfe827
1 Parent(s): 128e9ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -16
app.py CHANGED
@@ -21,11 +21,50 @@ pipe = None
21
  bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
22
  default_negative = os.getenv("default_negative","")
23
 
 
 
 
 
 
24
  def check_text(prompt, negative=""):
25
- for i in bad_words:
26
- if i in prompt:
27
- return True
28
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # Kid-friendly styles
31
  style_list = [
@@ -125,6 +164,7 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
125
  return seed
126
 
127
  @spaces.GPU(enable_queue=True)
 
128
  def generate(
129
  prompt: str,
130
  negative_prompt: str = "",
@@ -135,7 +175,6 @@ def generate(
135
  guidance_scale: float = 3,
136
  randomize_seed: bool = False,
137
  background: str = "transparent",
138
- device_type: str = "cpu",
139
  progress=gr.Progress(track_tqdm=True),
140
  ):
141
  global device, pipe
@@ -177,7 +216,8 @@ def generate(
177
  images = pipe(**options).images
178
  image_paths = [save_image(img, background) for img in images]
179
 
180
- return image_paths, seed
 
181
 
182
  examples = [
183
  "cute bunny",
@@ -189,15 +229,9 @@ css = '''
189
  .gradio-container{max-width: 700px !important}
190
  h1{text-align:center}
191
  '''
192
-
193
- # Define the Gradio UI for the sticker generator
194
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
195
  gr.Markdown(DESCRIPTION)
196
- gr.DuplicateButton(
197
- value="Duplicate Space for private use",
198
- elem_id="duplicate-button",
199
- visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
200
- )
201
  with gr.Group():
202
  with gr.Row():
203
  prompt = gr.Text(
@@ -209,6 +243,7 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
209
  )
210
  run_button = gr.Button("Run")
211
  result = gr.Gallery(label="Generated Stickers", columns=2, preview=True)
 
212
  with gr.Accordion("Advanced options", open=False):
213
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True)
214
  negative_prompt = gr.Text(
@@ -258,7 +293,7 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
258
  gr.Examples(
259
  examples=examples,
260
  inputs=prompt,
261
- outputs=[result, seed],
262
  fn=generate,
263
  cache_examples=CACHE_EXAMPLES,
264
  )
@@ -280,9 +315,8 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
280
  guidance_scale,
281
  randomize_seed,
282
  background_selection,
283
- device_selection,
284
  ],
285
- outputs=[result, seed],
286
  api_name="run",
287
  )
288
 
 
21
  bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
22
  default_negative = os.getenv("default_negative","")
23
 
24
+ # Remove the GPU-specific import and decorator
25
+ # import spaces
26
+ # @spaces.GPU(enable_queue=True)
27
+
28
+ # Update the check_text function to be more informative
29
  def check_text(prompt, negative=""):
30
+ restricted_words = []
31
+ for word in bad_words:
32
+ if word in prompt.lower() or word in negative.lower():
33
+ restricted_words.append(word)
34
+ return restricted_words
35
+
36
+
37
+ restricted_words = check_text(prompt, negative_prompt)
38
+ if restricted_words:
39
+ return [], seed, f"Prompt contains restricted words: {', '.join(restricted_words)}"
40
+
41
+ # Ensure prompt is 2-3 words long
42
+ prompt = " ".join(prompt.split()[:3])
43
+
44
+ # Apply style
45
+ prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
46
+ seed = int(randomize_seed_fn(seed, randomize_seed))
47
+ generator = torch.manual_seed(seed)
48
+
49
+ width, height = size_map.get(size, (512, 512))
50
+
51
+ if not use_negative_prompt:
52
+ negative_prompt = ""
53
+
54
+ options = {
55
+ "prompt": prompt,
56
+ "negative_prompt": negative_prompt,
57
+ "width": width,
58
+ "height": height,
59
+ "guidance_scale": guidance_scale,
60
+ "num_inference_steps": 20, # Reduced from 25
61
+ "generator": generator,
62
+ "num_images_per_prompt": 2, # Reduced from 6
63
+ "output_type": "pil",
64
+ }
65
+
66
+
67
+
68
 
69
  # Kid-friendly styles
70
  style_list = [
 
164
  return seed
165
 
166
  @spaces.GPU(enable_queue=True)
167
+ # Update the generate function
168
  def generate(
169
  prompt: str,
170
  negative_prompt: str = "",
 
175
  guidance_scale: float = 3,
176
  randomize_seed: bool = False,
177
  background: str = "transparent",
 
178
  progress=gr.Progress(track_tqdm=True),
179
  ):
180
  global device, pipe
 
216
  images = pipe(**options).images
217
  image_paths = [save_image(img, background) for img in images]
218
 
219
+ return image_paths, seed, None # Added None for potential error message
220
+
221
 
222
  examples = [
223
  "cute bunny",
 
229
  .gradio-container{max-width: 700px !important}
230
  h1{text-align:center}
231
  '''
232
+ # Update the Gradio interface
 
233
  with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
234
  gr.Markdown(DESCRIPTION)
 
 
 
 
 
235
  with gr.Group():
236
  with gr.Row():
237
  prompt = gr.Text(
 
243
  )
244
  run_button = gr.Button("Run")
245
  result = gr.Gallery(label="Generated Stickers", columns=2, preview=True)
246
+ error_output = gr.Textbox(label="Error", visible=False)
247
  with gr.Accordion("Advanced options", open=False):
248
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True)
249
  negative_prompt = gr.Text(
 
293
  gr.Examples(
294
  examples=examples,
295
  inputs=prompt,
296
+ outputs=[result, seed, error_output],
297
  fn=generate,
298
  cache_examples=CACHE_EXAMPLES,
299
  )
 
315
  guidance_scale,
316
  randomize_seed,
317
  background_selection,
 
318
  ],
319
+ outputs=[result, seed, error_output],
320
  api_name="run",
321
  )
322