mnist-bnn-vi / utils.py
srossi93's picture
Fixed race condition
d516303
raw
history blame
1.91 kB
from functools import partial
from typing import Callable
import jax
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from flax import linen as nn
tfd = tfp.distributions
def split_tree(a, rng_key):
treedef = jax.tree_util.tree_structure(a)
num_vars = len(jax.tree_util.tree_leaves(a))
all_keys = jax.random.split(rng_key, num=(num_vars + 1))
return jax.tree_util.tree_unflatten(treedef, all_keys[1:])
def sample_fn(rng, vi_params: nn.FrozenDict):
rng = split_tree(vi_params["mean"], rng)
params = jax.tree_map(
lambda m, ls, k: tfd.Normal(loc=m, scale=jnp.exp(ls)).sample(seed=k),
vi_params["mean"],
vi_params["log_scale"],
rng,
) # type: nn.FrozenDict
return params
def get_apply_fn(model: nn.Module):
"""Returns the model forward function"""
@jax.jit
@partial(jax.vmap, in_axes=(None, None, 0))
def apply_fn(vi_params, inputs, rng):
params = sample_fn(rng, vi_params)
outputs = model.apply({"params": params}, inputs)
return outputs
@jax.jit
def apply_map_fn(params, inputs, rng):
outputs = model.apply({"params": params["mean"]}, inputs)
return outputs[None, ...]
return apply_fn, apply_map_fn
class MLP(nn.Module):
n_features: int = 512
n_layers: int = 3
n_classes: int = 10
n_features_mult: int = 1
bias_init: Callable = nn.initializers.zeros_init()
act: Callable = nn.relu
dtype: str = "float32"
@nn.compact
def __call__(self, x):
dense = partial(
nn.Dense,
dtype=self.dtype,
bias_init=self.bias_init,
)
x = jnp.reshape(x, (x.shape[0], -1))
for _ in range(self.n_layers):
x = dense(int(self.n_features * self.n_features_mult))(x)
x = nn.relu(x)
x = dense(self.n_classes)(x)
return x