DPMInteractive / DPMInteractive.py
blairzheng
add approx_gauss; add eigenvalue; replace contraction with non_expanding
ec03973
import os
import random
import time
import numpy as np
import pandas as pd
import gradio as gr
import copy
import scipy
import scipy.signal
from scipy.stats import norm
import matplotlib
import matplotlib.pyplot as plt
from scipy.spatial.distance import jensenshannon
from scipy.optimize import curve_fit
import multiprocessing
from multiprocessing import Pool, Queue, Manager
plt.rcParams['figure.constrained_layout.use'] = True
plt.rcParams['figure.max_open_warning'] = 10
matplotlib.rcParams['interactive'] = False
g_st, g_et, g_num = -2.3, 2.3, 460
g_res = (g_et-g_st)/g_num
g_fw, g_fh = 3, 3.2
###################################################################################
# common function
###################################################################################
def rs(str_label):
return str_label.replace("z_{0}", "x").replace("z_0", "x")
def set_axis(axis, x_range, y_range, x_label, y_label):
matplotlib.rcParams.update({'font.size': 10, "axes.linewidth": 0.5, "lines.linewidth": 0.7, "figure.dpi": 100})
if x_range is not None:
axis.set_xlim(*x_range)
if y_range is not None:
axis.set_ylim(*y_range)
if x_label is not None:
axis.set_xlabel(x_label)
if y_label is not None:
axis.set_ylabel(y_label)
axis.xaxis.set_major_locator(plt.MultipleLocator(1))
st, et = x_range[0]//0.2*0.2, x_range[1]//0.2*0.2
count = int((et - st)/0.2)
axis.set_xticks(np.linspace(st, et, count+1), minor=True)
return
def plot_pdf(x, x_pdf, max_y=3.2, title=None, titlesize=10,
label=None, xlabel="domain", ylabel="pdf", style="solid", color="blue"):
fig = plt.figure(figsize=(g_fw, g_fh))
ax = fig.add_subplot(111)
axis_pdf(ax, x, x_pdf, max_y, title, titlesize, label, xlabel, ylabel, style, color)
return fig
def plot_2d_pdf(x, y, pdf, cond_val=None, label=None, title=None, titlesize=10, xlabel="x", ylabel="y"):
fig = plt.figure(figsize=(g_fw, g_fh))
ax = fig.add_subplot(111)
axis_2d_pdf(ax, x, y, pdf, cond_val, title, titlesize, label, xlabel, ylabel)
return fig
def axis_pdf(ax, x, x_pdf, max_y=3.2, title=None, titlesize=10,
label=None, xlabel="domain", ylabel="pdf", style="solid", color="blue"):
set_axis(ax, (x[0], x[-1]), (0, max_y), xlabel, ylabel)
ax.plot(x, x_pdf, label=label, color=color, linestyle=style)
if title is not None:
ax.set_title(title, fontsize=titlesize)
ax.legend()
return
def axis_2d_pdf(ax, x, y, pdf, cond_val=None, title=None, titlesize=10, label=None, xlabel="x", ylabel="y"):
set_axis(ax, (x[0], x[-1]), (y[0], y[-1]), xlabel, ylabel)
ax.contourf(x, y, pdf, label=label)
if title is not None:
ax.set_title(title, fontsize=titlesize)
if cond_val is not None:
ax.plot([cond_val, cond_val], [y[-1], y[0]], color="orange")
ax.legend()
return
def add_random_noise(x_pdf, noise_ratio, seed, st, et, num, res):
_, noise_pdf = init_x_pdf(st, et, num, seed=seed)
z_pdf = (1-noise_ratio)*x_pdf + noise_ratio*noise_pdf
z_pdf = z_pdf/(res*z_pdf.sum())
return z_pdf
def power_range(st, et, num, coeff=2):
roi_nodes = st + np.ceil((np.linspace(0, 1, num=num) ** coeff) * (et - st))
roi_nodes = roi_nodes.astype(int)
for ii in range(1, len(roi_nodes)):
if roi_nodes[ii] <= roi_nodes[ii - 1]:
roi_nodes[ii] = roi_nodes[ii - 1] + 1
roi_nodes[ii] = int(roi_nodes[ii])
return list(roi_nodes)
def init_x_pdf(st, et, num, modal_count=16, shape_type=0, seed=200):
rg = np.random.RandomState(int(seed))
C = modal_count
res = (et - st) / num
if shape_type == 0:
mean_st, mean_et = -1.1, 1.1
std_st, std_et = 0.03, 0.20
elif shape_type == 1:
mean_st, mean_et = -1.5, 1.5
std_st, std_et = 0.01, 0.09
elif shape_type == 2:
mean_st, mean_et = -1.5, 1.5
std_st, std_et = 0.05, 0.35
else:
mean_st, mean_et = -0.8, 0.8
std_st, std_et = 0.05, 0.35
mean = mean_st + rg.random(C) * (mean_et - mean_st)
std = std_st + rg.random(C) * (std_et - std_st)
weight = 1 + rg.random(C) * 10
weight = weight / weight.sum()
x = np.linspace(st, et, num + 1, dtype=np.float64)
x_pdf = np.zeros_like(x, dtype=np.float64)
for i in range(C):
# print("%+0.5f___%+0.5f___%+0.5f" % (mean[i], std[i], weight[i]))
x_pdf += weight[i] * norm.pdf(x, mean[i], std[i])
x_pdf += 1E-8
x_pdf = x_pdf / (x_pdf * res).sum() # normalized to 1
return x, x_pdf
def forward_next_pdf(x, x_pdf, alpha, res):
'''
x : input domain
x_pdf : input pdf of continual variable
res : resolution of x's domain
Two ways to understand normalizing to 1:
convert to discrete variable and summarize
Approximate integral for continual variable
'''
if np.isclose(alpha, 1.0):
return x, x_pdf, None, None, None
y = copy.deepcopy(x)
xy_pdf = np.zeros([*x.shape, *y.shape], dtype=np.float64)
for i in range(len(x)):
p_x = x_pdf[i]
mu = x[i] * np.sqrt(alpha)
std = np.sqrt(1 - alpha)
p_y__x = norm.pdf(y, mu, std)
p_y__x = p_y__x/(p_y__x*res).sum()
# this will cause posterior distortion in the near zero area
# p_y__x += 1E-8
# p_y__x = p_y__x / (p_y__x * res).sum() # normalize to 1
xy_pdf[i] = p_x * p_y__x
xy_pdf = xy_pdf / (xy_pdf * res * res).sum() # normalize to 1
y_pdf = (xy_pdf * res).sum(axis=0)
xcy_pdf = xy_pdf / (y_pdf[None, :] + 1E-10)
ycx_pdf = xy_pdf / (x_pdf[:, None] + 1E-10)
return y, y_pdf, xy_pdf, xcy_pdf, ycx_pdf
###################################################################################
# transform block function
###################################################################################
def shrink(x, x_pdf, alpha, st, res):
'''
x : input domain
x_pdf : input pdf of continual variable
function : y = sqrt(\alpha) * x
inverse function : x = y / sqrt(\alpha)
derivative : y'= sqrt(\alpha)
'''
# y's domain is the sample as x
y = copy.deepcopy(x)
shrink_pdf = np.zeros_like(x_pdf, dtype=np.float64)
sqrt_alpha = np.sqrt(alpha)
for i in range(len(y)):
# get corresponding x by inverse function
idx = int((y[i] / sqrt_alpha - st) / res)
if idx < 0 or idx >= len(x_pdf):
continue
# scale with the reciprocal of derivative of y
shrink_pdf[i] = (1 / sqrt_alpha) * x_pdf[idx]
return shrink_pdf
def conv(x, x_pdf, alpha, res):
# gauss_pdf is continual random variable pdf
gauss_pdf = norm.pdf(x, 0, np.sqrt(1 - alpha))
# convert to discrete probability by multiplying with res, and convert back to continual by dividing res
out_pdf = scipy.signal.convolve(x_pdf * res, gauss_pdf * res, "same") / res
return out_pdf
def shrink_conv(x, x_pdf, shrink_alpha, conv_alpha, st, res):
# linear transform
shrink_pdf = shrink(x, x_pdf, shrink_alpha, st, res)
# add independent noises, that is equivalent to convolution
conv_pdf = conv(x, shrink_pdf, conv_alpha, res)
return conv_pdf
def plot_init_pdf(seed, st, et, num):
x, x_pdf = init_x_pdf(st, et, num, shape_type=0, seed=seed)
fig = plot_pdf(x, x_pdf, label="x", title="input variable's pdf")
fig.axes[0].title.set_size(9)
return fig, x, x_pdf
def plot_shrink_pdf(x, x_pdf, alpha, st, res):
if x is None or x_pdf is None:
return None
shrink_pdf = shrink(x, x_pdf, alpha, st, res)
fig = plot_pdf(x, shrink_pdf, label=r"$y=\sqrt{\alpha}x$", title="pdf after linear transform", titlesize=9)
return fig
def plot_conv_pdf(x, x_pdf, alpha, res):
if x is None or x_pdf is None:
return None
conv_pdf = conv(x, x_pdf, alpha, res)
fig = plot_pdf(x, conv_pdf, label=r"$y=x+\sqrt{1-\alpha}\epsilon$", title="pdf after add noises", titlesize=9)
return fig
def plot_shrink_conv_pdf(x, x_pdf, shrink_alpha, conv_alpha, st, res):
if x is None or x_pdf is None:
return None
shrink_conv_pdf = shrink_conv(x, x_pdf, shrink_alpha, conv_alpha, st, res)
title = r"pdf after two sub transforms"
label = r"$y=\sqrt{\alpha_s}x + \sqrt{1-\alpha_e}\epsilon$"
fig = plot_pdf(x, shrink_conv_pdf, label=label, title=title, titlesize=9)
return fig
def init_change(seed, shrink_alpha, conv_alpha):
global g_st, g_et, g_num, g_res
init_fig, x, x_pdf = plot_init_pdf(seed, g_st, g_et, g_num)
shrink_fig = plot_shrink_pdf(x, x_pdf, shrink_alpha, g_st, g_res)
conv_fig = plot_conv_pdf(x, x_pdf, conv_alpha, g_res)
shrink_conv_fig = plot_shrink_conv_pdf(x, x_pdf, shrink_alpha, conv_alpha, g_st, g_res)
return init_fig, x, x_pdf, shrink_fig, conv_fig, shrink_conv_fig
def shrink_change(x, x_pdf, shrink_alpha, conv_alpha):
global g_st, g_et, g_num, g_res
shrink_fig = plot_shrink_pdf(x, x_pdf, shrink_alpha, g_st, g_res)
shrink_conv_fig = plot_shrink_conv_pdf(x, x_pdf, shrink_alpha, conv_alpha, g_st, g_res)
return shrink_fig, shrink_conv_fig
def conv_change(x, x_pdf, shrink_alpha, conv_alpha):
global g_st, g_et, g_num, g_res
conv_fig = plot_conv_pdf(x, x_pdf, conv_alpha, g_res)
shrink_conv_fig = plot_shrink_conv_pdf(x, x_pdf, shrink_alpha, conv_alpha, g_st, g_res)
return conv_fig, shrink_conv_fig
###################################################################################
# cond prob block function
###################################################################################
def cond_prob_init_change(seed, alpha, cond_val):
global g_st, g_et, g_num, g_res
x, x_pdf = init_x_pdf(g_st, g_et, g_num, shape_type=0, seed=seed)
x_pdf = hijack(seed, x, x_pdf)
fig_x = plot_pdf(x, x_pdf, xlabel="x domain", ylabel="pdf", title="input variable's pdf", titlesize=9)
outputs = cond_prob_alpha_change(x, x_pdf, alpha, cond_val)
z, zcx_pdf, fig_z, fig_zcx, fig_xcz, fig_fix_xcz = outputs
return x, x_pdf, z, zcx_pdf, fig_x, fig_z, fig_zcx, fig_xcz, fig_fix_xcz
def cond_prob_alpha_change(x, x_pdf, alpha, cond_val):
forward_info = forward_next_pdf(x, x_pdf, alpha, g_res)
z, z_pdf, xz_pdf, xcz_pdf, zcx_pdf = forward_info
label = r"$z=\sqrt{\alpha}x + \sqrt{1-\alpha}\epsilon$"
input_title = r"output variable's pdf"
fore_cond_title = r"forward conditional pdf"
fig_z = plot_pdf(z, z_pdf, label=label, title=input_title, titlesize=9, xlabel="z domain", ylabel="pdf")
fig_zcx = plot_2d_pdf(x, z, zcx_pdf.transpose(), label="$q(z|x)$",
title=fore_cond_title, titlesize=9, xlabel="x domain(cond)", ylabel="z domain")
ret_fig = cond_prob_cond_change(x, x_pdf, z, xcz_pdf, alpha, cond_val)
fig_xcz, fig_fix_xcz = ret_fig
return z, xcz_pdf, fig_z, fig_zcx, fig_xcz, fig_fix_xcz
def cond_prob_cond_change(x, x_pdf, z, xcz_pdf, alpha, cond_val):
global g_st, g_et, g_num, g_res
cond_idx = int((cond_val - g_st) / g_res)
cond_pdf = xcz_pdf[:, cond_idx]
back_cond_title = "backward conditional pdf"
fig_xcz = plot_2d_pdf(x, z, xcz_pdf, cond_val, label="$q(x|z)$",
title=back_cond_title, xlabel="z domain(cond)", ylabel="x domain")
fig_xcz.axes[0].title.set_size(9)
gauss = norm.pdf(x, cond_val / np.sqrt(alpha), np.sqrt((1 - alpha) / alpha))
fixed_back_cond_title = "posterior with fixed condition"
fig_fix_xcz = plt.figure(figsize=(g_fw, g_fh))
ax = fig_fix_xcz.add_subplot(111)
axis_pdf(ax, x, gauss, max_y=5, label="$gauss$", style="dashed", color="green")
axis_pdf(ax, x, x_pdf, max_y=5, label="$q(x)$", style="dashed", color="blue")
axis_pdf(ax, x, cond_pdf, max_y=5, label="$q(x|z=%s)$" % cond_val,
title=fixed_back_cond_title, titlesize=9, xlabel="x domain", color="orange")
handles, labels = ax.get_legend_handles_labels()
ax.add_artist(ax.legend(handles[:2], labels[:2], handlelength=0.8, loc="upper left"))
ax.add_artist(ax.legend(handles[2:], labels[2:], handlelength=0.8, loc="upper right"))
return fig_xcz, fig_fix_xcz
###################################################################################
# forward block function
###################################################################################
def plot_first_pdf(x, x_pdf, ax):
title, label = r"origin var pdf", r"forward q(x)",
xlabel = rs(r"x domain")
axis_pdf(ax, x, x_pdf, title=title, label=label, xlabel=xlabel, ylabel="pdf", color="blue")
ax.legend(handlelength=1.2, labels=[label])
return
def forward_init_change(seed):
global g_st, g_et, g_num, g_res
x, x_pdf = init_x_pdf(g_st, g_et, g_num, seed=seed)
x_pdf = hijack(seed, x, x_pdf)
fig, axes = plt.subplots(nrows=1, ncols=8, figsize=(8 * g_fw, 1 * g_fh))
axes = axes.flatten()
plot_first_pdf(x, x_pdf, axes[0])
return x, x_pdf, fig, None
def plot_forward_pdf(axes, seq_info, color, pidx=-1):
count = len(seq_info)
step = int(count/3+1)
if pidx >= 0:
st, et = pidx*step, (pidx+1)*step
seq_info = seq_info[st:et]
for info in seq_info:
_, _, nz, nz_pdf, _, cidx, nidx, alpha = info
if nidx == 0:
title, label = "origin var pdf", r"forward $q(x)$",
else:
title, label = rs(r"$q(z_{%d})\ \alpha=%0.3f$"%(nidx, alpha)), r"forward $q(z_{%d})$"%nidx
xlabel = rs(r"$z_{%d}\ domain$"%nidx)
axis_pdf(axes[nidx], nz, nz_pdf, title=title, label=label, xlabel=xlabel, ylabel="pdf", color=color)
axes[nidx].legend(handlelength=1.2)
if nidx == (count-1):
axes[count-1].plot(nz, norm.pdf(nz, 0, 1), label=r"$\mathcal{N}\/(0, 1)$", color="green")
axes[count-1].legend()
return
def plot_backward_pdf(axes, fore_seq_info, back_seq_info, label_prefix, res, color, pidx=-1):
count = len(fore_seq_info)
step = int(count/3 + 1)
if pidx >= 0:
st, et = (2-pidx)*step, (2-pidx+1)*step # reverse
fore_seq_info, back_seq_info = fore_seq_info[st:et], back_seq_info[st:et]
for fore_info, back_info in zip(fore_seq_info, back_seq_info):
fore_nz_pdf, back_nz_pdf = fore_info[3], back_info[3]
nz, nidx = fore_info[2], fore_info[6]
div = jensenshannon(back_nz_pdf*res, fore_nz_pdf*res)
name = r"$\mathcal{N}\/(0,1)$" if nidx == count-1 else "revert" # specific name at end point
label = rs(label_prefix + name + " div=%0.2f"%div)
xlabel = rs(r"$z_{%d}\ domain$" % nidx)
axis_pdf(axes[nidx], nz, back_nz_pdf, label=label, xlabel=xlabel, ylabel="pdf", color=color)
axes[nidx].legend(handlelength=1.2)
return
def plot_backward_cond_pdf(axes, seq_info, reverse=True, pidx=-1):
count = len(seq_info)
step = int(count/3+1)
if pidx >= 0:
st, et = ((2-pidx)*step, (2-pidx+1)*step) if reverse else (pidx*step, (pidx+1)*step)
seq_info = seq_info[st:et]
for info in seq_info:
cz, cz_pdf, nz, nz_pdf, bc_pdf, cidx, nidx, alpha = info
if bc_pdf is None:
continue
title = rs(r"$q(z_{%d}|z_{%d})\ \alpha=%0.3f$" % (cidx, nidx, alpha))
xlabel, ylabel = rs(r"$z_{%d}$" % nidx), rs(r"$z_{%d}$" % cidx)
axis_2d_pdf(axes[nidx], cz, nz, bc_pdf, title=title, xlabel=xlabel, ylabel=ylabel)
return
def get_back_seq_info(ez, ez_pdf, fore_seq_info, res):
back_seq_info = copy.deepcopy(fore_seq_info)
count = len(back_seq_info)
nz, nz_pdf = ez, ez_pdf
for ii in reversed(range(count)):
bc_pdf = back_seq_info[ii][4]
if bc_pdf is None:
back_seq_info[ii][2:4] = nz, nz_pdf
continue
cz_pdf = np.matmul(bc_pdf, nz_pdf[:, None]) * res
cz, cz_pdf = nz, cz_pdf.flatten()
back_seq_info[ii][:4] = cz, cz_pdf, nz, nz_pdf
nz, nz_pdf = cz, cz_pdf
return back_seq_info
def forward_seq_apply(x, x_pdf, st_alpha, et_alpha, step):
global g_st, g_et, g_num, g_res
if x_pdf is None:
return None, None, None, None
alphas = np.linspace(st_alpha, et_alpha, step)
col_count = 8
row_count = int(np.ceil((step+1)/8))
fig, axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh))
axes = axes.flatten()
pos_fig, pos_axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh))
pos_axes = pos_axes.flatten()
# plot_first_pdf(x, x_pdf, fig, axes[0])
seq_info = [[None, None, x, x_pdf, None, -1, 0, None]]
cz, cz_pdf = x, x_pdf
for ii, alpha in enumerate(alphas):
forward_info = forward_next_pdf(cz, cz_pdf, alpha, g_res)
nz, nz_pdf, joint_pdf, bc_pdf, fc_pdf = forward_info
cidx, nidx = ii, ii+1
# title, label = r"$q(z_%d)\ \alpha=%0.3f$"%(nidx, alpha), r"$q(z_%d)$"%nidx
# axis_pdf(axes[nidx], nz, nz_pdf, title=title, label=label, xlabel=r"$z_{%d}\ domain$"%nidx, ylabel="pdf")
# bc_label = rs(r"$q(z_%d|z_%d)\ \alpha=%0.3f$"%(cidx, nidx, alpha))
# bc_xlabel, bc_ylabel = rs(r"$z_%d$"%nidx), rs(r"$z_%d$"%cidx)
# axis_2d_pdf(back_axes[nidx], cz, nz, bc_pdf, label=bc_label, xlabel=bc_xlabel, ylabel=bc_ylabel)
seq_info.append([cz, cz_pdf, nz, nz_pdf, bc_pdf, cidx, nidx, alpha])
cz, cz_pdf = nz, nz_pdf
# plot_forward_pdf(axes, seq_info, "blue")
# plot_backward_bc_pdf(back_axes, seq_info)
# fig.tight_layout()
# back_fig.tight_layout()
forward_plot_state = fig, axes, pos_fig, pos_axes, seq_info, g_res, "blue"
return seq_info, forward_plot_state
def forward_plot_part(plot_state, pidx):
if plot_state is None:
return None, None
fig, axes, back_fig, pos_axes, seq_info, res, color = plot_state
plot_forward_pdf(axes, seq_info, color, pidx)
plot_backward_cond_pdf(pos_axes, seq_info, False, pidx)
# fig.tight_layout()
# back_fig.tight_layout()
return fig, back_fig
def backward_seq_apply(fore_seq_info, is_forward_pdf, is_backward_pdf, noise_seed, noise_ratio):
global g_st, g_et, g_num, g_res
if fore_seq_info is None:
return None, None
col_count = 8
step = len(fore_seq_info)-1
row_count = int(np.ceil((step+1)/8))
fig, axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh))
axes = axes.flatten()
x, x_pdf = fore_seq_info[0][2:4]
if is_forward_pdf:
plot_forward_pdf(axes, fore_seq_info, "blue")
ez, ez_pdf = fore_seq_info[-1][2], norm.pdf(x, 0, 1)
std_back_seq_info, noise_back_seq_info = None, None
if is_backward_pdf:
# plot_backward_pdf(axes, ez, ez_pdf, fore_seq_info, g_res, "std ", color="green")
std_back_seq_info = get_back_seq_info(ez, ez_pdf, fore_seq_info, g_res)
if noise_ratio > 0:
ez_pdf = add_random_noise(ez_pdf, noise_ratio, noise_seed, g_st, g_et, g_num, g_res)
# plot_backward_pdf(axes, ez, ez_pdf, fore_seq_info, g_res, "noise ", color="red")
noise_back_seq_info = get_back_seq_info(ez, ez_pdf, fore_seq_info, g_res)
# fig.tight_layout()
plot_state = fig, axes, fore_seq_info, std_back_seq_info, noise_back_seq_info, g_res
return fig, plot_state
def backward_plot_part(plot_state, pidx=-1):
if plot_state is None:
return None
fig, axes, fore_seq_info, std_back_seq_info, noise_back_seq_info, res = plot_state
if std_back_seq_info is not None:
plot_backward_pdf(axes, fore_seq_info, std_back_seq_info, "std ", res, "green", pidx)
if noise_back_seq_info is not None:
plot_backward_pdf(axes, fore_seq_info, noise_back_seq_info, "noise ", res, "red", pidx)
return fig
def fit_pos_with_gauss(idx, x, bc_pdf, queue):
# bc_pdf = copy.deepcopy(bc_pdf)
for ii in range(bc_pdf.shape[1]):
# guess = bc_pdf[:, ii].mean()
(mu, std), _ = curve_fit(norm.pdf, x, bc_pdf[:, ii], p0=[0, 1])
bc_pdf[:, ii] = norm.pdf(x, mu, std)
# queue.put((idx, bc_pdf))
return bc_pdf
def seq_fit_pos_with_gauss(fore_seq_info):
fit_seq_info = copy.deepcopy(fore_seq_info)
# queue = Manager().Queue()
# ls_param = []
threads = []
for ii in range(len(fit_seq_info)):
x, _, _, _, bc_pdf = fit_seq_info[ii][:5]
if bc_pdf is None:
continue
# os.system("echo hihi")
# thrd = Thread(target=fit_pos_with_gauss, args=(ii, x, bc_pdf, None))
# threads.append(thrd)
fit_seq_info[ii][4] = fit_pos_with_gauss(ii, x, bc_pdf, None)
# ls_param.append((ii, x, bc_pdf, None))
# for thrd in threads:
# thrd.start()
# for thrd in threads:
# thrd.join()
# with Pool(6) as pool:
# pool.starmap(fit_pos_with_gauss, ls_param)
#
# for ii in range(queue.qsize()):
# idx, bc_pdf = queue.get()
# seq_info[idx][4] = bc_pdf
# with WorkerPool(n_jobs=5) as pool:
# results = pool.map(fit_pos_with_gauss, ls_param)
return fit_seq_info
def fit_and_backward_apply(fore_seq_info, is_forward_pdf, is_backward_pdf):
global g_st, g_et, g_num, g_res
if fore_seq_info is None:
return None, None, None
col_count = 8
step = len(fore_seq_info)-1
row_count = int(np.ceil((step+1) / 8))
fig, axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh))
axes = axes.flatten()
pos_fig, pos_axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh))
pos_axes = pos_axes.flatten()
x, x_pdf = fore_seq_info[0][2:4]
# axis_pdf(axes[0], x, x_pdf, title="origin var pdf $q(x)$", label="forward", xlabel="x domain", ylabel="pdf")
if is_forward_pdf:
plot_forward_pdf(axes, fore_seq_info, "blue")
ez, ez_pdf = fore_seq_info[-1][2], norm.pdf(x, 0, 1)
# axes[step].plot(ez, ez_pdf, label="$\mathcal{N}\/(0, 1)$", color="green")
if is_backward_pdf:
std_back_seq_info = get_back_seq_info(ez, ez_pdf, fore_seq_info, g_res)
plot_backward_pdf(axes, fore_seq_info, std_back_seq_info, "std ", g_res, "green")
fit_back_seq_info = seq_fit_pos_with_gauss(fore_seq_info)
fit_back_seq_info = get_back_seq_info(ez, ez_pdf, fit_back_seq_info, g_res)
# plot_backward_pdf(axes, ez, ez_pdf, fit_seq_info, g_res, "fit ", color="orange")
# plot_backward_bc_pdf(back_axes, seq_info)
# fig.tight_layout()
# back_fig.tight_layout()
fit_plot_state = fig, axes, pos_fig, pos_axes, fore_seq_info, fit_back_seq_info, g_res
return fig, pos_fig, fit_plot_state
def fit_plot_part(plot_state, is_show_pos, pidx=-1):
if plot_state is None:
return None, None
fig, axes, back_fig, back_axes, fore_seq_info, fit_back_seq_info, res = plot_state
plot_backward_pdf(axes, fore_seq_info, fit_back_seq_info, "fit ", res, "orange", pidx)
if is_show_pos:
plot_backward_cond_pdf(back_axes, fit_back_seq_info, True, pidx)
# back_fig.tight_layout()
return fig, back_fig
###################################################################################
# contraction block function
###################################################################################
def contraction_init_change(seed, alpha, two_inputs_seed):
global g_st, g_et, g_num, g_res
rg = np.random.RandomState(int(seed))
shape_type = rg.randint(0, 4)
x, x_pdf = init_x_pdf(g_st, g_et, g_num, shape_type=shape_type, seed=seed)
x_pdf = hijack(seed, x, x_pdf)
# test
# x_pdf[x_pdf < 0.01] = 0
x_pdf = x_pdf / (x_pdf * g_res).sum() # normalized to 1
fig = plot_pdf(x, x_pdf, title="input variable pdf", titlesize=9)
info = contraction_alpha_change(x, x_pdf, alpha, two_inputs_seed)
fig_xcz, fig_z, z, xcz_pdf, fig_inp_out, lambda_2 = info
return fig, x, x_pdf, fig_xcz, fig_z, z, xcz_pdf, fig_inp_out, lambda_2
def contraction_alpha_change(x, x_pdf, alpha, two_inputs_seed):
global g_st, g_et, g_num, g_res
forward_info = forward_next_pdf(x, x_pdf, alpha, g_res)
z, z_pdf, xz_pdf, xcz_pdf, zcx_pdf = forward_info
label = r"$z=\sqrt{\alpha}x + \sqrt{1-\alpha}\epsilon$"
z_title = r"output variable pdf"
xcz_title = r"posterior pdf"
fig_z = plot_pdf(z, z_pdf, label=label, title=z_title, titlesize=9, xlabel="z domain", ylabel="pdf")
fig_xcz = plot_2d_pdf(x, z, xcz_pdf, None, label="$q(x|z)$",
title=xcz_title, titlesize=9, xlabel="z domain(cond)", ylabel="x domain")
xcz = xcz_pdf/xcz_pdf.sum(axis=0, keepdims=True)
evals = np.linalg.eigvals(xcz)
evals = sorted(np.absolute(evals), reverse=True)
lambda_2 = evals[1]
fig_inp_out = contraction_apply(x, x_pdf, xcz_pdf, two_inputs_seed)
return fig_xcz, fig_z, z, xcz_pdf, fig_inp_out, lambda_2
def change_two_inputs_seed():
seed = random.randint(0, 1E6)
return seed
def contraction_apply(x, x_pdf, bc_pdf, seed):
global g_st, g_et, g_num, g_res
rg = np.random.RandomState(int(seed))
modals = [1, 2, 8, 12, 16, 16, 16, 16, 16, 20]
count1, count2 = rg.choice(modals), rg.choice(modals)
seed1, seed2 = rg.randint(0, 1E6, 2)
shape1, shape2 = rg.randint(0, 4, 2)
z1, z1_pdf = init_x_pdf(g_st, g_et, g_num, count1, shape_type=shape1, seed=seed1)
z2, z2_pdf = init_x_pdf(g_st, g_et, g_num, count2, shape_type=shape2, seed=seed2)
div_z = jensenshannon(z1_pdf*g_res, z2_pdf*g_res)
x1_pdf = np.matmul(bc_pdf, z1_pdf[:, None]) * g_res
x2_pdf = np.matmul(bc_pdf, z2_pdf[:, None]) * g_res
x1_pdf, x2_pdf = x1_pdf.flatten(), x2_pdf.flatten()
div_x = jensenshannon(x1_pdf*g_res, x2_pdf*g_res)
div_in_label, div_out_label = r"$div_{in}=%0.3f$"%div_z, r"$div_{out}=%0.3f$"%div_x
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(2*g_fw, 1*g_fh))
axis_pdf(axes[0], z1, z1_pdf, max_y=3.8, label="input1",
title="two random input", titlesize=9, xlabel="z domain", ylabel="pdf", color="orange")
axis_pdf(axes[0], z2, z2_pdf, max_y=3.8, label="input2", xlabel="z domain", ylabel="pdf", color="green")
axes[0].plot([], [], label=div_in_label, color="blue")
handles, labels = axes[0].get_legend_handles_labels()
axes[0].add_artist(axes[0].legend(handles[:2], labels[:2], handlelength=1.0, loc="upper left"))
axes[0].add_artist(axes[0].legend(handles[2:], labels[2:], handlelength=0, loc="upper right"))
# axis_pdf(axes[1], x, x_pdf, max_y=3.8, title="two output", titlesize=9, style="dotted", color="blue")
axis_pdf(axes[1], z1, x1_pdf, max_y=3.8, label="output1",
title="two output", titlesize=9, xlabel="x domain", ylabel="pdf", color="orange")
axis_pdf(axes[1], z2, x2_pdf, max_y=3.8, label="output2", xlabel="x domain", ylabel="pdf", color="green")
axes[1].plot([], [], label=div_out_label, color="blue")
handles, labels = axes[1].get_legend_handles_labels()
axes[1].add_artist(axes[1].legend(handles[:2], labels[:2], handlelength=1.0, loc="upper left"))
axes[1].add_artist(axes[1].legend(handles[2:], labels[2:], handlelength=0, loc="upper right"))
# fig.tight_layout()
return fig
def fixed_point_init_change(seed, x, x_pdf):
rg = np.random.RandomState(int(seed))
shape_type = rg.randint(0, 4)
count = rg.choice([1, 2, 8, 12, 16, 16, 16, 16, 16, 20])
z, z_pdf = init_x_pdf(g_st, g_et, g_num, modal_count=count, shape_type=shape_type, seed=seed)
div = jensenshannon(z_pdf*g_res, x_pdf*g_res)
fig, axes = plt.subplots(nrows=1, ncols=8, figsize=(8*g_fw, 1*g_fh))
axes = axes.flatten()
axis_pdf(axes[0], x, x_pdf, label="converging pdf", color="blue")
axis_pdf(axes[0], z, z_pdf, title="random input of inverse transform", label="random input", color="green")
axes[0].plot([], [], label="div=%0.2f"%div, color="orange")
axes[0].legend(handlelength=1.2)
# fig.tight_layout()
return fig, z, z_pdf, None
def matrix_power(in_mat, n):
if n == 0:
return np.eye(in_mat.shape[0])
temp_mat = matrix_power(in_mat, int(n / 2))
if n % 2 == 0:
out_mat = np.matmul(temp_mat * 100, temp_mat * 100) / 10000
out_mat = out_mat / (out_mat.sum(axis=0, keepdims=True) + 1E-9)
return out_mat
else:
out_mat = np.matmul(temp_mat * 100, temp_mat * 100) / 10000
out_mat = out_mat / (out_mat.sum(axis=0, keepdims=True) + 1E-9)
out_mat = np.matmul(in_mat * 100, out_mat * 100) / 10000
out_mat = out_mat / (out_mat.sum(axis=0, keepdims=True) + 1E-9)
return out_mat
def fixed_point_apply_iterate(x, x_pdf, zt, zt_pdf, xcz_pdf, iterate_num, is_show_pow):
global g_res
if x_pdf is None or zt_pdf is None or xcz_pdf is None:
return None, None, None
col_count, max_row_count = 8, 3
max_ax_count = max_row_count*col_count - 1
ax_count = min(iterate_num, max_ax_count)
row_count = int(np.ceil((ax_count+1)/col_count))
fig, axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh))
axes = axes.flatten()
axis_pdf(axes[0], x, x_pdf, label="converging point", color="blue")
axis_pdf(axes[0], zt, zt_pdf, title="random input", label="random input", color="green")
div = jensenshannon(zt_pdf*g_res, x_pdf*g_res)
axes[0].plot([], [], label="div=%0.2f"%div, color="green")
axes[0].legend(handlelength=1.2)
idxs = np.arange(iterate_num).tolist()
if iterate_num > max_ax_count:
idxs = np.arange(6).tolist() + power_range(6, iterate_num-1, max_ax_count-6, 2.5)
pow_mats, pdfs = [], []
for ii, idx in enumerate(idxs):
pow_idx, ax_idx = idx + 1, ii + 1
pow_mat = matrix_power(xcz_pdf*g_res, pow_idx)
pz_pdf = np.matmul(pow_mat, zt_pdf[:, None])
pz, pz_pdf = zt, pz_pdf.flatten()
pow_mats.append([pow_mat, pow_idx, ax_idx])
pdfs.append([x, x_pdf, pz_pdf, pow_idx, ax_idx])
pow_fig, pow_axes = None, None
if is_show_pow:
pow_fig, pow_axes = plt.subplots(nrows=row_count, ncols=col_count, figsize=(col_count*g_fw, row_count*g_fh))
pow_axes = pow_axes.flatten()
plot_state = (fig, pow_fig, axes, pow_axes, pdfs, pow_mats, g_res)
return fig, pow_fig, plot_state
def fixed_plot_part(plot_state, pidx):
if plot_state is None:
return None, None
fig, pow_fig, axes, pow_axes, pdfs, pow_mats, res = plot_state
step = int(len(pdfs)/3) + 1
roi_pdfs = pdfs[pidx*step: (pidx+1)*step]
for pdf_info in roi_pdfs:
x, x_pdf, pz_pdf, pow_idx, ax_idx = pdf_info
axis_pdf(axes[ax_idx], x, x_pdf, label="converging pdf", color="blue")
title = r"the %dth iterate" % pow_idx
axis_pdf(axes[ax_idx], x, pz_pdf, title=title, label="transform result", color="green")
div = jensenshannon(pz_pdf*res, x_pdf*res)
axes[ax_idx].plot([], [], label="div=%0.3f"%div, color="green")
axes[ax_idx].legend(handlelength=1.2)
# fig.tight_layout()
if pow_axes is None:
return fig, None
roi_pow_mats = pow_mats[pidx*step: (pidx+1)*step]
for pow_info in roi_pow_mats:
pow_mat, pow_idx, ax_idx = pow_info
axis_2d_pdf(pow_axes[ax_idx], x, x, pow_mat, title="power(mat,%d)"%(pow_idx), xlabel="z", ylabel="x")
# pow_fig.tight_layout()
return fig, pow_fig
def hijack(seed, x, x_pdf):
if seed in [16002, 16003]:
x, x_pdf = init_x_pdf(g_st, g_et, g_num, shape_type=2, seed=100)
left, right = (-0.5, 0.5) if seed == 16002 else (-0.7, 0.7)
mask = np.logical_and(x > left, x < right)
x_pdf[mask] = 0
base = 17500
left, right = int(base+g_st*100), int(base+g_et*100)
if seed in range(left, right):
mu, std = g_st + (seed//10*10 - left)*0.01, (seed%10+1)*0.02
x_pdf = norm.pdf(x, mu, std)
return x_pdf