from functools import partial from typing import Callable import jax import jax.numpy as jnp import tensorflow_probability.substrates.jax as tfp from flax import linen as nn tfd = tfp.distributions def split_tree(a, rng_key): treedef = jax.tree_util.tree_structure(a) num_vars = len(jax.tree_util.tree_leaves(a)) all_keys = jax.random.split(rng_key, num=(num_vars + 1)) return jax.tree_util.tree_unflatten(treedef, all_keys[1:]) def sample_fn(rng, vi_params: nn.FrozenDict): rng = split_tree(vi_params["mean"], rng) params = jax.tree_map( lambda m, ls, k: tfd.Normal(loc=m, scale=jnp.exp(ls)).sample(seed=k), vi_params["mean"], vi_params["log_scale"], rng, ) # type: nn.FrozenDict return params def get_apply_fn(model: nn.Module): """Returns the model forward function""" @jax.jit @partial(jax.vmap, in_axes=(None, None, 0)) def apply_fn(vi_params, inputs, rng): params = sample_fn(rng, vi_params) outputs = model.apply({"params": params}, inputs) return outputs @jax.jit def apply_map_fn(params, inputs, rng): outputs = model.apply({"params": params["mean"]}, inputs) return outputs[None, ...] return apply_fn, apply_map_fn class MLP(nn.Module): n_features: int = 512 n_layers: int = 3 n_classes: int = 10 n_features_mult: int = 1 bias_init: Callable = nn.initializers.zeros_init() act: Callable = nn.relu dtype: str = "float32" @nn.compact def __call__(self, x): dense = partial( nn.Dense, dtype=self.dtype, bias_init=self.bias_init, ) x = jnp.reshape(x, (x.shape[0], -1)) for _ in range(self.n_layers): x = dense(int(self.n_features * self.n_features_mult))(x) x = nn.relu(x) x = dense(self.n_classes)(x) return x