sberbank-ai commited on
Commit
346b427
·
1 Parent(s): 58c3bbc

feat: Add app file

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+ import gradio as gr
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ from scgan.config import Config
8
+ from scgan.generate_images import ImgGenerator
9
+
10
+
11
+ def download_weights(repo_id):
12
+ char_map_path = hf_hub_download(repo_id, "char_map.pkl")
13
+ weights_path = hf_hub_download(repo_id, "model_checkpoint_epoch_200.pth.tar")
14
+ return char_map_path, weights_path
15
+
16
+
17
+ def get_text_from_image(img):
18
+ COLOR_MIN = np.array([0, 0, 0],np.uint8)
19
+ COLOR_MAX = np.array([250,250,160],np.uint8)
20
+
21
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
22
+ text_mask = cv2.inRange(img, COLOR_MIN, COLOR_MAX).astype(bool)
23
+ img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
24
+
25
+ bg = np.ones(img.shape, dtype=np.uint8) * 255
26
+ bg[text_mask] = img[text_mask]
27
+ return bg
28
+
29
+
30
+ def predict(text):
31
+ imgs, texts = GENERATOR.generate(word_list=[text])
32
+ image_on_white = get_text_from_image(imgs[0])
33
+ return image_on_white
34
+
35
+
36
+ CHAR_MAP_PATH, WEIGHTS_PATH = download_weights("sberbank-ai/scrabblegan-peter")
37
+
38
+ GENERATOR = ImgGenerator(
39
+ checkpt_path=WEIGHTS_PATH,
40
+ config=Config,
41
+ char_map_path=CHAR_MAP_PATH
42
+ )
43
+
44
+ gr.Interface(
45
+ predict,
46
+ inputs=gr.Textbox(label="Type your text to generate it on an image"),
47
+ outputs=gr.Image(label="Generated image"),
48
+ title="Peter handwritten image generation",
49
+ ).launch()