Zeel's picture
Update app.py
043861d
import streamlit as st
import jax
import optax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import zeros
from tinygp import kernels, transforms, GaussianProcess
import numpy as np
import matplotlib.pyplot as plt
mcycle_x = np.array([ 35.2, 27.6, 35.6, 28.2, 57.6, 26.4, 46.6, 55. ,
16.6, 8.2, 32.8, 19.2, 14.6, 42.8, 9.6, 50.6,
2.4, 34.8, 33.4, 6.2, 34.4, 29.4, 25.6, 16. ,
13.8, 15.6, 20.2, 44.4, 3.6, 21.2, 8.8, 24.2,
45. , 33.8, 7.8, 38. , 2.6, 41.6, 20.4, 23.2,
26. , 40. , 28.4, 55.4, 15.4, 53.2, 11.4, 48.8,
25.4, 13.6, 39.2, 40.4, 16.8, 4. , 43. , 15.8,
24.6, 16.4, 28.6, 52. , 30.2, 18.6, 10.2, 32. ,
6.6, 17.6, 19.6, 24. , 16.2, 42.4, 22. , 23.4,
44. , 6.8, 11. , 10.6, 26.2, 39.4, 31.2, 19.4,
21.4, 27.2, 47.8, 35.4, 13.2, 31. , 36.2, 14.8,
3.2, 25. , 21.8, 17.8, 27. , 10. ])
mcycle_y = np.array([ -16. , 4. , 34.8, 12. , 10.7, -65.6, 10.7, -2.7,
-59. , -2.7, 46.9, -123.1, -13.3, 0. , -2.7, 0. ,
0. , 75. , 16. , -2.7, 1.3, -17.4, -26.8, -42.9,
0. , -40.2, -123.1, 0. , 0. , -134. , -1.3, -95.1,
10.7, 45.6, -2.7, 46.9, -1.3, 30.8, -117.9, -123.1,
-5.4, -21.5, -21.5, -2.7, -22.8, -14.7, 0. , -13.3,
-72.3, -2.7, 5.4, -13.3, -71. , -2.7, 14.7, -21.5,
-53.5, -5.4, 46.9, 10.7, 36.2, -112.5, -5.4, 54.9,
-2.7, -37.5, -127.2, -112.5, -21.5, 29.4, -123.1, -128.5,
-1.3, -1.3, -5.4, -2.7, -107.1, -1.3, 8.1, -85.6,
-101.9, -45.6, -26.8, 69.6, -2.7, 75. , -37.5, -2.7,
-2.7, -64.4, -108.4, -99.1, -16. , -2.7])
oly_x = np.array([1896. , 1900. , 1904. , 1908. ,
1912. , 1920. , 1924. , 1928. ,
1932. , 1936. , 1948. , 1952. ,
1956. , 1960. , 1964. , 1968. ,
1972. , 1976. , 1980. , 1984. ,
1988. , 1992. , 1996. , 2000. ,
2004. , 2008. , 2012. ])
oly_y = np.array([ 4.47083333, 4.46472926, 5.22208333, 4.15467867,
3.90331675, 3.56951267, 3.82454477, 3.62483707,
3.59284275, 3.53880792, 3.67010309, 3.39029111,
3.43642612, 3.20583007, 3.13275665, 3.32819844,
3.13583758, 3.0789588 , 3.10581822, 3.06552909,
3.09357349, 3.16111704, 3.14255244, 3.08527867,
3.10265829, 2.99877553, 3.03392977])
st.title("Heteroscedastic Gaussian Processes")
st.markdown(r"""
Gaussian processes generally assume Homoskedastic noise such as:
$$
y_{i}=f\left(\mathbf{x}_{i}\right)+\epsilon_{i}, \quad \epsilon_{i} \stackrel{\text { i.i.d. }}{\sim} \mathcal{N}\left(0, \sigma_{\epsilon}^{2}\right), \quad 1 \leq i \leq n
$$
We can also assume separate distribution of noise over each data point:
$$
y_{i}=f\left(\mathbf{x}_{i}\right)+\epsilon_{i}, \quad \epsilon_{i} {\sim} \mathcal{N}\left(0, \sigma_{\epsilon_i}^{2}\right), \quad 1 \leq i \leq n
$$
However, this may not be straightforward to extend for inference or conditioning. A simple idea can be to learn a non-linear neural network function to model the relationship between inputs and noise:
$$
\sigma_{\epsilon_i} = f(\mathbf{x}_i)
$$
This demo is an attempt to experiment with this idea on several synthetic and real datasets with Heteroskedastic noise.
""")
data = st.selectbox("Data", ["Motorcycle", "Olympic", "Linear", 'GPflow'])
if data == "Motorcycle":
data_x, data_y = mcycle_x, mcycle_y
elif data == "Olympic":
data_x, data_y = oly_x, oly_y
elif data == "Linear":
data_x = np.linspace(0,1,99)
data_y = 3 * data_x + 2 + (np.random.randn(data_x.shape[0]) * data_x)
elif data == 'GPflow':
N = 1001
# Build inputs X
data_x = np.linspace(0, 4 * np.pi, N)
# Deterministic functions in place of latent ones
f1 = np.sin
f2 = np.cos
# Use transform = exp to ensure positive-only scale values
transform = np.exp
# Compute loc and scale as functions of input X
loc = f1(data_x)
scale = transform(f2(data_x))
# Sample outputs Y from Gaussian Likelihood
data_y = np.random.normal(loc, scale)
x = (data_x - data_x.mean()) / data_x.std()
y = (data_y - data_y.mean()) / data_y.std()
n_tests = st.number_input(
"Number of test points", min_value=50, max_value=1000, value=100
)
t = np.linspace(x.min(), x.max(), n_tests)
noise = 0.01
# x = np.sort(random.uniform(-1, 1, 100))
# y = 2 * (x > 0) - 1 + random.normal(0.0, noise, len(x))
# t = np.linspace(-1.5, 1.5, 500)
# Define a small neural network used to non-linearly transform the input data in our model
fet1 = st.slider("Number of neurons in Layer1", min_value=2, max_value=30, value=15)
fet2 = st.slider("Number of neurons in Layer2", min_value=2, max_value=30, value=10)
class Transformer(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=fet1)(x)
x = nn.relu(x)
x = nn.Dense(features=fet2)(x)
x = nn.relu(x)
x = nn.Dense(features=1)(x)
return x
class BaseGPLoss(nn.Module):
@nn.compact
def __call__(self, x, y, t):
# Set up a typical Matern-3/2 kernel
log_sigma = self.param("log_sigma", zeros, ())
log_rho = self.param("log_rho", zeros, ())
log_jitter = self.param("log_jitter", zeros, ())
kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(jnp.exp(log_rho))
# Evaluate and return the GP negative log likelihood as usual
gp = GaussianProcess(kernel, x[:, None], diag=noise**2 + jnp.exp(log_jitter))
log_prob, gp_cond = gp.condition(y, t[:, None])
return (
-log_prob,
(gp_cond.loc, gp_cond.variance),
jnp.exp(log_jitter),
)
class GPLoss(nn.Module):
@nn.compact
def __call__(self, x, y, t):
# Set up a typical Matern-3/2 kernel
log_sigma = self.param("log_sigma", zeros, ())
log_rho = self.param("log_rho", zeros, ())
# log_jitter = self.param("log_jitter", zeros, ())
base_kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(jnp.exp(log_rho))
# Define a custom transform to pass the input coordinates through our `Transformer`
# network from above
transform = Transformer()
log_jitter = transform(x.reshape(-1, 1)).ravel()
kernel = base_kernel
# Evaluate and return the GP negative log likelihood as usual
gp = GaussianProcess(kernel, x[:, None], diag=noise**2 + jnp.exp(log_jitter))
log_prob, gp_cond = gp.condition(y, t[:, None])
return (
-log_prob,
(gp_cond.loc, gp_cond.variance),
jnp.exp(transform(t[:, None])),
)
# Define and train the model
def loss(params):
return m.apply(params, x, y, t)[0]
base_model = BaseGPLoss()
model = GPLoss()
seed = np.random.randint(0,100)
base_params = base_model.init(jax.random.PRNGKey(seed), x, y, t)
params = model.init(jax.random.PRNGKey(np.random.randint(seed)), x, y, t)
n_iters = st.number_input("Number of iterations", min_value=1, max_value=200, value=100)
lr = st.selectbox("Learning rate", [0.1, 0.01, 0.001, 0.0001], 1)
tx = optax.sgd(learning_rate=lr)
base_opt_state = tx.init(base_params)
opt_state = tx.init(params)
loss_grad_fn = jax.jit(jax.value_and_grad(loss))
base_losses = []
losses = []
my_bar = st.progress(0)
for i in range(n_iters):
m = base_model
base_loss_val, base_grads = loss_grad_fn(base_params)
m = model
loss_val, grads = loss_grad_fn(params)
base_updates, base_opt_state = tx.update(base_grads, base_opt_state)
updates, opt_state = tx.update(grads, opt_state)
base_params = optax.apply_updates(base_params, base_updates)
params = optax.apply_updates(params, updates)
losses.append(loss_val)
base_losses.append(base_loss_val)
my_bar.progress((i+1) / n_iters)
# Plot the results and compare to the true model
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 4))
m_list = [base_model, model]
p_list = [base_params, params]
t_list = ['Homoskedastic GP', 'Heteroskedastic GP']
j_list = []
for i in range(2):
_, (mu, var), jitter = m_list[i].apply(p_list[i], x, y, t)
var += jitter.ravel()
j_list.append(jitter)
# plt.plot(t, 2 * (t > 0) - 1, "k", lw=1, label="truth")
ax[i].plot(x, y, ".k", label="data")
ax[i].plot(t, mu)
ax[i].set_title(t_list[i])
ax[i].fill_between(
t, mu + 2 * np.sqrt(var), mu - 2 * np.sqrt(var), alpha=0.5, label="95% conf"
)
ax[1].legend()
col = st.columns(1)[0]
with col:
st.pyplot(fig)
col2 = st.columns(1)[0]
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 4))
idx = np.argsort(t)
ax[1].plot(t[idx], j_list[1][idx], label="learned noise")
ax[0].hlines(j_list[0], t[idx].min(), t[idx].max())
ax[0].set_xlabel('x')
ax[1].set_xlabel('x')
ax[0].set_ylabel('learned noise')
ax[1].legend()
with col2:
st.pyplot(fig)
col3 = st.columns(1)[0]
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 4))
idx = np.argsort(t)
ax[0].plot(base_losses)
ax[1].plot(losses, label="loss")
ax[0].set_xlabel('iterations')
ax[0].set_ylabel('loss')
ax[1].set_xlabel('iterations')
ax[1].legend()
with col3:
st.pyplot(fig)