Spaces:
Sleeping
Sleeping
from functools import partial | |
from typing import Callable | |
import jax | |
import jax.numpy as jnp | |
import tensorflow_probability.substrates.jax as tfp | |
from flax import linen as nn | |
tfd = tfp.distributions | |
def split_tree(a, rng_key): | |
treedef = jax.tree_util.tree_structure(a) | |
num_vars = len(jax.tree_util.tree_leaves(a)) | |
all_keys = jax.random.split(rng_key, num=(num_vars + 1)) | |
return jax.tree_util.tree_unflatten(treedef, all_keys[1:]) | |
def sample_fn(rng, vi_params: nn.FrozenDict): | |
rng = split_tree(vi_params["mean"], rng) | |
params = jax.tree_map( | |
lambda m, ls, k: tfd.Normal(loc=m, scale=jnp.exp(ls)).sample(seed=k), | |
vi_params["mean"], | |
vi_params["log_scale"], | |
rng, | |
) # type: nn.FrozenDict | |
return params | |
def get_apply_fn(model: nn.Module): | |
"""Returns the model forward function""" | |
def apply_fn(vi_params, inputs, rng): | |
params = sample_fn(rng, vi_params) | |
outputs = model.apply({"params": params}, inputs) | |
return outputs | |
def apply_map_fn(params, inputs, rng): | |
outputs = model.apply({"params": params["mean"]}, inputs) | |
return outputs[None, ...] | |
return apply_fn, apply_map_fn | |
class MLP(nn.Module): | |
n_features: int = 512 | |
n_layers: int = 3 | |
n_classes: int = 10 | |
n_features_mult: int = 1 | |
bias_init: Callable = nn.initializers.zeros_init() | |
act: Callable = nn.relu | |
dtype: str = "float32" | |
def __call__(self, x): | |
dense = partial( | |
nn.Dense, | |
dtype=self.dtype, | |
bias_init=self.bias_init, | |
) | |
x = jnp.reshape(x, (x.shape[0], -1)) | |
for _ in range(self.n_layers): | |
x = dense(int(self.n_features * self.n_features_mult))(x) | |
x = nn.relu(x) | |
x = dense(self.n_classes)(x) | |
return x | |