Oranblock commited on
Commit
84ce0b5
1 Parent(s): 24e02ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -14
app.py CHANGED
@@ -13,8 +13,9 @@ import torch
13
  from diffusers import DiffusionPipeline
14
  from typing import Tuple
15
 
16
- # Check if GPU is available; fallback to CPU if needed
17
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
18
 
19
  # Setup rules for bad words (ensure the prompts are kid-friendly)
20
  bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
@@ -68,21 +69,21 @@ DESCRIPTION = """## Children's Sticker Generator
68
  Generate fun and playful stickers for children using AI.
69
  """
70
 
71
- if not torch.cuda.is_available():
72
- DESCRIPTION += "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
73
-
74
  MAX_SEED = np.iinfo(np.int32).max
75
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
76
 
77
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
78
 
79
- # Initialize the DiffusionPipeline
80
- pipe = DiffusionPipeline.from_pretrained(
81
- "SG161222/RealVisXL_V3.0_Turbo", # or any model of your choice
82
- torch_dtype=torch.float16,
83
- use_safetensors=True,
84
- variant="fp16"
85
- ).to(device)
86
 
87
  # Convert mm to pixels for a specific DPI (300) and ensure divisible by 8
88
  def mm_to_pixels(mm, dpi=300):
@@ -134,8 +135,15 @@ def generate(
134
  guidance_scale: float = 3,
135
  randomize_seed: bool = False,
136
  background: str = "transparent",
 
137
  progress=gr.Progress(track_tqdm=True),
138
  ):
 
 
 
 
 
 
139
  if check_text(prompt, negative_prompt):
140
  raise ValueError("Prompt contains restricted words.")
141
 
@@ -145,7 +153,7 @@ def generate(
145
  # Apply style
146
  prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
147
  seed = int(randomize_seed_fn(seed, randomize_seed))
148
- generator = torch.Generator().manual_seed(seed)
149
 
150
  # Ensure we have only white or transparent background options
151
  width, height = size_map.get(size, (1024, 1024))
@@ -241,6 +249,11 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
241
  step=0.1,
242
  value=15.7,
243
  )
 
 
 
 
 
244
 
245
  gr.Examples(
246
  examples=examples,
@@ -267,6 +280,7 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
267
  guidance_scale,
268
  randomize_seed,
269
  background_selection,
 
270
  ],
271
  outputs=[result, seed],
272
  api_name="run",
 
13
  from diffusers import DiffusionPipeline
14
  from typing import Tuple
15
 
16
+ # Initialize device to None
17
+ device = None
18
+ pipe = None
19
 
20
  # Setup rules for bad words (ensure the prompts are kid-friendly)
21
  bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
 
69
  Generate fun and playful stickers for children using AI.
70
  """
71
 
 
 
 
72
  MAX_SEED = np.iinfo(np.int32).max
73
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
74
 
75
+ def initialize_pipeline(device_type):
76
+ global device, pipe
77
+ device = torch.device(device_type)
78
+ pipe = DiffusionPipeline.from_pretrained(
79
+ "SG161222/RealVisXL_V3.0_Turbo",
80
+ torch_dtype=torch.float32 if device_type == "cpu" else torch.float16,
81
+ use_safetensors=True,
82
+ variant="fp32" if device_type == "cpu" else "fp16"
83
+ ).to(device)
84
 
85
+ # Initialize with CPU by default
86
+ initialize_pipeline("cpu")
 
 
 
 
 
87
 
88
  # Convert mm to pixels for a specific DPI (300) and ensure divisible by 8
89
  def mm_to_pixels(mm, dpi=300):
 
135
  guidance_scale: float = 3,
136
  randomize_seed: bool = False,
137
  background: str = "transparent",
138
+ device_type: str = "cpu",
139
  progress=gr.Progress(track_tqdm=True),
140
  ):
141
+ global device, pipe
142
+
143
+ # Switch device if necessary
144
+ if device.type != device_type:
145
+ initialize_pipeline(device_type)
146
+
147
  if check_text(prompt, negative_prompt):
148
  raise ValueError("Prompt contains restricted words.")
149
 
 
153
  # Apply style
154
  prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
155
  seed = int(randomize_seed_fn(seed, randomize_seed))
156
+ generator = torch.Generator(device=device).manual_seed(seed)
157
 
158
  # Ensure we have only white or transparent background options
159
  width, height = size_map.get(size, (1024, 1024))
 
249
  step=0.1,
250
  value=15.7,
251
  )
252
+ device_selection = gr.Radio(
253
+ choices=["cpu", "cuda"],
254
+ value="cpu",
255
+ label="Device",
256
+ )
257
 
258
  gr.Examples(
259
  examples=examples,
 
280
  guidance_scale,
281
  randomize_seed,
282
  background_selection,
283
+ device_selection,
284
  ],
285
  outputs=[result, seed],
286
  api_name="run",