Oranblock commited on
Commit
d33031f
1 Parent(s): 322f509

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -227
app.py CHANGED
@@ -1,5 +1,3 @@
1
- #!/usr/bin/env python
2
-
3
  import os
4
  import random
5
  import uuid
@@ -11,69 +9,13 @@ from PIL import Image
11
  import spaces
12
  import torch
13
  from diffusers import DiffusionPipeline
 
14
  from typing import Tuple
15
 
16
- # Setup rules for bad words (ensure the prompts are kid-friendly)
17
- bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
18
- default_negative = os.getenv("default_negative","")
19
-
20
- def check_text(prompt, negative=""):
21
- for i in bad_words:
22
- if i in prompt:
23
- return True
24
- return False
25
-
26
- # Kid-friendly styles
27
- style_list = [
28
- {
29
- "name": "Cartoon",
30
- "prompt": "colorful cartoon {prompt}. vibrant, playful, friendly, suitable for children, highly detailed, bright colors",
31
- "negative_prompt": "scary, dark, violent, ugly, realistic",
32
- },
33
- {
34
- "name": "Children's Illustration",
35
- "prompt": "children's illustration {prompt}. cute, colorful, fun, simple shapes, smooth lines, highly detailed, joyful",
36
- "negative_prompt": "scary, dark, violent, deformed, ugly",
37
- },
38
- {
39
- "name": "Sticker",
40
- "prompt": "children's sticker of {prompt}. bright colors, playful, high resolution, cartoonish",
41
- "negative_prompt": "scary, dark, violent, ugly, low resolution",
42
- },
43
- {
44
- "name": "Fantasy",
45
- "prompt": "fantasy world for children with {prompt}. magical, vibrant, friendly, beautiful, colorful",
46
- "negative_prompt": "dark, scary, violent, ugly, realistic",
47
- },
48
- {
49
- "name": "(No style)",
50
- "prompt": "{prompt}",
51
- "negative_prompt": "",
52
- },
53
- ]
54
-
55
- styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
56
- STYLE_NAMES = list(styles.keys())
57
- DEFAULT_STYLE_NAME = "Sticker"
58
-
59
- def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
60
- p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
61
- return p.replace("{prompt}", positive), n + negative
62
-
63
- DESCRIPTION = """## Children's Sticker Generator
64
-
65
- Generate fun and playful stickers for children using AI.
66
- """
67
-
68
- if not torch.cuda.is_available():
69
- DESCRIPTION += "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
70
-
71
- MAX_SEED = np.iinfo(np.int32).max
72
- CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
73
-
74
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
75
 
76
- # Initialize the DiffusionPipeline
77
  pipe = DiffusionPipeline.from_pretrained(
78
  "SG161222/RealVisXL_V3.0_Turbo", # or any model of your choice
79
  torch_dtype=torch.float16,
@@ -81,193 +23,124 @@ pipe = DiffusionPipeline.from_pretrained(
81
  variant="fp16"
82
  ).to(device)
83
 
84
- # Convert mm to pixels for a specific DPI (300) and ensure divisible by 8
85
- def mm_to_pixels(mm, dpi=300):
86
- """Convert mm to pixels and make the dimensions divisible by 8."""
87
- pixels = int((mm / 25.4) * dpi)
88
- return pixels - (pixels % 8) # Adjust to the nearest lower multiple of 8
89
 
90
- # Default sizes for 75mm and 35mm, rounded to nearest multiple of 8
91
- size_map = {
92
- "75mm": (mm_to_pixels(75), mm_to_pixels(75)), # 75mm in pixels at 300dpi
93
- "35mm": (mm_to_pixels(35), mm_to_pixels(35)), # 35mm in pixels at 300dpi
94
- }
95
 
96
- # Function to post-process images (transparent or white background)
97
- def save_image(img, background="transparent"):
98
- img = img.convert("RGBA")
99
- data = img.getdata()
100
- new_data = []
101
 
102
- if background == "transparent":
103
- for item in data:
104
- # Replace white with transparent
105
- if item[0] == 255 and item[1] == 255 and item[2] == 255:
106
- new_data.append((255, 255, 255, 0)) # Transparent
107
- else:
108
- new_data.append(item)
109
- elif background == "white":
110
- for item in data:
111
- new_data.append(item) # Keep as white
112
-
113
- img.putdata(new_data)
114
- unique_name = str(uuid.uuid4()) + ".png"
115
- img.save(unique_name)
116
- return unique_name
117
-
118
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
119
- if randomize_seed:
120
- seed = random.randint(0, MAX_SEED)
121
- return seed
122
-
123
- @spaces.GPU(enable_queue=True)
124
- def generate(
125
- prompt: str,
126
- negative_prompt: str = "",
127
- use_negative_prompt: bool = False,
128
- style: str = DEFAULT_STYLE_NAME,
129
- seed: int = 0,
130
- size: str = "75mm",
131
- guidance_scale: float = 3,
132
- randomize_seed: bool = False,
133
- background: str = "transparent",
134
- progress=gr.Progress(track_tqdm=True),
135
- ):
136
- if check_text(prompt, negative_prompt):
137
- raise ValueError("Prompt contains restricted words.")
138
 
139
- # Ensure prompt is 2-3 words long
140
- prompt = " ".join(re.findall(r'\w+', prompt)[:3])
 
141
 
142
- # Apply style
143
- prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
144
- seed = int(randomize_seed_fn(seed, randomize_seed))
145
- generator = torch.Generator().manual_seed(seed)
146
 
147
- # Ensure we have only white or transparent background options
148
- width, height = size_map.get(size, (1024, 1024))
 
 
149
 
150
- if not use_negative_prompt:
151
- negative_prompt = "" # type: ignore
152
 
 
 
 
 
 
 
 
 
153
  options = {
154
  "prompt": prompt,
155
- "negative_prompt": negative_prompt,
156
- "width": width,
157
- "height": height,
158
  "guidance_scale": guidance_scale,
159
- "num_inference_steps": 25,
160
  "generator": generator,
161
- "num_images_per_prompt": 6, # Max 6 images
162
- "output_type": "pil",
163
  }
164
-
165
- # Generate images with the pipeline
 
 
 
166
  images = pipe(**options).images
167
- image_paths = [save_image(img, background) for img in images]
168
 
169
  return image_paths, seed
170
 
171
- examples = [
172
- "cute bunny",
173
- "happy cat",
174
- "funny dog",
175
- ]
176
 
177
- css = '''
178
- .gradio-container{max-width: 700px !important}
179
- h1{text-align:center}
180
- '''
 
181
 
182
- # Define the Gradio UI for the sticker generator
183
- with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
184
- gr.Markdown(DESCRIPTION)
185
- gr.DuplicateButton(
186
- value="Duplicate Space for private use",
187
- elem_id="duplicate-button",
188
- visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
189
- )
190
- with gr.Group():
191
- with gr.Row():
192
- prompt = gr.Text(
193
- label="Enter your prompt",
194
- show_label=False,
195
- max_lines=1,
196
- placeholder="Enter 2-3 word prompt (e.g., cute bunny)",
197
- container=False,
198
- )
199
- run_button = gr.Button("Run")
200
- result = gr.Gallery(label="Generated Stickers", columns=2, preview=True)
201
- with gr.Accordion("Advanced options", open=False):
202
- use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True)
203
- negative_prompt = gr.Text(
204
- label="Negative prompt",
205
- max_lines=1,
206
- placeholder="Enter a negative prompt",
207
- value="(scary, violent, dark, ugly)",
208
- visible=True,
209
- )
210
- seed = gr.Slider(
211
- label="Seed",
212
- minimum=0,
213
- maximum=MAX_SEED,
214
- step=1,
215
- value=0,
216
- visible=True
217
- )
218
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
219
- size_selection = gr.Radio(
220
- choices=["75mm", "35mm"],
221
- value="75mm",
222
- label="Sticker Size",
223
- )
224
- style_selection = gr.Radio(
225
- choices=STYLE_NAMES,
226
- value=DEFAULT_STYLE_NAME,
227
- label="Image Style",
228
- )
229
- background_selection = gr.Radio(
230
- choices=["transparent", "white"],
231
- value="transparent",
232
- label="Background Color",
233
- )
234
- guidance_scale = gr.Slider(
235
- label="Guidance Scale",
236
- minimum=0.1,
237
- maximum=20.0,
238
- step=0.1,
239
- value=6,
240
- )
241
 
242
- gr.Examples(
243
- examples=examples,
244
- inputs=prompt,
245
- outputs=[result, seed],
246
- fn=generate,
247
- cache_examples=CACHE_EXAMPLES,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  )
 
 
 
 
 
 
249
 
250
- gr.on(
251
- triggers=[
252
- prompt.submit,
253
- negative_prompt.submit,
254
- run_button.click,
255
- ],
256
- fn=generate,
257
- inputs=[
258
- prompt,
259
- negative_prompt,
260
- use_negative_prompt,
261
- style_selection,
262
- seed,
263
- size_selection,
264
- guidance_scale,
265
- randomize_seed,
266
- background_selection,
267
- ],
268
- outputs=[result, seed],
269
- api_name="run",
270
  )
271
 
272
- if __name__ == "__main__":
273
- demo.queue(max_size=20).launch()
 
 
 
1
  import os
2
  import random
3
  import uuid
 
9
  import spaces
10
  import torch
11
  from diffusers import DiffusionPipeline
12
+ import face_recognition # More robust face detection library
13
  from typing import Tuple
14
 
15
+ # Check if GPU is available; fallback to CPU if needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17
 
18
+ # Initialize the AI model for sticker generation
19
  pipe = DiffusionPipeline.from_pretrained(
20
  "SG161222/RealVisXL_V3.0_Turbo", # or any model of your choice
21
  torch_dtype=torch.float16,
 
23
  variant="fp16"
24
  ).to(device)
25
 
 
 
 
 
 
26
 
27
+ def face_to_sticker(image_path: str) -> Tuple[str, str]:
28
+ """Detect the face using face_recognition and convert it to a sticker format."""
29
+ img = face_recognition.load_image_file(image_path)
30
+ face_locations = face_recognition.face_locations(img)
 
31
 
32
+ if not face_locations:
33
+ return None, "No face detected. Please upload a clear image with a visible face."
 
 
 
34
 
35
+ # Extract the first detected face and return as an image for sticker creation
36
+ top, right, bottom, left = face_locations[0]
37
+ face_img = img[top:bottom, left:right]
38
+ face_img = Image.fromarray(face_img).resize((256, 256)) # Resize face to sticker size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ face_img_path = f"{uuid.uuid4()}.png"
41
+ face_img.save(face_img_path)
42
+ return face_img_path, "Face successfully converted to a sticker."
43
 
 
 
 
 
44
 
45
+ def generate_prompt(clothing: str, pose: str, mood: str) -> str:
46
+ """Generate a descriptive prompt based on user-selected clothing, pose, and mood."""
47
+ prompt = f"sticker of a person wearing {clothing} clothes, in a {pose} pose, looking {mood}."
48
+ return prompt
49
 
 
 
50
 
51
+ def generate_stickers(prompt: str, face_image: str, guidance_scale: float = 7.5, randomize_seed: bool = False):
52
+ """Generate stickers using the diffusion model with the given prompt and face."""
53
+
54
+ # Adjust seed for variability
55
+ seed = random.randint(0, MAX_SEED) if randomize_seed else 42
56
+ generator = torch.Generator(device).manual_seed(seed)
57
+
58
+ # Prepare AI model options
59
  options = {
60
  "prompt": prompt,
61
+ "width": 512,
62
+ "height": 512,
 
63
  "guidance_scale": guidance_scale,
64
+ "num_inference_steps": 50,
65
  "generator": generator,
 
 
66
  }
67
+
68
+ # Load the face as an input condition for the sticker (optional, if supported by the model)
69
+ # If your model supports conditioning on a specific face, load the face image here
70
+ # options['image'] = Image.open(face_image)
71
+
72
  images = pipe(**options).images
73
+ image_paths = [save_image(img) for img in images]
74
 
75
  return image_paths, seed
76
 
 
 
 
 
 
77
 
78
+ def save_image(img: Image.Image) -> str:
79
+ """Save an image to a file and return the path."""
80
+ unique_name = f"{uuid.uuid4()}.png"
81
+ img.save(unique_name)
82
+ return unique_name
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ def stick_me_workflow(image, clothing, pose, mood, randomize_seed: bool):
86
+ """Workflow to generate stickers based on user-uploaded image and options."""
87
+ # Convert the uploaded image to a face sticker
88
+ face_path, message = face_to_sticker(image)
89
+
90
+ if face_path is None:
91
+ return message # Return error message if face detection fails
92
+
93
+ # Generate a descriptive prompt based on user selections
94
+ prompt = generate_prompt(clothing, pose, mood)
95
+
96
+ # Generate stickers using the diffusion model with the extracted face and prompt
97
+ stickers, seed = generate_stickers(prompt, face_path, randomize_seed=randomize_seed)
98
+ return stickers
99
+
100
+
101
+ def on_fallback_to_cpu():
102
+ """Notify users when the app is running on CPU (due to GPU quota being exceeded)."""
103
+ if not torch.cuda.is_available():
104
+ return "Warning: GPU quota exceeded. Running on CPU, which will be significantly slower."
105
+ return ""
106
+
107
+
108
+ # Gradio interface setup
109
+
110
+ with gr.Blocks() as demo:
111
+ gr.Markdown("# Sticker Generator with 'Stick Me' Feature")
112
+
113
+ # GPU Quota Handling
114
+ gpu_warning = gr.Markdown(on_fallback_to_cpu(), visible=not torch.cuda.is_available())
115
+
116
+ # New Stick Me Option
117
+ with gr.Row():
118
+ face_input = gr.Image(label="Upload Your Image for 'Stick Me'", type="filepath")
119
+ clothing = gr.Dropdown(["Casual", "Formal", "Sports"], label="Choose Clothing")
120
+ pose = gr.Dropdown(["Standing", "Sitting", "Running"], label="Choose Pose")
121
+ mood = gr.Dropdown(["Happy", "Serious", "Excited"], label="Choose Mood")
122
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
123
+
124
+ stick_me_button = gr.Button("Generate Stick Me Stickers")
125
+
126
+ stick_me_result = gr.Gallery(label="Your Stick Me Stickers")
127
+
128
+ stick_me_button.click(
129
+ fn=stick_me_workflow,
130
+ inputs=[face_input, clothing, pose, mood, randomize_seed],
131
+ outputs=[stick_me_result]
132
  )
133
+
134
+ gr.Markdown("# Generate Regular Stickers")
135
+
136
+ prompt = gr.Textbox(label="Enter a Prompt for Sticker Creation", placeholder="Cute bunny", max_lines=1)
137
+ generate_button = gr.Button("Generate Stickers")
138
+ result = gr.Gallery(label="Generated Stickers")
139
 
140
+ generate_button.click(
141
+ fn=generate_stickers,
142
+ inputs=[prompt],
143
+ outputs=[result]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  )
145
 
146
+ demo.launch()