File size: 9,247 Bytes
6124669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9174f6
 
6124669
 
c9174f6
6124669
 
 
 
 
 
 
c9174f6
6124669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e4890f
 
6124669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import os
import sys

import torch.nn.functional as F
import torch
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import matplotlib.cm
from PIL import Image

import streamlit as st
from streamlit_drawable_canvas import st_canvas


PACKAGE_PARENT = '..'
WISE_DIR = '../wise/'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, WISE_DIR)))



from effects.gauss2d_xy_separated import Gauss2DEffect
from effects.minimal_pipeline import MinimalPipelineEffect
from helpers import torch_to_np, np_to_torch
from effects import get_default_settings
from demo_config import HUGGING_FACE

st.set_page_config(page_title="Editing Demo", layout="wide")

# @st.cache(hash_funcs={OilPaintEffect: id})
@st.cache(hash_funcs={MinimalPipelineEffect: id})
def local_edits_create_effect():
    effect, preset, param_set = get_default_settings("minimal_pipeline")
    effect.enable_checkpoints()
    effect.cuda()
    return effect, param_set


effect, param_set = local_edits_create_effect()

@st.experimental_memo
def gen_param_strength_fig():
    cmap = matplotlib.cm.get_cmap('plasma')
    # cmap show
    gradient = np.linspace(0, 1, 256)
    gradient = np.vstack((gradient, gradient))
    fig, ax = plt.subplots(figsize=(3, 0.1))
    fig.patch.set_alpha(0.0)
    ax.set_title("parameter strength", fontsize=6.5, loc="left")
    ax.imshow(gradient, aspect='auto', cmap=cmap)
    ax.set_axis_off()
    return fig, cmap

cmap_fig, cmap = gen_param_strength_fig()

st.session_state["canvas_key"] = "canvas"
try:
    vp = st.session_state["result_vp"]
    org_cuda = st.session_state["effect_input"]
except KeyError as e:
    print("init run, certain keys not found. If this happens once its ok.")

if st.session_state["action"] != "switch_page_from_local_edits":
    st.session_state.local_edit_action = "init"

st.session_state["action"] = "switch_page_from_local_edits" # on switchback, remember effect input

if "mask_edit_counter" not in st.session_state:
    st.session_state["mask_edit_counter"] = 1
if "initial_drawing" not in st.session_state:
    st.session_state["initial_drawing"] = {"random": st.session_state["mask_edit_counter"], "background": "#eee"}

def on_slider_change():
    if st.session_state.local_edit_action == "init":
        st.stop()
    st.session_state.local_edit_action = "slider"

def on_param_change():
    st.session_state.local_edit_action = "param_change"

active_param = st.sidebar.selectbox("active parameter: ", param_set + ["smooth"], index=2, on_change=on_param_change)

st.sidebar.text("Drawing options")
if active_param != "smooth":
    plus_or_minus = st.sidebar.slider("Increase or decrease param map: ", -1.0, 1.0, 0.8, 0.05,
                                      on_change=on_slider_change)
else:
    sigma = st.sidebar.slider("Sigma: ", 0.1, 10.0, 0.5, 0.1, on_change=on_slider_change)

stroke_width = st.sidebar.slider("Stroke width: ", 1, 50, 20, on_change=on_slider_change)
drawing_mode = st.sidebar.selectbox(
    "Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"), on_change=on_slider_change,
)

st.sidebar.text("Viewing options")
if active_param != "smooth":
    overlay = st.sidebar.slider("show parameter overlay: ", 0.0, 1.0, 0.8, 0.02, on_change=on_slider_change)
    st.sidebar.pyplot(cmap_fig, bbox_inches='tight', pad_inches=0)

st.sidebar.text("Update:")
realtime_update = st.sidebar.checkbox("Update in realtime", True)
clear_after_draw = st.sidebar.checkbox("Clear Canvas after each Stroke", False)
invert_selection = st.sidebar.checkbox("Invert Selection", False)


@st.experimental_memo
def greyscale_org(_org_cuda, content_id): #content_id is used for hashing
    if HUGGING_FACE:
        wsize = 450
        img_org_height, img_org_width = _org_cuda.shape[-2:]
        wpercent = (wsize / float(img_org_width))
        hsize = int((float(img_org_height) * float(wpercent)))
    else:
        longest_edge = 670
        img_org_height, img_org_width = _org_cuda.shape[-2:]
        max_width_height = max(img_org_width, img_org_height)
        hsize = int((float(longest_edge) * float(float(img_org_height) / max_width_height)))
        wsize = int((float(longest_edge) * float(float(img_org_width) / max_width_height)))

    org_img = F.interpolate(_org_cuda, (hsize, wsize), mode="bilinear")
    org_img = torch.mean(org_img, dim=1, keepdim=True) / 2.0
    org_img = torch_to_np(org_img)[..., np.newaxis].repeat(3, axis=2)
    return org_img, hsize, wsize

def generate_param_mask(vp):
    greyscale_img, hsize, wsize = greyscale_org(org_cuda, st.session_state["Content_id"])
    if active_param != "smooth":
        scaled_vp = F.interpolate(vp, (hsize, wsize))[:, effect.vpd.name2idx[active_param]]
        param_cmapped = cmap((scaled_vp + 0.5).cpu().numpy())[...,:3][0]
        greyscale_img = greyscale_img * (1 - overlay) + param_cmapped * overlay
    return Image.fromarray((greyscale_img * 255).astype(np.uint8))

def compute_results(_vp):
    if "cached_canvas" in st.session_state and st.session_state["cached_canvas"].image_data is not None:
        canvas_result = st.session_state["cached_canvas"]
        abc = np_to_torch(canvas_result.image_data.astype(np.float32)).sum(dim=1, keepdim=True).cuda()

        if invert_selection:
            abc = abc * (- 1.0) + 1.0

        img_org_width = org_cuda.shape[-1]
        img_org_height = org_cuda.shape[-2]
        res_data = F.interpolate(abc, (img_org_height, img_org_width)).squeeze(1)

        if active_param != "smooth":
            _vp[:, effect.vpd.name2idx[active_param]] += plus_or_minus * res_data
            _vp.clamp_(-0.5, 0.5)
        else:
            gauss2dx = Gauss2DEffect(dxdy=[1.0, 0.0], dim_kernsize=5)
            gauss2dy = Gauss2DEffect(dxdy=[0.0, 1.0], dim_kernsize=5)

            vp_smoothed = gauss2dx(_vp, torch.tensor(sigma).cuda())
            vp_smoothed = gauss2dy(vp_smoothed, torch.tensor(sigma).cuda())

            print(res_data.shape)
            print(_vp.shape)
            print(vp_smoothed.shape)
            _vp = torch.lerp(_vp, vp_smoothed, res_data.unsqueeze(1))

    with torch.no_grad():
        result_cuda = effect(org_cuda, _vp)

    _, hsize, wsize = greyscale_org(org_cuda, st.session_state["Content_id"])
    result_cuda = F.interpolate(result_cuda, (hsize, wsize), mode="bilinear")

    return Image.fromarray((torch_to_np(result_cuda) * 255.0).astype(np.uint8)), _vp

coll1, coll2 = st.columns(2)
coll1.header("Draw Mask:")
coll2.header("Live Result")

# there is no way of removing the canvas history/state without rerunning the whole program.
# therefore, giving the canvas a initial_drawing that differs from the canvas state will clear the background
def mark_canvas_for_redraw():
    print("mark for redraw")
    st.session_state["mask_edit_counter"] += 1 # change state of initial drawing
    initial_drawing = {"random": st.session_state["mask_edit_counter"], "background": "#eee"}
    st.session_state["initial_drawing"] = initial_drawing


with coll1:
    print("edit action", st.session_state.local_edit_action)
    if clear_after_draw and st.session_state.local_edit_action not in ("slider", "param_change", "init"):
        if st.session_state.local_edit_action == "redraw":
            st.session_state.local_edit_action = "draw"
            mark_canvas_for_redraw()
        else:
            st.session_state.local_edit_action = "redraw"

    mask = generate_param_mask(st.session_state["result_vp"])
    st.session_state["last_mask"] = mask

    # Create a canvas component
    canvas_result = st_canvas(
        fill_color="rgba(0, 0, 0, 1)",
        stroke_width=stroke_width,
        background_image=mask,
        update_streamlit=realtime_update,
        width=mask.width,
        height=mask.height,
        initial_drawing=st.session_state["initial_drawing"],
        drawing_mode=drawing_mode,
        key=st.session_state.canvas_key,
    ) 

    if canvas_result.json_data is  None:
        print("stops")
        st.stop()

    st.session_state["cached_canvas"] = canvas_result

    print("compute result")
    img_res, vp = compute_results(vp)
    st.session_state["last_result"] = img_res
    st.session_state["result_vp"] = vp

    st.markdown("### Mask: " + active_param)

if st.session_state.local_edit_action in ("slider", "param_change", "init"):
    print("set redraw")
    st.session_state.local_edit_action = "redraw"

if "objects" in canvas_result.json_data and canvas_result.json_data["objects"] != []:
    print(st.session_state["user"], " edited local param canvas")       

print("plot masks")
texts = []
preview_masks = []
img = st.session_state["last_mask"]
for i, p in enumerate(param_set):
    idx = effect.vpd.name2idx[p]
    iii = F.interpolate(vp[:, idx:idx + 1] + 0.5, (int(img.height * 0.2), int(img.width * 0.2)))
    texts.append(p[:15])
    preview_masks.append(torch_to_np(iii))

coll2.image(img_res)  # , use_column_width="auto")
ppp = st.columns(len(param_set))
for i, (txt, im) in enumerate(zip(texts, preview_masks)):
    ppp[i].text(txt)
    ppp[i].image(im, clamp=True)

print("....")