Spaces:
Running
Running
import os | |
import random | |
import time | |
import numpy as np | |
import pandas as pd | |
import gradio as gr | |
import copy | |
import scipy | |
import scipy.signal | |
from scipy.stats import norm | |
import matplotlib | |
import matplotlib.pyplot as plt | |
from scipy.spatial.distance import jensenshannon | |
from scipy.optimize import curve_fit | |
import multiprocessing | |
from multiprocessing import Pool, Queue, Manager | |
plt.rcParams['figure.constrained_layout.use'] = True | |
plt.rcParams['figure.max_open_warning'] = 10 | |
matplotlib.rcParams['interactive'] = False | |
g_st, g_et, g_num = -2.3, 2.3, 460 | |
g_res = (g_et-g_st)/g_num | |
g_fw, g_fh = 3, 3.2 | |
################################################################################### | |
# common function | |
################################################################################### | |
def rs(str_label): | |
return str_label.replace("z_{0}", "x").replace("z_0", "x") | |
def set_axis(axis, x_range, y_range, x_label, y_label): | |
matplotlib.rcParams.update({'font.size': 10, "axes.linewidth": 0.5, "lines.linewidth": 0.7, "figure.dpi": 100}) | |
if x_range is not None: | |
axis.set_xlim(*x_range) | |
if y_range is not None: | |
axis.set_ylim(*y_range) | |
if x_label is not None: | |
axis.set_xlabel(x_label) | |
if y_label is not None: | |
axis.set_ylabel(y_label) | |
axis.xaxis.set_major_locator(plt.MultipleLocator(1)) | |
st, et = x_range[0]//0.2*0.2, x_range[1]//0.2*0.2 | |
count = int((et - st)/0.2) | |
axis.set_xticks(np.linspace(st, et, count+1), minor=True) | |
return | |
def plot_pdf(x, x_pdf, max_y=3.2, title=None, titlesize=10, | |
label=None, xlabel="domain", ylabel="pdf", style="solid", color="blue"): | |
fig = plt.figure(figsize=(g_fw, g_fh)) | |
ax = fig.add_subplot(111) | |
axis_pdf(ax, x, x_pdf, max_y, title, titlesize, label, xlabel, ylabel, style, color) | |
return fig | |
def plot_2d_pdf(x, y, pdf, cond_val=None, label=None, title=None, titlesize=10, xlabel="x", ylabel="y"): | |
fig = plt.figure(figsize=(g_fw, g_fh)) | |
ax = fig.add_subplot(111) | |
axis_2d_pdf(ax, x, y, pdf, cond_val, title, titlesize, label, xlabel, ylabel) | |
return fig | |
def axis_pdf(ax, x, x_pdf, max_y=3.2, title=None, titlesize=10, | |
label=None, xlabel="domain", ylabel="pdf", style="solid", color="blue"): | |
set_axis(ax, (x[0], x[-1]), (0, max_y), xlabel, ylabel) | |
ax.plot(x, x_pdf, label=label, color=color, linestyle=style) | |
if title is not None: | |
ax.set_title(title, fontsize=titlesize) | |
ax.legend() | |
return | |
def axis_2d_pdf(ax, x, y, pdf, cond_val=None, title=None, titlesize=10, label=None, xlabel="x", ylabel="y"): | |
set_axis(ax, (x[0], x[-1]), (y[0], y[-1]), xlabel, ylabel) | |
ax.contourf(x, y, pdf, label=label) | |
if title is not None: | |
ax.set_title(title, fontsize=titlesize) | |
if cond_val is not None: | |
ax.plot([cond_val, cond_val], [y[-1], y[0]], color="orange") | |
ax.legend() | |
return | |
def add_random_noise(x_pdf, noise_ratio, seed, st, et, num, res): | |
_, noise_pdf = init_x_pdf(st, et, num, seed=seed) | |
z_pdf = (1-noise_ratio)*x_pdf + noise_ratio*noise_pdf | |
z_pdf = z_pdf/(res*z_pdf.sum()) | |
return z_pdf | |
def power_range(st, et, num, coeff=2): | |
roi_nodes = st + np.ceil((np.linspace(0, 1, num=num) ** coeff) * (et - st)) | |
roi_nodes = roi_nodes.astype(int) | |
for ii in range(1, len(roi_nodes)): | |
if roi_nodes[ii] <= roi_nodes[ii - 1]: | |
roi_nodes[ii] = roi_nodes[ii - 1] + 1 | |
roi_nodes[ii] = int(roi_nodes[ii]) | |
return list(roi_nodes) | |
def init_x_pdf(st, et, num, modal_count=16, shape_type=0, seed=200): | |
rg = np.random.RandomState(int(seed)) | |
C = modal_count | |
res = (et - st) / num | |
if shape_type == 0: | |
mean_st, mean_et = -1.1, 1.1 | |
std_st, std_et = 0.03, 0.20 | |
elif shape_type == 1: | |
mean_st, mean_et = -1.5, 1.5 | |
std_st, std_et = 0.01, 0.09 | |
elif shape_type == 2: | |
mean_st, mean_et = -1.5, 1.5 | |
std_st, std_et = 0.05, 0.35 | |
else: | |
mean_st, mean_et = -0.8, 0.8 | |
std_st, std_et = 0.05, 0.35 | |
mean = mean_st + rg.random(C) * (mean_et - mean_st) | |
std = std_st + rg.random(C) * (std_et - std_st) | |
weight = 1 + rg.random(C) * 10 | |
weight = weight / weight.sum() | |
x = np.linspace(st, et, num + 1, dtype=np.float64) | |
x_pdf = np.zeros_like(x, dtype=np.float64) | |
for i in range(C): | |
# print("%+0.5f___%+0.5f___%+0.5f" % (mean[i], std[i], weight[i])) | |
x_pdf += weight[i] * norm.pdf(x, mean[i], std[i]) | |
x_pdf += 1E-8 | |
x_pdf = x_pdf / (x_pdf * res).sum() # normalized to 1 | |
return x, x_pdf | |
def forward_next_pdf(x, x_pdf, alpha, res): | |
''' | |
x : input domain | |
x_pdf : input pdf of continual variable | |
res : resolution of x's domain | |
Two ways to understand normalizing to 1: | |
convert to discrete variable and summarize | |
Approximate integral for continual variable | |
''' | |
if np.isclose(alpha, 1.0): | |
return x, x_pdf, None, None, None | |
y = copy.deepcopy(x) | |
xy_pdf = np.zeros([*x.shape, *y.shape], dtype=np.float64) | |
for i in range(len(x)): | |
p_x = x_pdf[i] | |
mu = x[i] * np.sqrt(alpha) | |
std = np.sqrt(1 - alpha) | |
p_y__x = norm.pdf(y, mu, std) | |
p_y__x = p_y__x/(p_y__x*res).sum() | |
# this will cause posterior distortion in the near zero area | |
# p_y__x += 1E-8 | |
# p_y__x = p_y__x / (p_y__x * res).sum() # normalize to 1 | |
xy_pdf[i] = p_x * p_y__x | |
xy_pdf = xy_pdf / (xy_pdf * res * res).sum() # normalize to 1 | |
y_pdf = (xy_pdf * res).sum(axis=0) | |
xcy_pdf = xy_pdf / (y_pdf[None, :] + 1E-10) | |
ycx_pdf = xy_pdf / (x_pdf[:, None] + 1E-10) | |
return y, y_pdf, xy_pdf, xcy_pdf, ycx_pdf | |
################################################################################### | |
# transform block function | |
################################################################################### | |
def shrink(x, x_pdf, alpha, st, res): | |
''' | |
x : input domain | |
x_pdf : input pdf of continual variable | |
function : y = sqrt(\alpha) * x | |
inverse function : x = y / sqrt(\alpha) | |
derivative : y'= sqrt(\alpha) | |
''' | |
# y's domain is the sample as x | |
y = copy.deepcopy(x) | |
shrink_pdf = np.zeros_like(x_pdf, dtype=np.float64) | |
sqrt_alpha = np.sqrt(alpha) | |
for i in range(len(y)): | |
# get corresponding x by inverse function | |
idx = int((y[i] / sqrt_alpha - st) / res) | |
if idx < 0 or idx >= len(x_pdf): | |
continue | |
# scale with the reciprocal of derivative of y | |
shrink_pdf[i] = (1 / sqrt_alpha) * x_pdf[idx] | |
return shrink_pdf | |
def conv(x, x_pdf, alpha, res): | |
# gauss_pdf is continual random variable pdf | |
gauss_pdf = norm.pdf(x, 0, np.sqrt(1 - alpha)) | |
# convert to discrete probability by multiplying with res, and convert back to continual by dividing res | |
out_pdf = scipy.signal.convolve(x_pdf * res, gauss_pdf * res, "same") / res | |
return out_pdf | |
def shrink_conv(x, x_pdf, shrink_alpha, conv_alpha, st, res): | |
# linear transform | |
shrink_pdf = shrink(x, x_pdf, shrink_alpha, st, res) | |
# add independent noises, that is equivalent to convolution | |
conv_pdf = conv(x, shrink_pdf, conv_alpha, res) | |
return conv_pdf | |
def plot_init_pdf(seed, st, et, num): | |
x, x_pdf = init_x_pdf(st, et, num, shape_type=0, seed=seed) | |
fig = plot_pdf(x, x_pdf, label="x", title="input variable's pdf") | |
fig.axes[0].title.set_size(9) | |
return fig, x, x_pdf | |
def plot_shrink_pdf(x, x_pdf, alpha, st, res): | |
if x is None or x_pdf is None: | |
return None | |
shrink_pdf = shrink(x, x_pdf, alpha, st, res) | |
fig = plot_pdf(x, shrink_pdf, label=r"$y=\sqrt{\alpha}x$", title="pdf after linear transform", titlesize=9) | |
return fig | |
def plot_conv_pdf(x, x_pdf, alpha, res): | |
if x is None or x_pdf is None: | |
return None | |
conv_pdf = conv(x, x_pdf, alpha, res) | |
fig = plot_pdf(x, conv_pdf, label=r"$y=x+\sqrt{1-\alpha}\epsilon$", title="pdf after add noises", titlesize=9) | |
return fig | |
def plot_shrink_conv_pdf(x, x_pdf, shrink_alpha, conv_alpha, st, res): | |
if x is None or x_pdf is None: | |
return None | |
shrink_conv_pdf = shrink_conv(x, x_pdf, shrink_alpha, conv_alpha, st, res) | |
title = r"pdf after two sub transforms" | |
label = r"$y=\sqrt{\alpha_s}x + \sqrt{1-\alpha_e}\epsilon$" | |
fig = plot_pdf(x, shrink_conv_pdf, label=label, title=title, titlesize=9) | |
return fig | |
def init_change(seed, shrink_alpha, conv_alpha): | |
global g_st, g_et, g_num, g_res | |
init_fig, x, x_pdf = plot_init_pdf(seed, g_st, g_et, g_num) | |
shrink_fig = plot_shrink_pdf(x, x_pdf, shrink_alpha, g_st, g_res) | |
conv_fig = plot_conv_pdf(x, x_pdf, conv_alpha, g_res) | |
shrink_conv_fig = plot_shrink_conv_pdf(x, x_pdf, shrink_alpha, conv_alpha, g_st, g_res) | |
return init_fig, x, x_pdf, shrink_fig, conv_fig, shrink_conv_fig | |
def shrink_change(x, x_pdf, shrink_alpha, conv_alpha): | |
global g_st, g_et, g_num, g_res | |
shrink_fig = plot_shrink_pdf(x, x_pdf, shrink_alpha, g_st, g_res) | |
shrink_conv_fig = plot_shrink_conv_pdf(x, x_pdf, shrink_alpha, conv_alpha, g_st, g_res) | |
return shrink_fig, shrink_conv_fig | |
def conv_change(x, x_pdf, shrink_alpha, conv_alpha): | |
global g_st, g_et, g_num, g_res | |
conv_fig = plot_conv_pdf(x, x_pdf, conv_alpha, g_res) | |
shrink_conv_fig = plot_shrink_conv_pdf(x, x_pdf, shrink_alpha, conv_alpha, g_st, g_res) | |
return conv_fig, shrink_conv_fig | |
################################################################################### | |
# cond prob block function | |
################################################################################### | |
def cond_prob_init_change(seed, alpha, cond_val): | |
global g_st, g_et, g_num, g_res | |
x, x_pdf = init_x_pdf(g_st, g_et, g_num, shape_type=0, seed=seed) | |
x_pdf = hijack(seed, x, x_pdf) | |
fig_x = plot_pdf(x, x_pdf, xlabel="x domain", ylabel="pdf", title="input variable's pdf", titlesize=9) | |
outputs = cond_prob_alpha_change(x, x_pdf, alpha, cond_val) | |
z, zcx_pdf, fig_z, fig_zcx, fig_xcz, fig_fix_xcz = outputs | |
return x, x_pdf, z, zcx_pdf, fig_x, fig_z, fig_zcx, fig_xcz, fig_fix_xcz | |
def cond_prob_alpha_change(x, x_pdf, alpha, cond_val): | |
forward_info = forward_next_pdf(x, x_pdf, alpha, g_res) | |
z, z_pdf, xz_pdf, xcz_pdf, zcx_pdf = forward_info | |
label = r"$z=\sqrt{\alpha}x + \sqrt{1-\alpha}\epsilon$" | |
input_title = r"output variable's pdf" | |
fore_cond_title = r"forward conditional pdf" | |
fig_z = plot_pdf(z, z_pdf, label=label, title=input_title, titlesize=9, xlabel="z domain", ylabel="pdf") | |
fig_zcx = plot_2d_pdf(x, z, zcx_pdf.transpose(), label="$q(z|x)$", | |
title=fore_cond_title, titlesize=9, xlabel="x domain(cond)", ylabel="z domain") | |
ret_fig = cond_prob_cond_change(x, x_pdf, z, xcz_pdf, alpha, cond_val) | |
fig_xcz, fig_fix_xcz = ret_fig | |
return z, xcz_pdf, fig_z, fig_zcx, fig_xcz, fig_fix_xcz | |
def cond_prob_cond_change(x, x_pdf, z, xcz_pdf, alpha, cond_val): | |
global g_st, g_et, g_num, g_res | |
cond_idx = int((cond_val - g_st) / g_res) | |
cond_pdf = xcz_pdf[:, cond_idx] | |
back_cond_title = "backward conditional pdf" | |
fig_xcz = plot_2d_pdf(x, z, xcz_pdf, cond_val, label="$q(x|z)$", | |
title=back_cond_title, xlabel="z domain(cond)", ylabel="x domain") | |
fig_xcz.axes[0].title.set_size(9) | |
gauss = norm.pdf(x, cond_val / np.sqrt(alpha), np.sqrt((1 - alpha) / alpha)) | |
fixed_back_cond_title = "posterior with fixed condition" | |
fig_fix_xcz = plt.figure(figsize=(g_fw, g_fh)) | |
ax = fig_fix_xcz.add_subplot(111) | |
axis_pdf(ax, x, gauss, max_y=5, label="$gauss$", style="dashed", color="green") | |
axis_pdf(ax, x, x_pdf, max_y=5, label="$q(x)$", style="dashed", color="blue") | |
axis_pdf(ax, x, cond_pdf, max_y=5, label="$q(x|z=%s)$" % cond_val, | |
title=fixed_back_cond_title, titlesize=9, xlabel="x domain", color="orange") | |
handles, labels = ax.get_legend_handles_labels() | |
ax.add_artist(ax.legend(handles[:2], labels[:2], handlelength=0.8, loc="upper left")) | |
ax.add_artist(ax.legend(handles[2:], labels[2:], handlelength=0.8, loc="upper right")) | |
return fig_xcz, fig_fix_xcz | |
################################################################################### | |
# forward block function | |
################################################################################### | |
def plot_first_pdf(x, x_pdf, ax): | |
title, label = r"origin var pdf", r"forward q(x)", | |
xlabel = rs(r"x domain") | |
axis_pdf(ax, x, x_pdf, title=title, label=label, xlabel=xlabel, ylabel="pdf", color="blue") | |
ax.legend(handlelength=1.2, labels=[label]) | |
return | |
def forward_init_change(seed): | |
global g_st, g_et, g_num, g_res | |
x, x_pdf = init_x_pdf(g_st, g_et, g_num, seed=seed) | |
x_pdf = hijack(seed, x, x_pdf) | |
fig, axes = plt.subplots(nrows=1, ncols=8, figsize=(8 * g_fw, 1 * g_fh)) | |
axes = axes.flatten() | |
plot_first_pdf(x, x_pdf, axes[0]) | |
return x, x_pdf, fig, None | |
def plot_forward_pdf(axes, seq_info, color, pidx=-1): | |
count = len(seq_info) | |
step = int(count/3+1) | |
if pidx >= 0: | |
st, et = pidx*step, (pidx+1)*step | |
seq_info = seq_info[st:et] | |
for info in seq_info: | |
_, _, nz, nz_pdf, _, cidx, nidx, alpha = info | |
if nidx == 0: | |
title, label = "origin var pdf", r"forward $q(x)$", | |
else: | |
title, label = rs(r"$q(z_{%d})\ \alpha=%0.3f$"%(nidx, alpha)), r"forward $q(z_{%d})$"%nidx | |
xlabel = rs(r"$z_{%d}\ domain$"%nidx) | |
axis_pdf(axes[nidx], nz, nz_pdf, title=title, label=label, xlabel=xlabel, ylabel="pdf", color=color) | |
axes[nidx].legend(handlelength=1.2) | |
if nidx == (count-1): | |
axes[count-1].plot(nz, norm.pdf(nz, 0, 1), label=r"$\mathcal{N}\/(0, 1)$", color="green") | |
axes[count-1].legend() | |
return | |
def plot_backward_pdf(axes, fore_seq_info, back_seq_info, label_prefix, res, color, pidx=-1): | |
count = len(fore_seq_info) | |
step = int(count/3 + 1) | |
if pidx >= 0: | |
st, et = (2-pidx)*step, (2-pidx+1)*step # reverse | |
fore_seq_info, back_seq_info = fore_seq_info[st:et], back_seq_info[st:et] | |
for fore_info, back_info in zip(fore_seq_info, back_seq_info): | |
fore_nz_pdf, back_nz_pdf = fore_info[3], back_info[3] | |
nz, nidx = fore_info[2], fore_info[6] | |
div = jensenshannon(back_nz_pdf*res, fore_nz_pdf*res) | |
name = r"$\mathcal{N}\/(0,1)$" if nidx == count-1 else "revert" # specific name at end point | |
label = rs(label_prefix + name + " div=%0.2f"%div) | |
xlabel = rs(r"$z_{%d}\ domain$" % nidx) | |
axis_pdf(axes[nidx], nz, back_nz_pdf, label=label, xlabel=xlabel, ylabel="pdf", color=color) | |
axes[nidx].legend(handlelength=1.2) | |
return | |
def plot_backward_cond_pdf(axes, seq_info, reverse=True, pidx=-1): | |
count = len(seq_info) | |
step = int(count/3+1) | |
if pidx >= 0: | |
st, et = ((2-pidx)*step, (2-pidx+1)*step) if reverse else (pidx*step, (pidx+1)*step) | |
seq_info = seq_info[st:et] | |
for info in seq_info: | |
cz, cz_pdf, nz, nz_pdf, bc_pdf, cidx, nidx, alpha = info | |
if bc_pdf is None: | |
continue | |
title = rs(r"$q(z_{%d}|z_{%d})\ \alpha=%0.3f$" % (cidx, nidx, alpha)) | |
xlabel, ylabel = rs(r"$z_{%d}$" % nidx), rs(r"$z_{%d}$" % cidx) | |
axis_2d_pdf(axes[nidx], cz, nz, bc_pdf, title=title, xlabel=xlabel, ylabel=ylabel) | |
return | |
def get_back_seq_info(ez, ez_pdf, fore_seq_info, res): | |
back_seq_info = copy.deepcopy(fore_seq_info) | |
count = len(back_seq_info) | |
nz, nz_pdf = ez, ez_pdf | |
for ii in reversed(range(count)): | |
bc_pdf = back_seq_info[ii][4] | |
if bc_pdf is None: | |
back_seq_info[ii][2:4] = nz, nz_pdf | |
continue | |
cz_pdf = np.matmul(bc_pdf, nz_pdf[:, None]) * res | |
cz, cz_pdf = nz, cz_pdf.flatten() | |
back_seq_info[ii][:4] = cz, cz_pdf, nz, nz_pdf | |
nz, nz_pdf = cz, cz_pdf | |
return back_seq_info | |
def forward_seq_apply(x, x_pdf, st_alpha, et_alpha, step): | |
global g_st, g_et, g_num, g_res | |
if x_pdf is None: | |
return None, None, None, None | |
alphas = np.linspace(st_alpha, et_alpha, step) | |
col_count = 8 | |
row_count = int(np.ceil((step+1)/8)) | |
fig, axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh)) | |
axes = axes.flatten() | |
pos_fig, pos_axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh)) | |
pos_axes = pos_axes.flatten() | |
# plot_first_pdf(x, x_pdf, fig, axes[0]) | |
seq_info = [[None, None, x, x_pdf, None, -1, 0, None]] | |
cz, cz_pdf = x, x_pdf | |
for ii, alpha in enumerate(alphas): | |
forward_info = forward_next_pdf(cz, cz_pdf, alpha, g_res) | |
nz, nz_pdf, joint_pdf, bc_pdf, fc_pdf = forward_info | |
cidx, nidx = ii, ii+1 | |
# title, label = r"$q(z_%d)\ \alpha=%0.3f$"%(nidx, alpha), r"$q(z_%d)$"%nidx | |
# axis_pdf(axes[nidx], nz, nz_pdf, title=title, label=label, xlabel=r"$z_{%d}\ domain$"%nidx, ylabel="pdf") | |
# bc_label = rs(r"$q(z_%d|z_%d)\ \alpha=%0.3f$"%(cidx, nidx, alpha)) | |
# bc_xlabel, bc_ylabel = rs(r"$z_%d$"%nidx), rs(r"$z_%d$"%cidx) | |
# axis_2d_pdf(back_axes[nidx], cz, nz, bc_pdf, label=bc_label, xlabel=bc_xlabel, ylabel=bc_ylabel) | |
seq_info.append([cz, cz_pdf, nz, nz_pdf, bc_pdf, cidx, nidx, alpha]) | |
cz, cz_pdf = nz, nz_pdf | |
# plot_forward_pdf(axes, seq_info, "blue") | |
# plot_backward_bc_pdf(back_axes, seq_info) | |
# fig.tight_layout() | |
# back_fig.tight_layout() | |
forward_plot_state = fig, axes, pos_fig, pos_axes, seq_info, g_res, "blue" | |
return seq_info, forward_plot_state | |
def forward_plot_part(plot_state, pidx): | |
if plot_state is None: | |
return None, None | |
fig, axes, back_fig, pos_axes, seq_info, res, color = plot_state | |
plot_forward_pdf(axes, seq_info, color, pidx) | |
plot_backward_cond_pdf(pos_axes, seq_info, False, pidx) | |
# fig.tight_layout() | |
# back_fig.tight_layout() | |
return fig, back_fig | |
def backward_seq_apply(fore_seq_info, is_forward_pdf, is_backward_pdf, noise_seed, noise_ratio): | |
global g_st, g_et, g_num, g_res | |
if fore_seq_info is None: | |
return None, None | |
col_count = 8 | |
step = len(fore_seq_info)-1 | |
row_count = int(np.ceil((step+1)/8)) | |
fig, axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh)) | |
axes = axes.flatten() | |
x, x_pdf = fore_seq_info[0][2:4] | |
if is_forward_pdf: | |
plot_forward_pdf(axes, fore_seq_info, "blue") | |
ez, ez_pdf = fore_seq_info[-1][2], norm.pdf(x, 0, 1) | |
std_back_seq_info, noise_back_seq_info = None, None | |
if is_backward_pdf: | |
# plot_backward_pdf(axes, ez, ez_pdf, fore_seq_info, g_res, "std ", color="green") | |
std_back_seq_info = get_back_seq_info(ez, ez_pdf, fore_seq_info, g_res) | |
if noise_ratio > 0: | |
ez_pdf = add_random_noise(ez_pdf, noise_ratio, noise_seed, g_st, g_et, g_num, g_res) | |
# plot_backward_pdf(axes, ez, ez_pdf, fore_seq_info, g_res, "noise ", color="red") | |
noise_back_seq_info = get_back_seq_info(ez, ez_pdf, fore_seq_info, g_res) | |
# fig.tight_layout() | |
plot_state = fig, axes, fore_seq_info, std_back_seq_info, noise_back_seq_info, g_res | |
return fig, plot_state | |
def backward_plot_part(plot_state, pidx=-1): | |
if plot_state is None: | |
return None | |
fig, axes, fore_seq_info, std_back_seq_info, noise_back_seq_info, res = plot_state | |
if std_back_seq_info is not None: | |
plot_backward_pdf(axes, fore_seq_info, std_back_seq_info, "std ", res, "green", pidx) | |
if noise_back_seq_info is not None: | |
plot_backward_pdf(axes, fore_seq_info, noise_back_seq_info, "noise ", res, "red", pidx) | |
return fig | |
def fit_pos_with_gauss(idx, x, bc_pdf, queue): | |
# bc_pdf = copy.deepcopy(bc_pdf) | |
for ii in range(bc_pdf.shape[1]): | |
# guess = bc_pdf[:, ii].mean() | |
(mu, std), _ = curve_fit(norm.pdf, x, bc_pdf[:, ii], p0=[0, 1]) | |
bc_pdf[:, ii] = norm.pdf(x, mu, std) | |
# queue.put((idx, bc_pdf)) | |
return bc_pdf | |
def seq_fit_pos_with_gauss(fore_seq_info): | |
fit_seq_info = copy.deepcopy(fore_seq_info) | |
# queue = Manager().Queue() | |
# ls_param = [] | |
threads = [] | |
for ii in range(len(fit_seq_info)): | |
x, _, _, _, bc_pdf = fit_seq_info[ii][:5] | |
if bc_pdf is None: | |
continue | |
# os.system("echo hihi") | |
# thrd = Thread(target=fit_pos_with_gauss, args=(ii, x, bc_pdf, None)) | |
# threads.append(thrd) | |
fit_seq_info[ii][4] = fit_pos_with_gauss(ii, x, bc_pdf, None) | |
# ls_param.append((ii, x, bc_pdf, None)) | |
# for thrd in threads: | |
# thrd.start() | |
# for thrd in threads: | |
# thrd.join() | |
# with Pool(6) as pool: | |
# pool.starmap(fit_pos_with_gauss, ls_param) | |
# | |
# for ii in range(queue.qsize()): | |
# idx, bc_pdf = queue.get() | |
# seq_info[idx][4] = bc_pdf | |
# with WorkerPool(n_jobs=5) as pool: | |
# results = pool.map(fit_pos_with_gauss, ls_param) | |
return fit_seq_info | |
def fit_and_backward_apply(fore_seq_info, is_forward_pdf, is_backward_pdf): | |
global g_st, g_et, g_num, g_res | |
if fore_seq_info is None: | |
return None, None, None | |
col_count = 8 | |
step = len(fore_seq_info)-1 | |
row_count = int(np.ceil((step+1) / 8)) | |
fig, axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh)) | |
axes = axes.flatten() | |
pos_fig, pos_axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh)) | |
pos_axes = pos_axes.flatten() | |
x, x_pdf = fore_seq_info[0][2:4] | |
# axis_pdf(axes[0], x, x_pdf, title="origin var pdf $q(x)$", label="forward", xlabel="x domain", ylabel="pdf") | |
if is_forward_pdf: | |
plot_forward_pdf(axes, fore_seq_info, "blue") | |
ez, ez_pdf = fore_seq_info[-1][2], norm.pdf(x, 0, 1) | |
# axes[step].plot(ez, ez_pdf, label="$\mathcal{N}\/(0, 1)$", color="green") | |
if is_backward_pdf: | |
std_back_seq_info = get_back_seq_info(ez, ez_pdf, fore_seq_info, g_res) | |
plot_backward_pdf(axes, fore_seq_info, std_back_seq_info, "std ", g_res, "green") | |
fit_back_seq_info = seq_fit_pos_with_gauss(fore_seq_info) | |
fit_back_seq_info = get_back_seq_info(ez, ez_pdf, fit_back_seq_info, g_res) | |
# plot_backward_pdf(axes, ez, ez_pdf, fit_seq_info, g_res, "fit ", color="orange") | |
# plot_backward_bc_pdf(back_axes, seq_info) | |
# fig.tight_layout() | |
# back_fig.tight_layout() | |
fit_plot_state = fig, axes, pos_fig, pos_axes, fore_seq_info, fit_back_seq_info, g_res | |
return fig, pos_fig, fit_plot_state | |
def fit_plot_part(plot_state, is_show_pos, pidx=-1): | |
if plot_state is None: | |
return None, None | |
fig, axes, back_fig, back_axes, fore_seq_info, fit_back_seq_info, res = plot_state | |
plot_backward_pdf(axes, fore_seq_info, fit_back_seq_info, "fit ", res, "orange", pidx) | |
if is_show_pos: | |
plot_backward_cond_pdf(back_axes, fit_back_seq_info, True, pidx) | |
# back_fig.tight_layout() | |
return fig, back_fig | |
################################################################################### | |
# contraction block function | |
################################################################################### | |
def contraction_init_change(seed, alpha, two_inputs_seed): | |
global g_st, g_et, g_num, g_res | |
rg = np.random.RandomState(int(seed)) | |
shape_type = rg.randint(0, 4) | |
x, x_pdf = init_x_pdf(g_st, g_et, g_num, shape_type=shape_type, seed=seed) | |
x_pdf = hijack(seed, x, x_pdf) | |
# test | |
# x_pdf[x_pdf < 0.01] = 0 | |
x_pdf = x_pdf / (x_pdf * g_res).sum() # normalized to 1 | |
fig = plot_pdf(x, x_pdf, title="input variable pdf", titlesize=9) | |
info = contraction_alpha_change(x, x_pdf, alpha, two_inputs_seed) | |
fig_xcz, fig_z, z, xcz_pdf, fig_inp_out, lambda_2 = info | |
return fig, x, x_pdf, fig_xcz, fig_z, z, xcz_pdf, fig_inp_out, lambda_2 | |
def contraction_alpha_change(x, x_pdf, alpha, two_inputs_seed): | |
global g_st, g_et, g_num, g_res | |
forward_info = forward_next_pdf(x, x_pdf, alpha, g_res) | |
z, z_pdf, xz_pdf, xcz_pdf, zcx_pdf = forward_info | |
label = r"$z=\sqrt{\alpha}x + \sqrt{1-\alpha}\epsilon$" | |
z_title = r"output variable pdf" | |
xcz_title = r"posterior pdf" | |
fig_z = plot_pdf(z, z_pdf, label=label, title=z_title, titlesize=9, xlabel="z domain", ylabel="pdf") | |
fig_xcz = plot_2d_pdf(x, z, xcz_pdf, None, label="$q(x|z)$", | |
title=xcz_title, titlesize=9, xlabel="z domain(cond)", ylabel="x domain") | |
xcz = xcz_pdf/xcz_pdf.sum(axis=0, keepdims=True) | |
evals = np.linalg.eigvals(xcz) | |
evals = sorted(np.absolute(evals), reverse=True) | |
lambda_2 = evals[1] | |
fig_inp_out = contraction_apply(x, x_pdf, xcz_pdf, two_inputs_seed) | |
return fig_xcz, fig_z, z, xcz_pdf, fig_inp_out, lambda_2 | |
def change_two_inputs_seed(): | |
seed = random.randint(0, 1E6) | |
return seed | |
def contraction_apply(x, x_pdf, bc_pdf, seed): | |
global g_st, g_et, g_num, g_res | |
rg = np.random.RandomState(int(seed)) | |
modals = [1, 2, 8, 12, 16, 16, 16, 16, 16, 20] | |
count1, count2 = rg.choice(modals), rg.choice(modals) | |
seed1, seed2 = rg.randint(0, 1E6, 2) | |
shape1, shape2 = rg.randint(0, 4, 2) | |
z1, z1_pdf = init_x_pdf(g_st, g_et, g_num, count1, shape_type=shape1, seed=seed1) | |
z2, z2_pdf = init_x_pdf(g_st, g_et, g_num, count2, shape_type=shape2, seed=seed2) | |
div_z = jensenshannon(z1_pdf*g_res, z2_pdf*g_res) | |
x1_pdf = np.matmul(bc_pdf, z1_pdf[:, None]) * g_res | |
x2_pdf = np.matmul(bc_pdf, z2_pdf[:, None]) * g_res | |
x1_pdf, x2_pdf = x1_pdf.flatten(), x2_pdf.flatten() | |
div_x = jensenshannon(x1_pdf*g_res, x2_pdf*g_res) | |
div_in_label, div_out_label = r"$div_{in}=%0.3f$"%div_z, r"$div_{out}=%0.3f$"%div_x | |
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(2*g_fw, 1*g_fh)) | |
axis_pdf(axes[0], z1, z1_pdf, max_y=3.8, label="input1", | |
title="two random input", titlesize=9, xlabel="z domain", ylabel="pdf", color="orange") | |
axis_pdf(axes[0], z2, z2_pdf, max_y=3.8, label="input2", xlabel="z domain", ylabel="pdf", color="green") | |
axes[0].plot([], [], label=div_in_label, color="blue") | |
handles, labels = axes[0].get_legend_handles_labels() | |
axes[0].add_artist(axes[0].legend(handles[:2], labels[:2], handlelength=1.0, loc="upper left")) | |
axes[0].add_artist(axes[0].legend(handles[2:], labels[2:], handlelength=0, loc="upper right")) | |
# axis_pdf(axes[1], x, x_pdf, max_y=3.8, title="two output", titlesize=9, style="dotted", color="blue") | |
axis_pdf(axes[1], z1, x1_pdf, max_y=3.8, label="output1", | |
title="two output", titlesize=9, xlabel="x domain", ylabel="pdf", color="orange") | |
axis_pdf(axes[1], z2, x2_pdf, max_y=3.8, label="output2", xlabel="x domain", ylabel="pdf", color="green") | |
axes[1].plot([], [], label=div_out_label, color="blue") | |
handles, labels = axes[1].get_legend_handles_labels() | |
axes[1].add_artist(axes[1].legend(handles[:2], labels[:2], handlelength=1.0, loc="upper left")) | |
axes[1].add_artist(axes[1].legend(handles[2:], labels[2:], handlelength=0, loc="upper right")) | |
# fig.tight_layout() | |
return fig | |
def fixed_point_init_change(seed, x, x_pdf): | |
rg = np.random.RandomState(int(seed)) | |
shape_type = rg.randint(0, 4) | |
count = rg.choice([1, 2, 8, 12, 16, 16, 16, 16, 16, 20]) | |
z, z_pdf = init_x_pdf(g_st, g_et, g_num, modal_count=count, shape_type=shape_type, seed=seed) | |
div = jensenshannon(z_pdf*g_res, x_pdf*g_res) | |
fig, axes = plt.subplots(nrows=1, ncols=8, figsize=(8*g_fw, 1*g_fh)) | |
axes = axes.flatten() | |
axis_pdf(axes[0], x, x_pdf, label="converging pdf", color="blue") | |
axis_pdf(axes[0], z, z_pdf, title="random input of inverse transform", label="random input", color="green") | |
axes[0].plot([], [], label="div=%0.2f"%div, color="orange") | |
axes[0].legend(handlelength=1.2) | |
# fig.tight_layout() | |
return fig, z, z_pdf, None | |
def matrix_power(in_mat, n): | |
if n == 0: | |
return np.eye(in_mat.shape[0]) | |
temp_mat = matrix_power(in_mat, int(n / 2)) | |
if n % 2 == 0: | |
out_mat = np.matmul(temp_mat * 100, temp_mat * 100) / 10000 | |
out_mat = out_mat / (out_mat.sum(axis=0, keepdims=True) + 1E-9) | |
return out_mat | |
else: | |
out_mat = np.matmul(temp_mat * 100, temp_mat * 100) / 10000 | |
out_mat = out_mat / (out_mat.sum(axis=0, keepdims=True) + 1E-9) | |
out_mat = np.matmul(in_mat * 100, out_mat * 100) / 10000 | |
out_mat = out_mat / (out_mat.sum(axis=0, keepdims=True) + 1E-9) | |
return out_mat | |
def fixed_point_apply_iterate(x, x_pdf, zt, zt_pdf, xcz_pdf, iterate_num, is_show_pow): | |
global g_res | |
if x_pdf is None or zt_pdf is None or xcz_pdf is None: | |
return None, None, None | |
col_count, max_row_count = 8, 3 | |
max_ax_count = max_row_count*col_count - 1 | |
ax_count = min(iterate_num, max_ax_count) | |
row_count = int(np.ceil((ax_count+1)/col_count)) | |
fig, axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh)) | |
axes = axes.flatten() | |
axis_pdf(axes[0], x, x_pdf, label="converging point", color="blue") | |
axis_pdf(axes[0], zt, zt_pdf, title="random input", label="random input", color="green") | |
div = jensenshannon(zt_pdf*g_res, x_pdf*g_res) | |
axes[0].plot([], [], label="div=%0.2f"%div, color="green") | |
axes[0].legend(handlelength=1.2) | |
idxs = np.arange(iterate_num).tolist() | |
if iterate_num > max_ax_count: | |
idxs = np.arange(6).tolist() + power_range(6, iterate_num-1, max_ax_count-6, 2.5) | |
pow_mats, pdfs = [], [] | |
for ii, idx in enumerate(idxs): | |
pow_idx, ax_idx = idx + 1, ii + 1 | |
pow_mat = matrix_power(xcz_pdf*g_res, pow_idx) | |
pz_pdf = np.matmul(pow_mat, zt_pdf[:, None]) | |
pz, pz_pdf = zt, pz_pdf.flatten() | |
pow_mats.append([pow_mat, pow_idx, ax_idx]) | |
pdfs.append([x, x_pdf, pz_pdf, pow_idx, ax_idx]) | |
pow_fig, pow_axes = None, None | |
if is_show_pow: | |
pow_fig, pow_axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh)) | |
pow_axes = pow_axes.flatten() | |
plot_state = (fig, pow_fig, axes, pow_axes, pdfs, pow_mats, g_res) | |
return fig, pow_fig, plot_state | |
def fixed_plot_part(plot_state, pidx): | |
if plot_state is None: | |
return None, None | |
fig, pow_fig, axes, pow_axes, pdfs, pow_mats, res = plot_state | |
step = int(len(pdfs)/3) + 1 | |
roi_pdfs = pdfs[pidx*step: (pidx+1)*step] | |
for pdf_info in roi_pdfs: | |
x, x_pdf, pz_pdf, pow_idx, ax_idx = pdf_info | |
axis_pdf(axes[ax_idx], x, x_pdf, label="converging pdf", color="blue") | |
title = r"the %dth iterate" % pow_idx | |
axis_pdf(axes[ax_idx], x, pz_pdf, title=title, label="transform result", color="green") | |
div = jensenshannon(pz_pdf*res, x_pdf*res) | |
axes[ax_idx].plot([], [], label="div=%0.3f"%div, color="green") | |
axes[ax_idx].legend(handlelength=1.2) | |
# fig.tight_layout() | |
if pow_axes is None: | |
return fig, None | |
roi_pow_mats = pow_mats[pidx*step: (pidx+1)*step] | |
for pow_info in roi_pow_mats: | |
pow_mat, pow_idx, ax_idx = pow_info | |
axis_2d_pdf(pow_axes[ax_idx], x, x, pow_mat, title="power(mat,%d)"%(pow_idx), xlabel="z", ylabel="x") | |
# pow_fig.tight_layout() | |
return fig, pow_fig | |
def hijack(seed, x, x_pdf): | |
if seed in [16002, 16003]: | |
x, x_pdf = init_x_pdf(g_st, g_et, g_num, shape_type=2, seed=100) | |
left, right = (-0.5, 0.5) if seed == 16002 else (-0.7, 0.7) | |
mask = np.logical_and(x > left, x < right) | |
x_pdf[mask] = 0 | |
base = 17500 | |
left, right = int(base+g_st*100), int(base+g_et*100) | |
if seed in range(left, right): | |
mu, std = g_st + (seed//10*10 - left)*0.01, (seed%10+1)*0.02 | |
x_pdf = norm.pdf(x, mu, std) | |
return x_pdf | |