Max Reimann commited on
Commit
11a70dd
β€’
1 Parent(s): dc6a058

add page for xdog prediction

Browse files
images/apdrawing/img_1585.png ADDED

Git LFS Details

  • SHA256: 956e58eeda1522d3969b44a5c05fe25ecca26afa622c62111d991817d5cbc432
  • Pointer size: 131 Bytes
  • Size of remote file: 511 kB
images/apdrawing/img_1592.png ADDED

Git LFS Details

  • SHA256: b017936b260d7005ab746410bee0d6623024d6813ab3f4b3911dfff3e3852258
  • Pointer size: 131 Bytes
  • Size of remote file: 435 kB
images/apdrawing/img_1594.png ADDED

Git LFS Details

  • SHA256: 4f20624d3e94eaf6631390d9e2e75986a319122c33a3feab285dc79eb2c8a87a
  • Pointer size: 131 Bytes
  • Size of remote file: 518 kB
images/apdrawing/img_1600.png ADDED

Git LFS Details

  • SHA256: 85417c6ed60327ae0e0ecfa754c3a1158e91ad2fef21a56e518ff8f8f25960b5
  • Pointer size: 131 Bytes
  • Size of remote file: 390 kB
images/apdrawing/img_1607.png ADDED

Git LFS Details

  • SHA256: 1333717935e5f823c1c59e1569086133fd231e7a3dc4f7c378576ce522d742a9
  • Pointer size: 131 Bytes
  • Size of remote file: 513 kB
images/apdrawing/img_1616.png ADDED

Git LFS Details

  • SHA256: 90a1c17a599b00c92a9c77a40dd26bdd34c5e47e206344f2522be898276f2e79
  • Pointer size: 131 Bytes
  • Size of remote file: 502 kB
pages/3_πŸ§‘_Predict_Portrait_xDoG.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import base64
3
+ from io import BytesIO
4
+ from pathlib import Path
5
+ import os
6
+ import shutil
7
+ import sys
8
+ import time
9
+
10
+ import numpy as np
11
+ import torch.nn.functional as F
12
+ import torch
13
+ import streamlit as st
14
+ from st_click_detector import click_detector
15
+
16
+ from matplotlib import pyplot as plt
17
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
18
+ from torchvision.transforms import ToPILImage, Compose, ToTensor, Normalize
19
+ from PIL import Image
20
+
21
+ from huggingface_hub import hf_hub_download
22
+
23
+
24
+ PACKAGE_PARENT = '..'
25
+ WISE_DIR = '../wise/'
26
+ SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
27
+ sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
28
+ sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, WISE_DIR)))
29
+
30
+
31
+ from local_ppn.options.test_options import TestOptions
32
+ from local_ppn.models import create_model
33
+
34
+
35
+
36
+ class CustomOpts(TestOptions):
37
+
38
+ def remove_options(self, parser, options):
39
+ for option in options:
40
+ for action in parser._actions:
41
+ print(action)
42
+ if vars(action)['option_strings'][0] == option:
43
+ parser._handle_conflict_resolve(None,[(option,action)])
44
+ break
45
+
46
+ def initialize(self, parser):
47
+ parser = super(CustomOpts, self).initialize(parser)
48
+ self.remove_options(parser, ["--dataroot"])
49
+ return parser
50
+
51
+ def print_options(self, opt):
52
+ pass
53
+
54
+ def add_predefined_images():
55
+ images = []
56
+ for f in os.listdir(os.path.join(SCRIPT_DIR, PACKAGE_PARENT, 'images','apdrawing')):
57
+ if not f.endswith('.png'):
58
+ continue
59
+ AB = Image.open(os.path.join(SCRIPT_DIR, PACKAGE_PARENT, 'images','apdrawing', f)).convert('RGB')
60
+ # split AB image into A and B
61
+ w, h = AB.size
62
+ w2 = int(w / 2)
63
+ A = AB.crop((0, 0, w2, h))
64
+ B = AB.crop((w2, 0, w, h))
65
+ images.append(A)
66
+ return images
67
+
68
+ @st.experimental_singleton
69
+ def make_model(_unused=None):
70
+ model_path = hf_hub_download(repo_id="MaxReimann/WISE-APDrawing-XDoG", filename="apdrawing_xdog_ppn_conv.pth")
71
+ os.makedirs(os.path.join(SCRIPT_DIR, PACKAGE_PARENT, "trained_models", "ours_apdrawing"), exist_ok=True)
72
+ shutil.copy2(model_path, os.path.join(SCRIPT_DIR, PACKAGE_PARENT, "trained_models", "ours_apdrawing", "latest_net_G.pth"))
73
+
74
+ opt = CustomOpts().parse() # get test options
75
+ # hard-code some parameters for test
76
+ opt.num_threads = 0 # test code only supports num_threads = 0
77
+ opt.batch_size = 1 # test code only supports batch_size = 1
78
+ # opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
79
+ opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
80
+ opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
81
+ opt.dataroot ="null"
82
+ opt.direction = "BtoA"
83
+ opt.model = "pix2pix"
84
+ opt.ppnG = "our_xdog"
85
+ opt.name = "ours_apdrawing"
86
+ opt.netG = "resnet_9blocks"
87
+ opt.no_dropout = True
88
+ opt.norm = "batch"
89
+ opt.load_size = 576
90
+ opt.crop_size = 512
91
+ opt.eval = False
92
+ model = create_model(opt) # create a model given opt.model and other options
93
+ model.setup(opt) # regular setup: load and print networks; create schedulers
94
+ if opt.eval:
95
+ model.eval()
96
+
97
+
98
+ return model, opt
99
+
100
+ def predict(image):
101
+ model, opt = make_model()
102
+ t = Compose([
103
+ ToTensor(),
104
+ Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
105
+ ])
106
+ inp = image.resize((opt.crop_size, opt.crop_size), resample=Image.BICUBIC)
107
+ inp = t(inp).unsqueeze(0).cuda()
108
+ x = model.netG.module.ppn_part_forward(inp)
109
+
110
+ output = model.netG.module.conv_part_forward(x)
111
+ out_img = ToPILImage()(output.squeeze(0))
112
+ return out_img
113
+
114
+
115
+
116
+ st.title("xDoG+CNN Portrait Drawing ")
117
+
118
+ images = add_predefined_images()
119
+
120
+ html_code = '<div class="column" style="display: flex; flex-wrap: wrap; padding: 0 4px;">'
121
+ for i, image in enumerate(images):
122
+ buffered = BytesIO()
123
+ image.save(buffered, format="JPEG")
124
+ encoded = base64.b64encode(buffered.getvalue()).decode()
125
+ html_code += f"<a href='#' id='{i}' style='padding: 0px 5px'><img height='120px' style='margin-top: 8px;' src='data:image/jpeg;base64,{encoded}'></a>"
126
+ html_code += "</div>"
127
+ clicked = click_detector(html_code)
128
+
129
+ uploaded_im = st.file_uploader(f"OR: Load portrait:", type=["png", "jpg"], )
130
+ if uploaded_im is not None:
131
+ img = Image.open(uploaded_im)
132
+ img = img.convert('RGB')
133
+ buffered = BytesIO()
134
+ img.save(buffered, format="JPEG")
135
+
136
+
137
+ clicked_img = None
138
+ if clicked:
139
+ clicked_img = images[int(clicked)]
140
+
141
+ sel_img = img if uploaded_im is not None else clicked_img
142
+ if sel_img:
143
+ result_container = st.container()
144
+ coll1, coll2 = result_container.columns([3,2])
145
+ coll1.header("Result")
146
+ coll2.header("Global Edits")
147
+
148
+ model, opt = make_model()
149
+ t = Compose([
150
+ ToTensor(),
151
+ Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
152
+ ])
153
+ inp = sel_img.resize((opt.crop_size, opt.crop_size), resample=Image.BICUBIC)
154
+ inp = t(inp).unsqueeze(0).cuda()
155
+ # vp = model.netG.module.ppn_part_forward(inp)
156
+
157
+ vp = model.netG.module.predict_parameters(inp)
158
+ inp = (inp * 0.5) + 0.5
159
+
160
+ effect = model.netG.module.apply_visual_effect.effect
161
+
162
+ with coll2:
163
+ # ("blackness", "contour", "strokeWidth", "details", "saturation", "contrast", "brightness")
164
+ show_params_names = ["strokeWidth", "blackness", "contours"]
165
+ display_means = []
166
+ params_mapping = {"strokeWidth": ['strokeWidth'], 'blackness': ["blackness"], "contours": [ "details", "contour"]}
167
+ def create_slider(name):
168
+ params = params_mapping[name] if name in params_mapping else [name]
169
+ means = [torch.mean(vp[:, effect.vpd.name2idx[n]]).item() for n in params]
170
+ display_mean = float(np.average(means) + 0.5)
171
+ display_means.append(display_mean)
172
+ slider = st.slider(f"Mean {name}: ", 0.0, 1.0, value=display_mean, step=0.05)
173
+ for i, param_name in enumerate(params):
174
+ vp[:, effect.vpd.name2idx[param_name]] += slider - (means[i]+ 0.5)
175
+ # vp.clamp_(-0.5, 0.5)
176
+ # pass
177
+
178
+ for name in show_params_names:
179
+ create_slider(name)
180
+
181
+ x = model.netG.module.apply_visual_effect(inp, vp)
182
+ x = (x - 0.5) / 0.5
183
+
184
+ only_x_dog = st.checkbox('only xdog', value=False, help='if checked, use only ppn+xdog, else use ppn+xdog+post-processing cnn')
185
+ if only_x_dog:
186
+ output = x[:,0].repeat(1,3,1,1)
187
+ print('shape output', output.shape)
188
+ else:
189
+ output = model.netG.module.conv_part_forward(x)
190
+
191
+ out_img = ToPILImage()(output.squeeze(0))
192
+ output = out_img.resize((320,320), resample=Image.BICUBIC)
193
+ with coll1:
194
+ st.image(output)
pages/{3_πŸ“–_Readme.py β†’ 4_πŸ“–_Readme.py} RENAMED
File without changes
requirements.txt CHANGED
@@ -10,4 +10,5 @@ streamlit==1.10.0
10
  streamlit_drawable_canvas==0.8.0
11
  streamlit_extras==0.1.5
12
  st_click_detector
13
- scipy
 
 
10
  streamlit_drawable_canvas==0.8.0
11
  streamlit_extras==0.1.5
12
  st_click_detector
13
+ scipy
14
+ huggingface_hub