Spaces:
Runtime error
Runtime error
import streamlit as st | |
import jax | |
import optax | |
import jax.numpy as jnp | |
import flax.linen as nn | |
from flax.linen.initializers import zeros | |
from tinygp import kernels, transforms, GaussianProcess | |
import numpy as np | |
import matplotlib.pyplot as plt | |
mcycle_x = np.array([ 35.2, 27.6, 35.6, 28.2, 57.6, 26.4, 46.6, 55. , | |
16.6, 8.2, 32.8, 19.2, 14.6, 42.8, 9.6, 50.6, | |
2.4, 34.8, 33.4, 6.2, 34.4, 29.4, 25.6, 16. , | |
13.8, 15.6, 20.2, 44.4, 3.6, 21.2, 8.8, 24.2, | |
45. , 33.8, 7.8, 38. , 2.6, 41.6, 20.4, 23.2, | |
26. , 40. , 28.4, 55.4, 15.4, 53.2, 11.4, 48.8, | |
25.4, 13.6, 39.2, 40.4, 16.8, 4. , 43. , 15.8, | |
24.6, 16.4, 28.6, 52. , 30.2, 18.6, 10.2, 32. , | |
6.6, 17.6, 19.6, 24. , 16.2, 42.4, 22. , 23.4, | |
44. , 6.8, 11. , 10.6, 26.2, 39.4, 31.2, 19.4, | |
21.4, 27.2, 47.8, 35.4, 13.2, 31. , 36.2, 14.8, | |
3.2, 25. , 21.8, 17.8, 27. , 10. ]) | |
mcycle_y = np.array([ -16. , 4. , 34.8, 12. , 10.7, -65.6, 10.7, -2.7, | |
-59. , -2.7, 46.9, -123.1, -13.3, 0. , -2.7, 0. , | |
0. , 75. , 16. , -2.7, 1.3, -17.4, -26.8, -42.9, | |
0. , -40.2, -123.1, 0. , 0. , -134. , -1.3, -95.1, | |
10.7, 45.6, -2.7, 46.9, -1.3, 30.8, -117.9, -123.1, | |
-5.4, -21.5, -21.5, -2.7, -22.8, -14.7, 0. , -13.3, | |
-72.3, -2.7, 5.4, -13.3, -71. , -2.7, 14.7, -21.5, | |
-53.5, -5.4, 46.9, 10.7, 36.2, -112.5, -5.4, 54.9, | |
-2.7, -37.5, -127.2, -112.5, -21.5, 29.4, -123.1, -128.5, | |
-1.3, -1.3, -5.4, -2.7, -107.1, -1.3, 8.1, -85.6, | |
-101.9, -45.6, -26.8, 69.6, -2.7, 75. , -37.5, -2.7, | |
-2.7, -64.4, -108.4, -99.1, -16. , -2.7]) | |
oly_x = np.array([1896. , 1900. , 1904. , 1908. , | |
1912. , 1920. , 1924. , 1928. , | |
1932. , 1936. , 1948. , 1952. , | |
1956. , 1960. , 1964. , 1968. , | |
1972. , 1976. , 1980. , 1984. , | |
1988. , 1992. , 1996. , 2000. , | |
2004. , 2008. , 2012. ]) | |
oly_y = np.array([ 4.47083333, 4.46472926, 5.22208333, 4.15467867, | |
3.90331675, 3.56951267, 3.82454477, 3.62483707, | |
3.59284275, 3.53880792, 3.67010309, 3.39029111, | |
3.43642612, 3.20583007, 3.13275665, 3.32819844, | |
3.13583758, 3.0789588 , 3.10581822, 3.06552909, | |
3.09357349, 3.16111704, 3.14255244, 3.08527867, | |
3.10265829, 2.99877553, 3.03392977]) | |
st.title("Heteroscedastic Gaussian Processes") | |
st.markdown(r""" | |
Gaussian processes generally assume Homoskedastic noise such as: | |
$$ | |
y_{i}=f\left(\mathbf{x}_{i}\right)+\epsilon_{i}, \quad \epsilon_{i} \stackrel{\text { i.i.d. }}{\sim} \mathcal{N}\left(0, \sigma_{\epsilon}^{2}\right), \quad 1 \leq i \leq n | |
$$ | |
We can also assume separate distribution of noise over each data point: | |
$$ | |
y_{i}=f\left(\mathbf{x}_{i}\right)+\epsilon_{i}, \quad \epsilon_{i} {\sim} \mathcal{N}\left(0, \sigma_{\epsilon_i}^{2}\right), \quad 1 \leq i \leq n | |
$$ | |
However, this may not be straightforward to extend for inference or conditioning. A simple idea can be to learn a non-linear neural network function to model the relationship between inputs and noise: | |
$$ | |
\sigma_{\epsilon_i} = f(\mathbf{x}_i) | |
$$ | |
This demo is an attempt to experiment with this idea on several synthetic and real datasets with Heteroskedastic noise. | |
""") | |
data = st.selectbox("Data", ["Motorcycle", "Olympic", "Linear", 'GPflow']) | |
if data == "Motorcycle": | |
data_x, data_y = mcycle_x, mcycle_y | |
elif data == "Olympic": | |
data_x, data_y = oly_x, oly_y | |
elif data == "Linear": | |
data_x = np.linspace(0,1,99) | |
data_y = 3 * data_x + 2 + (np.random.randn(data_x.shape[0]) * data_x) | |
elif data == 'GPflow': | |
N = 1001 | |
# Build inputs X | |
data_x = np.linspace(0, 4 * np.pi, N) | |
# Deterministic functions in place of latent ones | |
f1 = np.sin | |
f2 = np.cos | |
# Use transform = exp to ensure positive-only scale values | |
transform = np.exp | |
# Compute loc and scale as functions of input X | |
loc = f1(data_x) | |
scale = transform(f2(data_x)) | |
# Sample outputs Y from Gaussian Likelihood | |
data_y = np.random.normal(loc, scale) | |
x = (data_x - data_x.mean()) / data_x.std() | |
y = (data_y - data_y.mean()) / data_y.std() | |
n_tests = st.number_input( | |
"Number of test points", min_value=50, max_value=1000, value=100 | |
) | |
t = np.linspace(x.min(), x.max(), n_tests) | |
noise = 0.01 | |
# x = np.sort(random.uniform(-1, 1, 100)) | |
# y = 2 * (x > 0) - 1 + random.normal(0.0, noise, len(x)) | |
# t = np.linspace(-1.5, 1.5, 500) | |
# Define a small neural network used to non-linearly transform the input data in our model | |
fet1 = st.slider("Number of neurons in Layer1", min_value=2, max_value=30, value=15) | |
fet2 = st.slider("Number of neurons in Layer2", min_value=2, max_value=30, value=10) | |
class Transformer(nn.Module): | |
def __call__(self, x): | |
x = nn.Dense(features=fet1)(x) | |
x = nn.relu(x) | |
x = nn.Dense(features=fet2)(x) | |
x = nn.relu(x) | |
x = nn.Dense(features=1)(x) | |
return x | |
class BaseGPLoss(nn.Module): | |
def __call__(self, x, y, t): | |
# Set up a typical Matern-3/2 kernel | |
log_sigma = self.param("log_sigma", zeros, ()) | |
log_rho = self.param("log_rho", zeros, ()) | |
log_jitter = self.param("log_jitter", zeros, ()) | |
kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(jnp.exp(log_rho)) | |
# Evaluate and return the GP negative log likelihood as usual | |
gp = GaussianProcess(kernel, x[:, None], diag=noise**2 + jnp.exp(log_jitter)) | |
log_prob, gp_cond = gp.condition(y, t[:, None]) | |
return ( | |
-log_prob, | |
(gp_cond.loc, gp_cond.variance), | |
jnp.exp(log_jitter), | |
) | |
class GPLoss(nn.Module): | |
def __call__(self, x, y, t): | |
# Set up a typical Matern-3/2 kernel | |
log_sigma = self.param("log_sigma", zeros, ()) | |
log_rho = self.param("log_rho", zeros, ()) | |
# log_jitter = self.param("log_jitter", zeros, ()) | |
base_kernel = jnp.exp(2 * log_sigma) * kernels.Matern32(jnp.exp(log_rho)) | |
# Define a custom transform to pass the input coordinates through our `Transformer` | |
# network from above | |
transform = Transformer() | |
log_jitter = transform(x.reshape(-1, 1)).ravel() | |
kernel = base_kernel | |
# Evaluate and return the GP negative log likelihood as usual | |
gp = GaussianProcess(kernel, x[:, None], diag=noise**2 + jnp.exp(log_jitter)) | |
log_prob, gp_cond = gp.condition(y, t[:, None]) | |
return ( | |
-log_prob, | |
(gp_cond.loc, gp_cond.variance), | |
jnp.exp(transform(t[:, None])), | |
) | |
# Define and train the model | |
def loss(params): | |
return m.apply(params, x, y, t)[0] | |
base_model = BaseGPLoss() | |
model = GPLoss() | |
seed = np.random.randint(0,100) | |
base_params = base_model.init(jax.random.PRNGKey(seed), x, y, t) | |
params = model.init(jax.random.PRNGKey(np.random.randint(seed)), x, y, t) | |
n_iters = st.number_input("Number of iterations", min_value=1, max_value=200, value=100) | |
lr = st.selectbox("Learning rate", [0.1, 0.01, 0.001, 0.0001], 1) | |
tx = optax.sgd(learning_rate=lr) | |
base_opt_state = tx.init(base_params) | |
opt_state = tx.init(params) | |
loss_grad_fn = jax.jit(jax.value_and_grad(loss)) | |
base_losses = [] | |
losses = [] | |
my_bar = st.progress(0) | |
for i in range(n_iters): | |
m = base_model | |
base_loss_val, base_grads = loss_grad_fn(base_params) | |
m = model | |
loss_val, grads = loss_grad_fn(params) | |
base_updates, base_opt_state = tx.update(base_grads, base_opt_state) | |
updates, opt_state = tx.update(grads, opt_state) | |
base_params = optax.apply_updates(base_params, base_updates) | |
params = optax.apply_updates(params, updates) | |
losses.append(loss_val) | |
base_losses.append(base_loss_val) | |
my_bar.progress((i+1) / n_iters) | |
# Plot the results and compare to the true model | |
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 4)) | |
m_list = [base_model, model] | |
p_list = [base_params, params] | |
t_list = ['Homoskedastic GP', 'Heteroskedastic GP'] | |
j_list = [] | |
for i in range(2): | |
_, (mu, var), jitter = m_list[i].apply(p_list[i], x, y, t) | |
var += jitter.ravel() | |
j_list.append(jitter) | |
# plt.plot(t, 2 * (t > 0) - 1, "k", lw=1, label="truth") | |
ax[i].plot(x, y, ".k", label="data") | |
ax[i].plot(t, mu) | |
ax[i].set_title(t_list[i]) | |
ax[i].fill_between( | |
t, mu + 2 * np.sqrt(var), mu - 2 * np.sqrt(var), alpha=0.5, label="95% conf" | |
) | |
ax[1].legend() | |
col = st.columns(1)[0] | |
with col: | |
st.pyplot(fig) | |
col2 = st.columns(1)[0] | |
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 4)) | |
idx = np.argsort(t) | |
ax[1].plot(t[idx], j_list[1][idx], label="learned noise") | |
ax[0].hlines(j_list[0], t[idx].min(), t[idx].max()) | |
ax[0].set_xlabel('x') | |
ax[1].set_xlabel('x') | |
ax[0].set_ylabel('learned noise') | |
ax[1].legend() | |
with col2: | |
st.pyplot(fig) | |
col3 = st.columns(1)[0] | |
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 4)) | |
idx = np.argsort(t) | |
ax[0].plot(base_losses) | |
ax[1].plot(losses, label="loss") | |
ax[0].set_xlabel('iterations') | |
ax[0].set_ylabel('loss') | |
ax[1].set_xlabel('iterations') | |
ax[1].legend() | |
with col3: | |
st.pyplot(fig) |