Update app.py
Browse files
app.py
CHANGED
@@ -12,11 +12,15 @@ import spaces
|
|
12 |
import torch
|
13 |
from diffusers import DiffusionPipeline
|
14 |
from typing import Tuple
|
|
|
15 |
|
16 |
# Setup rules for bad words (ensure the prompts are kid-friendly)
|
17 |
bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
|
18 |
default_negative = os.getenv("default_negative","")
|
19 |
|
|
|
|
|
|
|
20 |
def check_text(prompt, negative=""):
|
21 |
for i in bad_words:
|
22 |
if i in prompt:
|
@@ -120,6 +124,14 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
|
120 |
seed = random.randint(0, MAX_SEED)
|
121 |
return seed
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
@spaces.GPU(enable_queue=True)
|
124 |
def generate(
|
125 |
prompt: str,
|
@@ -133,6 +145,9 @@ def generate(
|
|
133 |
background: str = "transparent",
|
134 |
progress=gr.Progress(track_tqdm=True),
|
135 |
):
|
|
|
|
|
|
|
136 |
if check_text(prompt, negative_prompt):
|
137 |
raise ValueError("Prompt contains restricted words.")
|
138 |
|
@@ -169,18 +184,23 @@ def generate(
|
|
169 |
return image_paths, seed
|
170 |
|
171 |
examples = [
|
172 |
-
"
|
173 |
-
"
|
174 |
-
"
|
|
|
|
|
|
|
175 |
]
|
176 |
|
177 |
-
css =
|
178 |
-
|
179 |
-
|
180 |
-
|
|
|
|
|
181 |
|
182 |
# Define the Gradio UI for the sticker generator
|
183 |
-
with gr.Blocks(
|
184 |
gr.Markdown(DESCRIPTION)
|
185 |
gr.DuplicateButton(
|
186 |
value="Duplicate Space for private use",
|
|
|
12 |
import torch
|
13 |
from diffusers import DiffusionPipeline
|
14 |
from typing import Tuple
|
15 |
+
from transformers import pipeline
|
16 |
|
17 |
# Setup rules for bad words (ensure the prompts are kid-friendly)
|
18 |
bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
|
19 |
default_negative = os.getenv("default_negative","")
|
20 |
|
21 |
+
# Add the translation pipeline
|
22 |
+
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
|
23 |
+
|
24 |
def check_text(prompt, negative=""):
|
25 |
for i in bad_words:
|
26 |
if i in prompt:
|
|
|
124 |
seed = random.randint(0, MAX_SEED)
|
125 |
return seed
|
126 |
|
127 |
+
def translate_if_korean(text):
|
128 |
+
# Check if the text contains Korean characters
|
129 |
+
if re.search("[\uac00-\ud7a3]", text):
|
130 |
+
# Translate Korean to English
|
131 |
+
translation = translator(text, max_length=512)
|
132 |
+
return translation[0]['translation_text']
|
133 |
+
return text
|
134 |
+
|
135 |
@spaces.GPU(enable_queue=True)
|
136 |
def generate(
|
137 |
prompt: str,
|
|
|
145 |
background: str = "transparent",
|
146 |
progress=gr.Progress(track_tqdm=True),
|
147 |
):
|
148 |
+
# Translate prompt if it's in Korean
|
149 |
+
prompt = translate_if_korean(prompt)
|
150 |
+
|
151 |
if check_text(prompt, negative_prompt):
|
152 |
raise ValueError("Prompt contains restricted words.")
|
153 |
|
|
|
184 |
return image_paths, seed
|
185 |
|
186 |
examples = [
|
187 |
+
"๊ท์ฌ์ด ๊ณ ์์ด",
|
188 |
+
"ํ๋ณตํ ํ ๋ผ",
|
189 |
+
"์๊ณ ์๋ ๊ฐ์์ง",
|
190 |
+
"์ถค์ถ๋ ๋๊ณ ๋",
|
191 |
+
"์ ๋๋ ์๊ธฐ ๋
์๋ฆฌ",
|
192 |
+
"์ฆ๊ฑฐ์ด ์๊ธฐ ์ฌ์",
|
193 |
]
|
194 |
|
195 |
+
css = """
|
196 |
+
footer {
|
197 |
+
visibility: hidden;
|
198 |
+
}
|
199 |
+
"""
|
200 |
+
|
201 |
|
202 |
# Define the Gradio UI for the sticker generator
|
203 |
+
with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
|
204 |
gr.Markdown(DESCRIPTION)
|
205 |
gr.DuplicateButton(
|
206 |
value="Duplicate Space for private use",
|