Spaces:
Running
Running
import random | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
from DPMInteractive import g_st, g_et, g_num, g_res | |
from DPMInteractive import init_change, shrink_change, conv_change | |
from DPMInteractive import cond_prob_init_change, cond_prob_alpha_change, cond_prob_cond_change | |
from DPMInteractive import forward_init_change, forward_seq_apply | |
from DPMInteractive import backward_seq_apply, fit_and_backward_apply | |
from DPMInteractive import contraction_init_change, contraction_alpha_change, change_two_inputs_seed, contraction_apply | |
from DPMInteractive import fixed_point_init_change, fixed_point_apply_iterate | |
from DPMInteractive import forward_plot_part, backward_plot_part, fit_plot_part, fixed_plot_part | |
from RenderMarkdown import md_introduction_block, md_transform_block, md_likelihood_block, md_posterior_block | |
from RenderMarkdown import md_forward_process_block, md_backward_process_block, md_fit_posterior_block | |
from RenderMarkdown import md_posterior_transform_block, md_deconvolution_block, md_cond_kl_block, md_approx_gauss_block | |
from RenderMarkdown import md_non_expanding_block, md_stationary_block, md_reference_block, md_about_block | |
from Misc import g_css, js_head, js_load | |
def gr_empty_space(size=1): | |
space = gr.Markdown(" "*size, elem_classes="bgc") | |
return space | |
def gr_number(label=None, minimum=None, maximum=None, value=None, step=1.0, precision=0, min_width=160): | |
number = gr.Number(label=label, minimum=minimum, maximum=maximum, value=value, step=step, | |
precision=precision, min_width=min_width) | |
return number | |
def gr_val(val): | |
return gr.Number(value=val, visible=False) | |
def apply_listener(apply_button, apply_func, plot_func, reseted_state, apply_inputs, apply_outputs, | |
plot_inputs, plot_outputs): | |
def enable_button(value): | |
button = gr.Button(value=value, interactive=True) | |
return button | |
def disable_button(value): | |
button = gr.Button(value=value, interactive=False) | |
return button, None | |
listener = apply_button.click(disable_button, [apply_button], [apply_button, reseted_state]) | |
listener = listener.then(apply_func, apply_inputs, apply_outputs, show_progress="minimal") | |
listener = listener.then(plot_func, plot_inputs + [gr_val(0)], plot_outputs, show_progress="minimal") | |
listener = listener.then(plot_func, plot_inputs + [gr_val(1)], plot_outputs, show_progress="minimal") | |
listener = listener.then(plot_func, plot_inputs + [gr_val(2)], plot_outputs, show_progress="minimal") | |
listener = listener.then(enable_button, [apply_button], apply_button) | |
return | |
def transform_block(): | |
x_state = gr.State(value=None) | |
x_pdf_state = gr.State(value=None) | |
title = "Demo 1 - Random Variable Transform In DPM" | |
with gr.Accordion(label=title, elem_classes="first_demo", elem_id="demo_1"): | |
with gr.Group(elem_classes="normal"): | |
with gr.Row(): | |
init_seed = gr.Number(label="random seed", value=100, minimum=0, step=1) | |
shrink_alpha = gr.Slider(label="alpha of linear transform", value=0.7, minimum=0.3, maximum=0.999, step=0.001) | |
conv_alpha = gr.Slider(label="alpha of add noises", value=0.995, minimum=0.3, maximum=0.999, step=0.001) | |
gr_empty_space(10) | |
gr_empty_space(5) | |
with gr.Row(): | |
inp_plot = gr.Plot(label="input random variable's pdf", show_label=False) | |
shrink_plot = gr.Plot(label="pdf after linear transform", show_label=False) | |
conv_plot = gr.Plot(label="pdf after add noises", show_label=False) | |
shrink_conv_plot = gr.Plot(label="pdf after linear transform and add noises", show_label=False) | |
gr_empty_space(5) | |
init_inputs = [init_seed, shrink_alpha, conv_alpha] | |
init_outputs = [inp_plot, x_state, x_pdf_state, shrink_plot, conv_plot, shrink_conv_plot] | |
init_seed.change(init_change, init_inputs, init_outputs, show_progress="minimal") | |
shrink_inputs = [x_state, x_pdf_state, shrink_alpha, conv_alpha] | |
shrink_outputs = [shrink_plot, shrink_conv_plot] | |
shrink_alpha.change(shrink_change, shrink_inputs, shrink_outputs, show_progress="minimal") | |
conv_inputs = [x_state, x_pdf_state, shrink_alpha, conv_alpha] | |
conv_outputs = [conv_plot, shrink_conv_plot] | |
conv_alpha.change(conv_change, conv_inputs, conv_outputs, show_progress="minimal") | |
init_param = dict(method=init_change, inputs=init_inputs, outputs=init_outputs) | |
return init_param | |
def cond_prob_block(): | |
x_state = gr.State(value=None) | |
x_pdf_state = gr.State(value=None) | |
z_state = gr.State(value=None) | |
xcz_pdf_state = gr.State(value=None) | |
title = "Demo 2 - Likelihood and Posterior of Transform" | |
with gr.Accordion(label=title, elem_classes="first_demo", elem_id="demo_2"): | |
with gr.Group(elem_classes="normal"): | |
with gr.Row(): | |
seed = gr_number("random seed", 0, 1E6, 100, 1, 0, min_width=80) | |
alpha = gr_number("alpha", 0.001, 0.999, 0.98, 0.001, 3, min_width=80) | |
cond_val = gr.Slider(label="fixed condition value", value=0.2, minimum=g_st, maximum=g_et, step=0.1) | |
gr_empty_space(5) | |
gr_empty_space(5) | |
with gr.Row(): | |
inp_plot = gr.Plot(label="input variable's pdf", min_width=80, show_label=False) | |
out_plot = gr.Plot(label="output variable's pdf", min_width=80, show_label=False) | |
forward_cond_plot = gr.Plot(label="forward conditional pdf", min_width=80, show_label=False) | |
backward_cond_plot = gr.Plot(label="backward conditional pdf", min_width=80, show_label=False) | |
fixed_cond_plot = gr.Plot(label="backward fixed conditional pdf", min_width=80, show_label=False) | |
init_inputs = [seed, alpha, cond_val] | |
init_outputs = [x_state, x_pdf_state, z_state, xcz_pdf_state, inp_plot, out_plot, | |
forward_cond_plot, backward_cond_plot, fixed_cond_plot] | |
seed.change(cond_prob_init_change, init_inputs, init_outputs, show_progress="minimal") | |
alpha_inputs = [x_state, x_pdf_state, alpha, cond_val] | |
alpha_outputs = [z_state, xcz_pdf_state, out_plot, forward_cond_plot, backward_cond_plot, fixed_cond_plot] | |
alpha.change(cond_prob_alpha_change, alpha_inputs, alpha_outputs, show_progress="minimal") | |
cond_inputs = [x_state, x_pdf_state, z_state, xcz_pdf_state, alpha, cond_val] | |
cond_outputs = [backward_cond_plot, fixed_cond_plot] | |
cond_val.change(cond_prob_cond_change, cond_inputs, cond_outputs, show_progress="minimal") | |
init_param = dict(method=cond_prob_init_change, inputs=init_inputs, outputs=init_outputs) | |
return init_param | |
def forward_block(seq_info_state): | |
x_state = gr.State(value=None) | |
x_pdf_state = gr.State(value=None) | |
plot_state = gr.State(value=None) | |
title = "Demo 3.1 - Transform To Normal Distribution Iteratively" | |
with gr.Accordion(label=title, elem_classes="first_demo", elem_id="demo_3_1"): | |
with gr.Group(elem_classes="normal"): | |
with gr.Row(): | |
seed = gr_number("random seed", 0, 1E6, 100, 1, 0, min_width=80) | |
st_alpha = gr_number("start alpha", 0.001, 0.999, 0.98, 0.001, 3, min_width=80) | |
et_alpha = gr_number("end alpha", 0.001, 0.999, 0.98, 0.001, 3, min_width=80) | |
step = gr.Slider(label="step", value=7, minimum=1, maximum=15, step=1, min_width=80) | |
apply_button = gr.Button(value="apply", min_width=80) | |
node_plot = gr.Plot(label="latent variable's pdf", show_label=False) | |
with gr.Accordion("posterior pdf", elem_classes="second"): | |
cond_plot = gr.Plot(show_label=False) | |
apply_inputs = [x_state, x_pdf_state, st_alpha, et_alpha, step] | |
apply_outputs = [seq_info_state, plot_state] | |
plot_outputs = [node_plot, cond_plot] | |
apply_listener(apply_button, forward_seq_apply, forward_plot_part, | |
seq_info_state, apply_inputs, apply_outputs, [plot_state], plot_outputs) | |
init_outputs = [x_state, x_pdf_state, node_plot, cond_plot] | |
seed.change(forward_init_change, inputs=[seed], outputs=init_outputs, show_progress="minimal") | |
init_param = dict(method=forward_init_change, inputs=[seed], outputs=init_outputs) | |
return init_param | |
def backward_block(seq_info_state): | |
plot_state = gr.State(value=None) | |
placeholder = gr.State(value=None) | |
title = "Demo 3.2 - Recover From Normal Distribution Iteratively" | |
with gr.Accordion(label=title, elem_classes="first_demo", elem_id="demo_3_2"): | |
with gr.Group(elem_classes="normal"): | |
with gr.Row(): | |
is_forward_pdf = gr.Checkbox(label="forward pdf", value=True) | |
is_backward_pdf = gr.Checkbox(label="backward pdf", value=True) | |
noise_seed = gr_number("nose random seed", 0, 1E6, 200, 1, 0, min_width=80) | |
noise_ratio = gr_number("noise ratio", 0, 1, 0.0, 0.1, 1, min_width=80) | |
apply_button = gr.Button(value="apply") | |
node_plot = gr.Plot(label="each variable's pdf", show_label=False) | |
inputs = [seq_info_state, is_forward_pdf, is_backward_pdf, noise_seed, noise_ratio] | |
outputs = [node_plot, plot_state] | |
apply_listener(apply_button, backward_seq_apply, backward_plot_part, | |
placeholder, inputs, outputs, [plot_state], [node_plot]) | |
return | |
def fit_posterior_block(seq_info_state): | |
plot_state = gr.State(value=None) | |
placeholder = gr.State(value=None) | |
title = "Demo 3.3 - Fitting Posterior with Conditional Gaussian Model" | |
with gr.Accordion(label=title, elem_classes="first_demo", elem_id="demo_3_3"): | |
with gr.Group(elem_classes="normal"): | |
with gr.Row(): | |
info = "show forward pdf" | |
is_forward_pdf = gr.Checkbox(label="forward pdf", info=info, value=True) | |
info = "show origin backward pdf" | |
is_backward_pdf = gr.Checkbox(label="backward pdf", info=info, value=False) | |
info = "show backward pdf after fitting posterior with conditonal Gaussian" | |
is_show_pos = gr.Checkbox(label="fitted posterior", info=info, value=True) | |
apply_button = gr.Button(value="apply") | |
node_plot = gr.Plot(label="each variable's pdf", show_label=False) | |
with gr.Accordion("fitted posterior's pdf", elem_classes="second"): | |
cond_plot = gr.Plot(show_label=False) | |
inputs = [seq_info_state, is_forward_pdf, is_backward_pdf] | |
outputs = [node_plot, cond_plot, plot_state] | |
apply_listener(apply_button, fit_and_backward_apply, fit_plot_part, | |
placeholder, inputs, outputs, [plot_state, is_show_pos], [node_plot, cond_plot]) | |
return | |
def contraction_block(): | |
x_state = gr.State(value=None) | |
x_pdf_state = gr.State(value=None) | |
z_state = gr.State(value=None) | |
xcz_pdf_state = gr.State(value=None) | |
zt_state = gr.State(value=None) | |
zt_pdf_state = gr.State(value=None) | |
plot_state = gr.State(value=None) | |
placeholder = gr.State(value=None) | |
ctr_title = "Demo 4.1 - Posterior Transform is a Contraction Mapping" | |
with gr.Accordion(label=ctr_title, elem_classes="first_demo", elem_id="demo_4_1"): | |
with gr.Row(elem_classes="normal"): | |
with gr.Column(scale=3): | |
with gr.Group(): | |
with gr.Row(): | |
ctr_init_seed = gr_number("random seed", 0, 1E6, 100, 1, 0, min_width=80) | |
ctr_alpha = gr_number("alpha", 0.001, 0.999, 0.95, 0.001, 3, min_width=80) | |
lambda_2 = gr_number("second largest eigenvalue", 0, 0, 1.0, 0.0001, 4, min_width=80) | |
with gr.Row(): | |
inp_plot = gr.Plot(label="input variable pdf", min_width=80, show_label=False) | |
pos_plot = gr.Plot(label="posterior pdf", min_width=80, show_label=False) | |
out_plot = gr.Plot(label="output variable pdf", min_width=80, show_label=False) | |
with gr.Column(scale=2): | |
with gr.Group(): | |
with gr.Row(): | |
change_inputs_seed = gr.Button(value="change inputs seed") | |
two_inputs_seed = gr_number("two inputs random seed", 0, 1E9, 100, 1, 0) | |
inp_out_plot = gr.Plot(label="input and output pdf of inverse transform", show_label=False) | |
fixed_title = "Demo 4.2 - Posterior Transform Have a Converging Point" | |
with gr.Accordion(label=fixed_title, elem_classes="first_demo", elem_id="demo_4_2"): | |
with gr.Group(elem_classes="normal"): | |
with gr.Row(): | |
fixed_point_seed = gr_number("input seed", 0, 1E6, 200, 1, 0, min_width=80) | |
iterate_number = gr_number("iterate number", 0, 1E6, 500, 1, 0, min_width=80) | |
is_show_pow = gr.Checkbox(label="show power matrix", value=True) | |
fixed_iterate_btn = gr.Button(value="apply iteration transform") | |
gr_empty_space(5) | |
gr_empty_space(5) | |
fixed_point_plot = gr.Plot(label="result of iteration of inverse transform", show_label=False) | |
with gr.Accordion("power matrix of posterior", elem_classes="second"): | |
power_mat_plot = gr.Plot(show_label=False) | |
ctr_init_inputs = [ctr_init_seed, ctr_alpha, two_inputs_seed] | |
ctr_init_outputs = [inp_plot, x_state, x_pdf_state, pos_plot, out_plot, z_state, xcz_pdf_state, inp_out_plot, lambda_2] | |
ctr_init_seed.change(contraction_init_change, ctr_init_inputs, ctr_init_outputs, show_progress="minimal") | |
ctr_alpha_inputs = [x_state, x_pdf_state, ctr_alpha, two_inputs_seed] | |
ctr_alpha_outputs = [pos_plot, out_plot, z_state, xcz_pdf_state, inp_out_plot, lambda_2] | |
ctr_alpha.change(contraction_alpha_change, ctr_alpha_inputs, ctr_alpha_outputs, show_progress="minimal") | |
ctr_apply_inputs, ctr_apply_outputs = [x_state, x_pdf_state, xcz_pdf_state, two_inputs_seed], [inp_out_plot] | |
two_inputs_seed.change(contraction_apply, ctr_apply_inputs, ctr_apply_outputs, show_progress="minimal") | |
change_inputs_seed.click(change_two_inputs_seed, None, two_inputs_seed, show_progress="minimal") | |
fixed_init_inputs = [fixed_point_seed, x_state, x_pdf_state] | |
fixed_init_outputs = [fixed_point_plot, zt_state, zt_pdf_state, power_mat_plot] | |
fixed_point_seed.change(fixed_point_init_change, fixed_init_inputs, fixed_init_outputs, show_progress="minimal") | |
iterate_inputs = [x_state, x_pdf_state, zt_state, zt_pdf_state, xcz_pdf_state, iterate_number, is_show_pow] | |
iterate_outputs = [fixed_point_plot, power_mat_plot, plot_state] | |
plot_outputs = [fixed_point_plot, power_mat_plot] | |
apply_listener(fixed_iterate_btn, fixed_point_apply_iterate, fixed_plot_part, placeholder, | |
iterate_inputs, iterate_outputs, [plot_state], plot_outputs) | |
ctr_init_param = dict(method=contraction_init_change, inputs=ctr_init_inputs, outputs=ctr_init_outputs) | |
fixed_init_param = dict(method=fixed_point_init_change, inputs=fixed_init_inputs, outputs=fixed_init_outputs) | |
return ctr_init_param, fixed_init_param | |
def md_header_block(): | |
gr.Markdown(""" </br> </br> | |
<center> | |
<h1 style="display:block"> | |
Understanding Diffusion Probability Model<span style='color: orange'> Interactively </span> | |
</h1> | |
</center> | |
</br> </br> </br>""") | |
return | |
def run_app(): | |
with gr.Blocks(css=g_css, head=js_head, js=js_load) as demo: | |
seq_info_state = gr.State(value=None) | |
# this is needed for offline render markdown in order to import katex.css | |
gr.Markdown("$$ $$", visible=False) | |
md_header_block() | |
md_introduction_block() | |
md_transform_block() | |
rets = transform_block() | |
trans_param = rets | |
md_likelihood_block() | |
md_posterior_block() | |
rets = cond_prob_block() | |
cond_param = rets | |
md_forward_process_block() | |
rets = forward_block(seq_info_state) | |
fore_param = rets | |
md_backward_process_block() | |
backward_block(seq_info_state) | |
md_fit_posterior_block() | |
fit_posterior_block(seq_info_state) | |
md_posterior_transform_block() | |
rets = contraction_block() | |
ctr_param, fixed_param = rets | |
md_deconvolution_block() | |
md_cond_kl_block() | |
md_approx_gauss_block() | |
md_non_expanding_block() | |
md_stationary_block() | |
md_reference_block() | |
md_about_block() | |
gr.Markdown("<div><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br></div>", visible=True) | |
# running initiation consecutively because of the bug of multithreading rendering mathtext in matplotlib | |
demo.load(trans_param["method"], trans_param["inputs"], trans_param["outputs"], show_progress="minimal").\ | |
then(cond_param["method"], cond_param["inputs"], cond_param["outputs"], show_progress="minimal"). \ | |
then(fore_param["method"], fore_param["inputs"], fore_param["outputs"], show_progress="minimal"). \ | |
then(ctr_param["method"], ctr_param["inputs"], ctr_param["outputs"], show_progress="minimal"). \ | |
then(fixed_param["method"], fixed_param["inputs"], fixed_param["outputs"], show_progress="minimal") | |
demo.launch(allowed_paths=["/"]) | |
return | |
def gtx(): | |
with gr.Blocks(css=g_css, head=js_head, js=js_load) as demo: | |
gr.Markdown("$$ $$", visible=False) | |
md_introduction_block() | |
md_transform_block() | |
md_likelihood_block() | |
md_posterior_block() | |
md_forward_process_block() | |
md_backward_process_block() | |
md_fit_posterior_block() | |
md_posterior_transform_block() | |
md_deconvolution_block() | |
md_reference_block() | |
md_about_block() | |
demo.queue() | |
demo.launch(allowed_paths=["/"]) | |
return | |
if __name__ == "__main__": | |
run_app() | |
# gtx() |