from functools import partial import dirichlet import gradio as gr import jax import jax.numpy as jnp import numpy as np import plotly.graph_objects as go import tensorflow_probability.substrates.jax as tfp from flax.traverse_util import unflatten_dict from safetensors.flax import load_file from scipy.ndimage import rotate from utils import MLP, get_apply_fn tfd = tfp.distributions image = np.load("image_2.npy")[0, ..., 0] rotate_image = partial(rotate, reshape=False, axes=(0, 1)) model = MLP(512) apply_fn, apply_map_fn = get_apply_fn(model) params = unflatten_dict(load_file("model.safetensors"), sep=".") params_gd = unflatten_dict(load_file("model-gd.safetensors"), sep=".") rng_test = jax.random.split(jax.random.PRNGKey(0), 1000) def predict(image): image = image[None, ..., None] / 255 y_logits_map = apply_map_fn(params_gd, image, rng_test)[:, 0, :] y_logits = apply_fn(params, image, rng_test)[:, 0, :] y_probs = jax.nn.softmax(y_logits) y_probs_map = jax.nn.softmax(y_logits_map) alpha = dirichlet.mle(y_probs, method="fixedpoint").astype(jnp.float32) predictive = tfd.DirichletMultinomial(1, alpha) entropy = -predictive.log_prob(y_probs).mean() predictive_map = tfd.Categorical(logits=y_logits_map) entropy_map = predictive_map.entropy().mean() fig = go.Figure() for class_id in range(10): fig.add_trace( go.Box(y=y_probs[:, class_id], name=str(class_id), boxpoints=False) ) fig.update_layout( title="Predicted labels (probabilistic model)", xaxis_title="Class", yaxis_title="Probability", yaxis=dict(range=[0, 1.1]), ) fig_map = go.Figure() for class_id in range(10): fig_map.add_trace( go.Box(y=y_probs_map[:, class_id], name=str(class_id), boxpoints=False) ) fig_map.update_layout( title="Predicted labels (deterministic model)", xaxis_title="Class", yaxis_title="Probability", yaxis=dict(range=[0, 1.1]), ) return fig, fig_map, f"{entropy:.3f}", f"{entropy_map:.3f}" with gr.Blocks() as demo: with gr.Row(): with gr.Column(): input = gr.Image( image, image_mode="L", label="Original", height=300, width=300, ) slider = gr.Slider(0, 360, step=1, label="Rotation angle") output = gr.Image( image_mode="L", label="Rotated", height=300, width=300, interactive=False, ) with gr.Column(): plot_map = gr.Plot() entropy_map = gr.Label(label="Entropy of predictions (deterministic)") with gr.Column(): plot = gr.Plot() entropy = gr.Label(label="Entropy of predictions (probabilistic)") slider.release( rotate_image, inputs=[input, slider], outputs=[output], api_name="rotate", ).then( predict, inputs=[output], outputs=[plot, plot_map, entropy, entropy_map], api_name="predict", ) # demo.launch(server_name="0.0.0.0", server_port=42000) demo.launch()