openfree commited on
Commit
c931415
โ€ข
1 Parent(s): b829850

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -8
app.py CHANGED
@@ -12,11 +12,15 @@ import spaces
12
  import torch
13
  from diffusers import DiffusionPipeline
14
  from typing import Tuple
 
15
 
16
  # Setup rules for bad words (ensure the prompts are kid-friendly)
17
  bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
18
  default_negative = os.getenv("default_negative","")
19
 
 
 
 
20
  def check_text(prompt, negative=""):
21
  for i in bad_words:
22
  if i in prompt:
@@ -120,6 +124,14 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
120
  seed = random.randint(0, MAX_SEED)
121
  return seed
122
 
 
 
 
 
 
 
 
 
123
  @spaces.GPU(enable_queue=True)
124
  def generate(
125
  prompt: str,
@@ -133,6 +145,9 @@ def generate(
133
  background: str = "transparent",
134
  progress=gr.Progress(track_tqdm=True),
135
  ):
 
 
 
136
  if check_text(prompt, negative_prompt):
137
  raise ValueError("Prompt contains restricted words.")
138
 
@@ -169,18 +184,23 @@ def generate(
169
  return image_paths, seed
170
 
171
  examples = [
172
- "cute bunny",
173
- "happy cat",
174
- "funny dog",
 
 
 
175
  ]
176
 
177
- css = '''
178
- .gradio-container{max-width: 700px !important}
179
- h1{text-align:center}
180
- '''
 
 
181
 
182
  # Define the Gradio UI for the sticker generator
183
- with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
184
  gr.Markdown(DESCRIPTION)
185
  gr.DuplicateButton(
186
  value="Duplicate Space for private use",
 
12
  import torch
13
  from diffusers import DiffusionPipeline
14
  from typing import Tuple
15
+ from transformers import pipeline
16
 
17
  # Setup rules for bad words (ensure the prompts are kid-friendly)
18
  bad_words = json.loads(os.getenv('BAD_WORDS', '["violence", "blood", "scary", "death", "ghost"]'))
19
  default_negative = os.getenv("default_negative","")
20
 
21
+ # Add the translation pipeline
22
+ translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
23
+
24
  def check_text(prompt, negative=""):
25
  for i in bad_words:
26
  if i in prompt:
 
124
  seed = random.randint(0, MAX_SEED)
125
  return seed
126
 
127
+ def translate_if_korean(text):
128
+ # Check if the text contains Korean characters
129
+ if re.search("[\uac00-\ud7a3]", text):
130
+ # Translate Korean to English
131
+ translation = translator(text, max_length=512)
132
+ return translation[0]['translation_text']
133
+ return text
134
+
135
  @spaces.GPU(enable_queue=True)
136
  def generate(
137
  prompt: str,
 
145
  background: str = "transparent",
146
  progress=gr.Progress(track_tqdm=True),
147
  ):
148
+ # Translate prompt if it's in Korean
149
+ prompt = translate_if_korean(prompt)
150
+
151
  if check_text(prompt, negative_prompt):
152
  raise ValueError("Prompt contains restricted words.")
153
 
 
184
  return image_paths, seed
185
 
186
  examples = [
187
+ "๊ท€์—ฌ์šด ๊ณ ์–‘์ด",
188
+ "ํ–‰๋ณตํ•œ ํ† ๋ผ",
189
+ "์›ƒ๊ณ ์žˆ๋Š” ๊ฐ•์•„์ง€",
190
+ "์ถค์ถ”๋Š” ๋Œ๊ณ ๋ž˜",
191
+ "์‹ ๋‚˜๋Š” ์•„๊ธฐ ๋…์ˆ˜๋ฆฌ",
192
+ "์ฆ๊ฑฐ์šด ์•„๊ธฐ ์‚ฌ์ž",
193
  ]
194
 
195
+ css = """
196
+ footer {
197
+ visibility: hidden;
198
+ }
199
+ """
200
+
201
 
202
  # Define the Gradio UI for the sticker generator
203
+ with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo:
204
  gr.Markdown(DESCRIPTION)
205
  gr.DuplicateButton(
206
  value="Duplicate Space for private use",