fantos commited on
Commit
8d2510b
Β·
verified Β·
1 Parent(s): 4ae2f5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -24
app.py CHANGED
@@ -10,9 +10,9 @@ from huggingface_hub import hf_hub_download
10
  import gradio as gr
11
  import torch
12
  from diffusers import FluxPipeline
 
13
  from PIL import Image
14
 
15
-
16
  # Setup and initialization code
17
  cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
18
  # Use PERSISTENT_DIR environment variable for Spaces
@@ -29,6 +29,27 @@ torch.backends.cuda.matmul.allow_tf32 = True
29
  if not path.exists(gallery_path):
30
  os.makedirs(gallery_path, exist_ok=True)
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  class timer:
33
  def __init__(self, method_name="timed process"):
34
  self.method = method_name
@@ -48,6 +69,9 @@ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8
48
  pipe.fuse_lora(lora_scale=0.125)
49
  pipe.to(device="cuda", dtype=torch.bfloat16)
50
 
 
 
 
51
  css = """
52
  footer {display: none !important}
53
  .gradio-container {
@@ -77,7 +101,6 @@ footer {display: none !important}
77
  -webkit-background-clip: text;
78
  -webkit-text-fill-color: transparent;
79
  }
80
- /* Gallery specific styles */
81
  #gallery {
82
  width: 100% !important;
83
  max-width: 100% !important;
@@ -101,7 +124,6 @@ footer {display: none !important}
101
  width: 100% !important;
102
  box-sizing: border-box !important;
103
  }
104
- /* Force gallery items to maintain aspect ratio */
105
  .gallery-item {
106
  width: 100% !important;
107
  aspect-ratio: 1 !important;
@@ -118,12 +140,10 @@ footer {display: none !important}
118
  .gallery-item img:hover {
119
  transform: scale(1.05);
120
  }
121
- /* Force output image container to full width */
122
  .output-image {
123
  width: 100% !important;
124
  max-width: 100% !important;
125
  }
126
- /* Force container widths */
127
  .contain > div {
128
  width: 100% !important;
129
  max-width: 100% !important;
@@ -132,7 +152,6 @@ footer {display: none !important}
132
  width: 100% !important;
133
  max-width: 100% !important;
134
  }
135
- /* Remove any horizontal scrolling */
136
  .gallery-container::-webkit-scrollbar {
137
  display: none !important;
138
  }
@@ -140,7 +159,6 @@ footer {display: none !important}
140
  -ms-overflow-style: none !important;
141
  scrollbar-width: none !important;
142
  }
143
- /* Ensure consistent sizing for gallery wrapper */
144
  #gallery > div {
145
  width: 100% !important;
146
  max-width: 100% !important;
@@ -150,10 +168,10 @@ footer {display: none !important}
150
  max-width: 100% !important;
151
  }
152
  """
 
153
  def save_image(image):
154
  """Save the generated image and return the path"""
155
  try:
156
- # Ensure gallery directory exists
157
  if not os.path.exists(gallery_path):
158
  try:
159
  os.makedirs(gallery_path, exist_ok=True)
@@ -161,7 +179,6 @@ def save_image(image):
161
  print(f"Failed to create gallery directory: {str(e)}")
162
  return None
163
 
164
- # Generate unique filename with timestamp and random suffix
165
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
166
  random_suffix = os.urandom(4).hex()
167
  filename = f"generated_{timestamp}_{random_suffix}.png"
@@ -187,24 +204,19 @@ def save_image(image):
187
  print(f"Error in save_image: {str(e)}")
188
  return None
189
 
190
-
191
  def load_gallery():
192
  """Load all images from the gallery directory"""
193
  try:
194
- # Ensure gallery directory exists
195
  os.makedirs(gallery_path, exist_ok=True)
196
 
197
- # Get all image files and sort by modification time
198
  image_files = []
199
  for f in os.listdir(gallery_path):
200
  if f.lower().endswith(('.png', '.jpg', '.jpeg')):
201
  full_path = os.path.join(gallery_path, f)
202
  image_files.append((full_path, os.path.getmtime(full_path)))
203
 
204
- # Sort by modification time (newest first)
205
  image_files.sort(key=lambda x: x[1], reverse=True)
206
 
207
- # Return only the file paths
208
  return [f[0] for f in image_files]
209
  except Exception as e:
210
  print(f"Error loading gallery: {str(e)}")
@@ -214,6 +226,12 @@ def load_gallery():
214
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
215
  gr.HTML('<div class="title">AI Image Generator</div>')
216
  gr.HTML('<div style="text-align: center; margin-bottom: 2em; color: #666;">Create stunning images from your descriptions</div>')
 
 
 
 
 
 
217
 
218
  with gr.Row():
219
  with gr.Column(scale=3):
@@ -299,14 +317,12 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
299
  """)
300
 
301
  with gr.Column(scale=4, elem_classes=["fixed-width"]):
302
- # Current generated image
303
  output = gr.Image(
304
  label="Generated Image",
305
  elem_id="output-image",
306
  elem_classes=["output-image", "fixed-width"]
307
  )
308
 
309
- # Gallery of generated images
310
  gallery = gr.Gallery(
311
  label="Generated Images Gallery",
312
  show_label=True,
@@ -318,16 +334,20 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
318
  elem_classes=["gallery-container", "fixed-width"]
319
  )
320
 
321
- # Load existing gallery images on startup
322
  gallery.value = load_gallery()
323
 
324
  @spaces.GPU
325
  def process_and_save_image(height, width, steps, scales, prompt, seed):
326
- global pipe
 
 
 
 
 
327
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
328
  try:
329
  generated_image = pipe(
330
- prompt=[prompt],
331
  generator=torch.Generator().manual_seed(int(seed)),
332
  num_inference_steps=int(steps),
333
  guidance_scale=float(scales),
@@ -336,18 +356,15 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
336
  max_sequence_length=256
337
  ).images[0]
338
 
339
- # Save the generated image
340
  saved_path = save_image(generated_image)
341
  if saved_path is None:
342
  print("Warning: Failed to save generated image")
343
 
344
- # Return both the generated image and updated gallery
345
  return generated_image, load_gallery()
346
  except Exception as e:
347
  print(f"Error in image generation: {str(e)}")
348
  return None, load_gallery()
349
 
350
- # Connect the generation button to both the image output and gallery update
351
  def update_seed():
352
  return get_random_seed()
353
 
@@ -357,13 +374,11 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
357
  outputs=[output, gallery]
358
  )
359
 
360
- # Add randomize seed button functionality
361
  randomize_seed.click(
362
  update_seed,
363
  outputs=[seed]
364
  )
365
 
366
- # Also randomize seed after each generation
367
  generate_btn.click(
368
  update_seed,
369
  outputs=[seed]
 
10
  import gradio as gr
11
  import torch
12
  from diffusers import FluxPipeline
13
+ from diffusers.pipelines.stable_diffusion import safety_checker
14
  from PIL import Image
15
 
 
16
  # Setup and initialization code
17
  cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
18
  # Use PERSISTENT_DIR environment variable for Spaces
 
29
  if not path.exists(gallery_path):
30
  os.makedirs(gallery_path, exist_ok=True)
31
 
32
+ def filter_prompt(prompt):
33
+ # λΆ€μ μ ˆν•œ ν‚€μ›Œλ“œ λͺ©λ‘
34
+ inappropriate_keywords = [
35
+ # μŒλž€/성적 ν‚€μ›Œλ“œ
36
+ "nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx",
37
+ "erotic", "sensual", "seductive", "provocative", "intimate",
38
+ # 폭λ ₯적 ν‚€μ›Œλ“œ
39
+ "violence", "gore", "blood", "death", "kill", "murder", "torture",
40
+ # 기타 λΆ€μ μ ˆν•œ ν‚€μ›Œλ“œ
41
+ "drug", "suicide", "abuse", "hate", "discrimination"
42
+ ]
43
+
44
+ prompt_lower = prompt.lower()
45
+
46
+ # λΆ€μ μ ˆν•œ ν‚€μ›Œλ“œ 체크
47
+ for keyword in inappropriate_keywords:
48
+ if keyword in prompt_lower:
49
+ return False, "λΆ€μ μ ˆν•œ λ‚΄μš©μ΄ ν¬ν•¨λœ ν”„λ‘¬ν”„νŠΈμž…λ‹ˆλ‹€."
50
+
51
+ return True, prompt
52
+
53
  class timer:
54
  def __init__(self, method_name="timed process"):
55
  self.method = method_name
 
69
  pipe.fuse_lora(lora_scale=0.125)
70
  pipe.to(device="cuda", dtype=torch.bfloat16)
71
 
72
+ # Add safety checker
73
+ pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
74
+
75
  css = """
76
  footer {display: none !important}
77
  .gradio-container {
 
101
  -webkit-background-clip: text;
102
  -webkit-text-fill-color: transparent;
103
  }
 
104
  #gallery {
105
  width: 100% !important;
106
  max-width: 100% !important;
 
124
  width: 100% !important;
125
  box-sizing: border-box !important;
126
  }
 
127
  .gallery-item {
128
  width: 100% !important;
129
  aspect-ratio: 1 !important;
 
140
  .gallery-item img:hover {
141
  transform: scale(1.05);
142
  }
 
143
  .output-image {
144
  width: 100% !important;
145
  max-width: 100% !important;
146
  }
 
147
  .contain > div {
148
  width: 100% !important;
149
  max-width: 100% !important;
 
152
  width: 100% !important;
153
  max-width: 100% !important;
154
  }
 
155
  .gallery-container::-webkit-scrollbar {
156
  display: none !important;
157
  }
 
159
  -ms-overflow-style: none !important;
160
  scrollbar-width: none !important;
161
  }
 
162
  #gallery > div {
163
  width: 100% !important;
164
  max-width: 100% !important;
 
168
  max-width: 100% !important;
169
  }
170
  """
171
+
172
  def save_image(image):
173
  """Save the generated image and return the path"""
174
  try:
 
175
  if not os.path.exists(gallery_path):
176
  try:
177
  os.makedirs(gallery_path, exist_ok=True)
 
179
  print(f"Failed to create gallery directory: {str(e)}")
180
  return None
181
 
 
182
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
183
  random_suffix = os.urandom(4).hex()
184
  filename = f"generated_{timestamp}_{random_suffix}.png"
 
204
  print(f"Error in save_image: {str(e)}")
205
  return None
206
 
 
207
  def load_gallery():
208
  """Load all images from the gallery directory"""
209
  try:
 
210
  os.makedirs(gallery_path, exist_ok=True)
211
 
 
212
  image_files = []
213
  for f in os.listdir(gallery_path):
214
  if f.lower().endswith(('.png', '.jpg', '.jpeg')):
215
  full_path = os.path.join(gallery_path, f)
216
  image_files.append((full_path, os.path.getmtime(full_path)))
217
 
 
218
  image_files.sort(key=lambda x: x[1], reverse=True)
219
 
 
220
  return [f[0] for f in image_files]
221
  except Exception as e:
222
  print(f"Error loading gallery: {str(e)}")
 
226
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
227
  gr.HTML('<div class="title">AI Image Generator</div>')
228
  gr.HTML('<div style="text-align: center; margin-bottom: 2em; color: #666;">Create stunning images from your descriptions</div>')
229
+
230
+ gr.HTML("""
231
+ <div style="color: red; margin-bottom: 1em; text-align: center; padding: 10px; background: rgba(255,0,0,0.1); border-radius: 8px;">
232
+ ⚠️ μŒλž€ν•˜κ±°λ‚˜ λΆ€μ μ ˆν•œ λ‚΄μš©μ˜ μ΄λ―Έμ§€λŠ” 생성할 수 μ—†μŠ΅λ‹ˆλ‹€.
233
+ </div>
234
+ """)
235
 
236
  with gr.Row():
237
  with gr.Column(scale=3):
 
317
  """)
318
 
319
  with gr.Column(scale=4, elem_classes=["fixed-width"]):
 
320
  output = gr.Image(
321
  label="Generated Image",
322
  elem_id="output-image",
323
  elem_classes=["output-image", "fixed-width"]
324
  )
325
 
 
326
  gallery = gr.Gallery(
327
  label="Generated Images Gallery",
328
  show_label=True,
 
334
  elem_classes=["gallery-container", "fixed-width"]
335
  )
336
 
 
337
  gallery.value = load_gallery()
338
 
339
  @spaces.GPU
340
  def process_and_save_image(height, width, steps, scales, prompt, seed):
341
+ # ν”„λ‘¬ν”„νŠΈ 필터링
342
+ is_safe, filtered_prompt = filter_prompt(prompt)
343
+ if not is_safe:
344
+ gr.Warning("λΆ€μ μ ˆν•œ λ‚΄μš©μ΄ ν¬ν•¨λœ ν”„λ‘¬ν”„νŠΈμž…λ‹ˆλ‹€.")
345
+ return None, load_gallery()
346
+
347
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
348
  try:
349
  generated_image = pipe(
350
+ prompt=[filtered_prompt],
351
  generator=torch.Generator().manual_seed(int(seed)),
352
  num_inference_steps=int(steps),
353
  guidance_scale=float(scales),
 
356
  max_sequence_length=256
357
  ).images[0]
358
 
 
359
  saved_path = save_image(generated_image)
360
  if saved_path is None:
361
  print("Warning: Failed to save generated image")
362
 
 
363
  return generated_image, load_gallery()
364
  except Exception as e:
365
  print(f"Error in image generation: {str(e)}")
366
  return None, load_gallery()
367
 
 
368
  def update_seed():
369
  return get_random_seed()
370
 
 
374
  outputs=[output, gallery]
375
  )
376
 
 
377
  randomize_seed.click(
378
  update_seed,
379
  outputs=[seed]
380
  )
381
 
 
382
  generate_btn.click(
383
  update_seed,
384
  outputs=[seed]