Oranblock commited on
Commit
722bb4a
1 Parent(s): aee9842

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -100
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import random
3
  import uuid
@@ -9,13 +11,69 @@ from PIL import Image
9
  import spaces
10
  import torch
11
  from diffusers import DiffusionPipeline
12
- import face_recognition # More robust face detection library
13
  from typing import Tuple
14
 
15
- # Check if GPU is available; fallback to CPU if needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17
 
18
- # Initialize the AI model for sticker generation
19
  pipe = DiffusionPipeline.from_pretrained(
20
  "SG161222/RealVisXL_V3.0_Turbo", # or any model of your choice
21
  torch_dtype=torch.float16,
@@ -23,124 +81,193 @@ pipe = DiffusionPipeline.from_pretrained(
23
  variant="fp16"
24
  ).to(device)
25
 
 
 
 
 
 
26
 
27
- def face_to_sticker(image_path: str) -> Tuple[str, str]:
28
- """Detect the face using face_recognition and convert it to a sticker format."""
29
- img = face_recognition.load_image_file(image_path)
30
- face_locations = face_recognition.face_locations(img)
 
31
 
32
- if not face_locations:
33
- return None, "No face detected. Please upload a clear image with a visible face."
34
-
35
- # Extract the first detected face and return as an image for sticker creation
36
- top, right, bottom, left = face_locations[0]
37
- face_img = img[top:bottom, left:right]
38
- face_img = Image.fromarray(face_img).resize((256, 256)) # Resize face to sticker size
39
-
40
- face_img_path = f"{uuid.uuid4()}.png"
41
- face_img.save(face_img_path)
42
- return face_img_path, "Face successfully converted to a sticker."
43
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- def generate_prompt(clothing: str, pose: str, mood: str) -> str:
46
- """Generate a descriptive prompt based on user-selected clothing, pose, and mood."""
47
- prompt = f"sticker of a person wearing {clothing} clothes, in a {pose} pose, looking {mood}."
48
- return prompt
49
 
 
 
 
 
50
 
51
- def generate_stickers(prompt: str, face_image: str, guidance_scale: float = 7.5, randomize_seed: bool = False):
52
- """Generate stickers using the diffusion model with the given prompt and face."""
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # Adjust seed for variability
55
- seed = random.randint(0, MAX_SEED) if randomize_seed else 42
56
- generator = torch.Generator(device).manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
57
 
58
- # Prepare AI model options
59
  options = {
60
  "prompt": prompt,
61
- "width": 512,
62
- "height": 512,
 
63
  "guidance_scale": guidance_scale,
64
- "num_inference_steps": 50,
65
  "generator": generator,
 
 
66
  }
67
-
68
- # Load the face as an input condition for the sticker (optional, if supported by the model)
69
- # If your model supports conditioning on a specific face, load the face image here
70
- # options['image'] = Image.open(face_image)
71
-
72
  images = pipe(**options).images
73
- image_paths = [save_image(img) for img in images]
74
 
75
  return image_paths, seed
76
 
 
 
 
 
 
77
 
78
- def save_image(img: Image.Image) -> str:
79
- """Save an image to a file and return the path."""
80
- unique_name = f"{uuid.uuid4()}.png"
81
- img.save(unique_name)
82
- return unique_name
83
-
84
-
85
- def stick_me_workflow(image, clothing, pose, mood, randomize_seed: bool):
86
- """Workflow to generate stickers based on user-uploaded image and options."""
87
- # Convert the uploaded image to a face sticker
88
- face_path, message = face_to_sticker(image)
89
-
90
- if face_path is None:
91
- return message # Return error message if face detection fails
92
-
93
- # Generate a descriptive prompt based on user selections
94
- prompt = generate_prompt(clothing, pose, mood)
95
-
96
- # Generate stickers using the diffusion model with the extracted face and prompt
97
- stickers, seed = generate_stickers(prompt, face_path, randomize_seed=randomize_seed)
98
- return stickers
99
-
100
-
101
- def on_fallback_to_cpu():
102
- """Notify users when the app is running on CPU (due to GPU quota being exceeded)."""
103
- if not torch.cuda.is_available():
104
- return "Warning: GPU quota exceeded. Running on CPU, which will be significantly slower."
105
- return ""
106
 
107
-
108
- # Gradio interface setup
109
-
110
- with gr.Blocks() as demo:
111
- gr.Markdown("# Sticker Generator with 'Stick Me' Feature")
112
-
113
- # GPU Quota Handling
114
- gpu_warning = gr.Markdown(on_fallback_to_cpu(), visible=not torch.cuda.is_available())
115
-
116
- # New Stick Me Option
117
- with gr.Row():
118
- face_input = gr.Image(label="Upload Your Image for 'Stick Me'", type="filepath")
119
- clothing = gr.Dropdown(["Casual", "Formal", "Sports"], label="Choose Clothing")
120
- pose = gr.Dropdown(["Standing", "Sitting", "Running"], label="Choose Pose")
121
- mood = gr.Dropdown(["Happy", "Serious", "Excited"], label="Choose Mood")
122
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
123
-
124
- stick_me_button = gr.Button("Generate Stick Me Stickers")
125
-
126
- stick_me_result = gr.Gallery(label="Your Stick Me Stickers")
127
-
128
- stick_me_button.click(
129
- fn=stick_me_workflow,
130
- inputs=[face_input, clothing, pose, mood, randomize_seed],
131
- outputs=[stick_me_result]
132
  )
133
-
134
- gr.Markdown("# Generate Regular Stickers")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- prompt = gr.Textbox(label="Enter a Prompt for Sticker Creation", placeholder="Cute bunny", max_lines=1)
137
- generate_button = gr.Button("Generate Stickers")
138
- result = gr.Gallery(label="Generated Stickers")
 
 
 
 
139
 
140
- generate_button.click(
141
- fn=generate_stickers,
142
- inputs=[prompt],
143
- outputs=[result]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  )
145
 
146
- demo.launch()
 
 
1
+ #!/usr/bin/env python
2
+
3
  import os
4
  import random
5
  import uuid
 
11
  import spaces
12
  import torch
13
  from diffusers import DiffusionPipeline
 
14
  from typing import Tuple
15
 
16
+ # Setup rules for bad words (ensure the prompts are kid-friendly)
17
+ bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
18
+ default_negative = os.getenv("default_negative","")
19
+
20
+ def check_text(prompt, negative=""):
21
+ for i in bad_words:
22
+ if i in prompt:
23
+ return True
24
+ return False
25
+
26
+ # Kid-friendly styles
27
+ style_list = [
28
+ {
29
+ "name": "Cartoon",
30
+ "prompt": "colorful cartoon {prompt}. vibrant, playful, friendly, suitable for children, highly detailed, bright colors",
31
+ "negative_prompt": "scary, dark, violent, ugly, realistic",
32
+ },
33
+ {
34
+ "name": "Children's Illustration",
35
+ "prompt": "children's illustration {prompt}. cute, colorful, fun, simple shapes, smooth lines, highly detailed, joyful",
36
+ "negative_prompt": "scary, dark, violent, deformed, ugly",
37
+ },
38
+ {
39
+ "name": "Sticker",
40
+ "prompt": "children's sticker of {prompt}. bright colors, playful, high resolution, cartoonish",
41
+ "negative_prompt": "scary, dark, violent, ugly, low resolution",
42
+ },
43
+ {
44
+ "name": "Fantasy",
45
+ "prompt": "fantasy world for children with {prompt}. magical, vibrant, friendly, beautiful, colorful",
46
+ "negative_prompt": "dark, scary, violent, ugly, realistic",
47
+ },
48
+ {
49
+ "name": "(No style)",
50
+ "prompt": "{prompt}",
51
+ "negative_prompt": "",
52
+ },
53
+ ]
54
+
55
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
56
+ STYLE_NAMES = list(styles.keys())
57
+ DEFAULT_STYLE_NAME = "Sticker"
58
+
59
+ def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
60
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
61
+ return p.replace("{prompt}", positive), n + negative
62
+
63
+ DESCRIPTION = """## Children's Sticker Generator
64
+
65
+ Generate fun and playful stickers for children using AI.
66
+ """
67
+
68
+ if not torch.cuda.is_available():
69
+ DESCRIPTION += "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
70
+
71
+ MAX_SEED = np.iinfo(np.int32).max
72
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
73
+
74
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
75
 
76
+ # Initialize the DiffusionPipeline
77
  pipe = DiffusionPipeline.from_pretrained(
78
  "SG161222/RealVisXL_V3.0_Turbo", # or any model of your choice
79
  torch_dtype=torch.float16,
 
81
  variant="fp16"
82
  ).to(device)
83
 
84
+ # Convert mm to pixels for a specific DPI (300) and ensure divisible by 8
85
+ def mm_to_pixels(mm, dpi=300):
86
+ """Convert mm to pixels and make the dimensions divisible by 8."""
87
+ pixels = int((mm / 25.4) * dpi)
88
+ return pixels - (pixels % 8) # Adjust to the nearest lower multiple of 8
89
 
90
+ # Default sizes for 75mm and 35mm, rounded to nearest multiple of 8
91
+ size_map = {
92
+ "75mm": (mm_to_pixels(75), mm_to_pixels(75)), # 75mm in pixels at 300dpi
93
+ "35mm": (mm_to_pixels(35), mm_to_pixels(35)), # 35mm in pixels at 300dpi
94
+ }
95
 
96
+ # Function to post-process images (transparent or white background)
97
+ def save_image(img, background="transparent"):
98
+ img = img.convert("RGBA")
99
+ data = img.getdata()
100
+ new_data = []
 
 
 
 
 
 
101
 
102
+ if background == "transparent":
103
+ for item in data:
104
+ # Replace white with transparent
105
+ if item[0] == 255 and item[1] == 255 and item[2] == 255:
106
+ new_data.append((255, 255, 255, 0)) # Transparent
107
+ else:
108
+ new_data.append(item)
109
+ elif background == "white":
110
+ for item in data:
111
+ new_data.append(item) # Keep as white
112
 
113
+ img.putdata(new_data)
114
+ unique_name = str(uuid.uuid4()) + ".png"
115
+ img.save(unique_name)
116
+ return unique_name
117
 
118
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
119
+ if randomize_seed:
120
+ seed = random.randint(0, MAX_SEED)
121
+ return seed
122
 
123
+ @spaces.GPU(enable_queue=True)
124
+ def generate(
125
+ prompt: str,
126
+ negative_prompt: str = "",
127
+ use_negative_prompt: bool = False,
128
+ style: str = DEFAULT_STYLE_NAME,
129
+ seed: int = 0,
130
+ size: str = "75mm",
131
+ guidance_scale: float = 3,
132
+ randomize_seed: bool = False,
133
+ background: str = "transparent",
134
+ progress=gr.Progress(track_tqdm=True),
135
+ ):
136
+ if check_text(prompt, negative_prompt):
137
+ raise ValueError("Prompt contains restricted words.")
138
 
139
+ # Ensure prompt is 2-3 words long
140
+ prompt = " ".join(re.findall(r'\w+', prompt)[:3])
141
+
142
+ # Apply style
143
+ prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
144
+ seed = int(randomize_seed_fn(seed, randomize_seed))
145
+ generator = torch.Generator().manual_seed(seed)
146
+
147
+ # Ensure we have only white or transparent background options
148
+ width, height = size_map.get(size, (1024, 1024))
149
+
150
+ if not use_negative_prompt:
151
+ negative_prompt = "" # type: ignore
152
 
 
153
  options = {
154
  "prompt": prompt,
155
+ "negative_prompt": negative_prompt,
156
+ "width": width,
157
+ "height": height,
158
  "guidance_scale": guidance_scale,
159
+ "num_inference_steps": 25,
160
  "generator": generator,
161
+ "num_images_per_prompt": 6, # Max 6 images
162
+ "output_type": "pil",
163
  }
164
+
165
+ # Generate images with the pipeline
 
 
 
166
  images = pipe(**options).images
167
+ image_paths = [save_image(img, background) for img in images]
168
 
169
  return image_paths, seed
170
 
171
+ examples = [
172
+ "cute bunny",
173
+ "happy cat",
174
+ "funny dog",
175
+ ]
176
 
177
+ css = '''
178
+ .gradio-container{max-width: 700px !important}
179
+ h1{text-align:center}
180
+ '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
+ # Define the Gradio UI for the sticker generator
183
+ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
184
+ gr.Markdown(DESCRIPTION)
185
+ gr.DuplicateButton(
186
+ value="Duplicate Space for private use",
187
+ elem_id="duplicate-button",
188
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  )
190
+ with gr.Group():
191
+ with gr.Row():
192
+ prompt = gr.Text(
193
+ label="Enter your prompt",
194
+ show_label=False,
195
+ max_lines=1,
196
+ placeholder="Enter 2-3 word prompt (e.g., cute bunny)",
197
+ container=False,
198
+ )
199
+ run_button = gr.Button("Run")
200
+ result = gr.Gallery(label="Generated Stickers", columns=2, preview=True)
201
+ with gr.Accordion("Advanced options", open=False):
202
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True)
203
+ negative_prompt = gr.Text(
204
+ label="Negative prompt",
205
+ max_lines=1,
206
+ placeholder="Enter a negative prompt",
207
+ value="(scary, violent, dark, ugly)",
208
+ visible=True,
209
+ )
210
+ seed = gr.Slider(
211
+ label="Seed",
212
+ minimum=0,
213
+ maximum=MAX_SEED,
214
+ step=1,
215
+ value=0,
216
+ visible=True
217
+ )
218
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
219
+ size_selection = gr.Radio(
220
+ choices=["75mm", "35mm"],
221
+ value="75mm",
222
+ label="Sticker Size",
223
+ )
224
+ style_selection = gr.Radio(
225
+ choices=STYLE_NAMES,
226
+ value=DEFAULT_STYLE_NAME,
227
+ label="Image Style",
228
+ )
229
+ background_selection = gr.Radio(
230
+ choices=["transparent", "white"],
231
+ value="transparent",
232
+ label="Background Color",
233
+ )
234
+ guidance_scale = gr.Slider(
235
+ label="Guidance Scale",
236
+ minimum=0.1,
237
+ maximum=20.0,
238
+ step=0.1,
239
+ value=6,
240
+ )
241
 
242
+ gr.Examples(
243
+ examples=examples,
244
+ inputs=prompt,
245
+ outputs=[result, seed],
246
+ fn=generate,
247
+ cache_examples=CACHE_EXAMPLES,
248
+ )
249
 
250
+ gr.on(
251
+ triggers=[
252
+ prompt.submit,
253
+ negative_prompt.submit,
254
+ run_button.click,
255
+ ],
256
+ fn=generate,
257
+ inputs=[
258
+ prompt,
259
+ negative_prompt,
260
+ use_negative_prompt,
261
+ style_selection,
262
+ seed,
263
+ size_selection,
264
+ guidance_scale,
265
+ randomize_seed,
266
+ background_selection,
267
+ ],
268
+ outputs=[result, seed],
269
+ api_name="run",
270
  )
271
 
272
+ if __name__ == "__main__":
273
+ demo.queue(max_size=20).launch()