srossi93 commited on
Commit
3dce43d
·
1 Parent(s): d7d6530

Update to space

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. app.py +73 -34
  3. requirements.txt +2 -1
  4. utils.py +72 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py CHANGED
@@ -1,24 +1,18 @@
1
- from functools import partial
2
-
3
- import gradio as gr
4
- import numpy as np
5
- from scipy.ndimage import rotate
6
-
7
-
8
  from functools import partial
9
  from typing import Any, Callable, Sequence, Tuple
10
 
11
  import dirichlet
12
-
13
  import jax
14
  import jax.numpy as jnp
15
  import numpy as np
16
  import plotly.graph_objects as go
17
  import tensorflow_probability.substrates.jax as tfp
18
  from flax import linen as nn
19
-
20
-
21
  from scipy.ndimage import rotate
 
22
 
23
  tfd = tfp.distributions
24
 
@@ -27,6 +21,59 @@ image = np.load("image_2.npy")[0, ..., 0]
27
 
28
  rotate_image = partial(rotate, reshape=False, axes=(0, 1))
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  with gr.Blocks() as demo:
32
  with gr.Row():
@@ -41,13 +88,13 @@ with gr.Blocks() as demo:
41
  slider = gr.Slider(0, 360, step=1, label="Rotation angle")
42
  output = gr.Image(image_mode="L", label="Rotated", height=300, width=300)
43
 
44
- # with gr.Column():
45
- # plot_map = gr.Plot()
46
- # entropy_map = gr.Label(label="Entropy of predictions (deterministic)")
47
 
48
- # with gr.Column():
49
- # plot = gr.Plot()
50
- # entropy = gr.Label(label="Entropy of predictions (probabilistic)")
51
 
52
  slider.release(
53
  rotate_image,
@@ -56,22 +103,14 @@ with gr.Blocks() as demo:
56
  api_name="rotate",
57
  )
58
 
59
- # slider.release(
60
- # predict_label,
61
- # inputs=[output],
62
- # outputs=[plot, plot_map, entropy, entropy_map],
63
- # api_name="predict",
64
- # )
65
- # # with gr.Row():
66
- # # number = gr.Number(label="On release")
67
- # # number2 = gr.Number(label="Number of events fired")
68
- # # slider.release(
69
- # # identity,
70
- # # inputs=[slider, state],
71
- # # outputs=[number, state, number2],
72
- # # api_name="predict",
73
- # # )
74
-
75
-
76
- # demo.launch(share=False, server_name="0.0.0.0", server_port=43000)
77
  demo.launch()
 
 
 
 
 
 
 
 
1
  from functools import partial
2
  from typing import Any, Callable, Sequence, Tuple
3
 
4
  import dirichlet
5
+ import gradio as gr
6
  import jax
7
  import jax.numpy as jnp
8
  import numpy as np
9
  import plotly.graph_objects as go
10
  import tensorflow_probability.substrates.jax as tfp
11
  from flax import linen as nn
12
+ from flax.traverse_util import flatten_dict, unflatten_dict
13
+ from safetensors.flax import load_file, save_file
14
  from scipy.ndimage import rotate
15
+ from utils import MLP, get_apply_fn
16
 
17
  tfd = tfp.distributions
18
 
 
21
 
22
  rotate_image = partial(rotate, reshape=False, axes=(0, 1))
23
 
24
+ model = MLP(512)
25
+ apply_fn, apply_map_fn = get_apply_fn(model)
26
+ params = unflatten_dict(load_file("model.safetensors"), sep=".")
27
+
28
+ rng_test = jax.random.split(jax.random.PRNGKey(0), 1000)
29
+
30
+
31
+ def predict(image):
32
+ image = image[None, ..., None] / 255
33
+
34
+ y_logits_map = apply_map_fn(params, image, rng_test)[:, 0, :]
35
+ y_logits = apply_fn(params, image, rng_test)[:, 0, :]
36
+
37
+ y_probs = jax.nn.softmax(y_logits)
38
+ y_probs_map = jax.nn.softmax(y_logits_map)
39
+
40
+ alpha = dirichlet.mle(y_probs, method="fixedpoint")
41
+
42
+ predictive = tfd.DirichletMultinomial(1, alpha)
43
+ entropy = -predictive.log_prob(y_probs).mean()
44
+
45
+ predictive_map = tfd.Categorical(logits=y_logits_map)
46
+ entropy_map = predictive_map.entropy().mean()
47
+
48
+
49
+ fig = go.Figure()
50
+ for class_id in range(10):
51
+ fig.add_trace(
52
+ go.Box(y=y_probs[:, class_id], name=str(class_id), boxpoints=False)
53
+ )
54
+
55
+ fig.update_layout(
56
+ title="Predicted labels (probabilistic model)",
57
+ xaxis_title="Class",
58
+ yaxis_title="Probability",
59
+ yaxis=dict(range=[0, 1.1]),
60
+ )
61
+
62
+ fig_map = go.Figure()
63
+ for class_id in range(10):
64
+ fig_map.add_trace(
65
+ go.Box(y=y_probs_map[:, class_id], name=str(class_id), boxpoints=False)
66
+ )
67
+
68
+ fig_map.update_layout(
69
+ title="Predicted labels (deterministic model)",
70
+ xaxis_title="Class",
71
+ yaxis_title="Probability",
72
+ yaxis=dict(range=[0, 1.1]),
73
+ )
74
+
75
+ return fig, fig_map, f"{entropy:.3f}", f"{entropy_map:.3f}"
76
+
77
 
78
  with gr.Blocks() as demo:
79
  with gr.Row():
 
88
  slider = gr.Slider(0, 360, step=1, label="Rotation angle")
89
  output = gr.Image(image_mode="L", label="Rotated", height=300, width=300)
90
 
91
+ with gr.Column():
92
+ plot_map = gr.Plot()
93
+ entropy_map = gr.Label(label="Entropy of predictions (deterministic)")
94
 
95
+ with gr.Column():
96
+ plot = gr.Plot()
97
+ entropy = gr.Label(label="Entropy of predictions (probabilistic)")
98
 
99
  slider.release(
100
  rotate_image,
 
103
  api_name="rotate",
104
  )
105
 
106
+ slider.release(
107
+ predict,
108
+ inputs=[output],
109
+ outputs=[plot, plot_map, entropy, entropy_map],
110
+ api_name="predict",
111
+ )
112
+
113
+
114
+
115
+
 
 
 
 
 
 
 
 
116
  demo.launch()
requirements.txt CHANGED
@@ -3,4 +3,5 @@ git+https://github.com/ericsuh/dirichlet.git
3
  jax[cpu]==0.4.14
4
  plotly==5.18.0
5
  tensorflow-probability==0.23.0
6
- flax==0.7.2
 
 
3
  jax[cpu]==0.4.14
4
  plotly==5.18.0
5
  tensorflow-probability==0.23.0
6
+ flax==0.7.2
7
+ safetensors==0.4.2
utils.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Callable, Sequence, Tuple
3
+
4
+ import dirichlet
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+ import tensorflow_probability.substrates.jax as tfp
9
+ from flax import linen as nn
10
+ from flax.training.checkpoints import restore_checkpoint
11
+
12
+ tfd = tfp.distributions
13
+
14
+
15
+ def split_tree(a, rng_key):
16
+ treedef = jax.tree_util.tree_structure(a)
17
+ num_vars = len(jax.tree_util.tree_leaves(a))
18
+ all_keys = jax.random.split(rng_key, num=(num_vars + 1))
19
+ return jax.tree_util.tree_unflatten(treedef, all_keys[1:])
20
+
21
+
22
+ def sample_fn(rng, vi_params: nn.FrozenDict):
23
+ rng = split_tree(vi_params["mean"], rng)
24
+ params = jax.tree_map(
25
+ lambda m, ls, k: tfd.Normal(loc=m, scale=jnp.exp(ls)).sample(seed=k),
26
+ vi_params["mean"],
27
+ vi_params["log_scale"],
28
+ rng,
29
+ ) # type: nn.FrozenDict
30
+ return params
31
+
32
+
33
+ def get_apply_fn(model: nn.Module):
34
+ """Returns the model forward function"""
35
+
36
+ @jax.jit
37
+ @partial(jax.vmap, in_axes=(None, None, 0))
38
+ def apply_fn(vi_params, inputs, rng):
39
+ params = sample_fn(rng, vi_params)
40
+ outputs = model.apply({"params": params}, inputs)
41
+ return outputs
42
+
43
+ @jax.jit
44
+ def apply_map_fn(params, inputs, rng):
45
+ outputs = model.apply({"params": params["mean"]}, inputs)
46
+ return outputs[None, ...]
47
+
48
+ return apply_fn, apply_map_fn
49
+
50
+
51
+ class MLP(nn.Module):
52
+ n_features: int = 512
53
+ n_layers: int = 3
54
+ n_classes: int = 10
55
+ n_features_mult: int = 1
56
+ bias_init: Callable = nn.initializers.zeros_init()
57
+ act: Callable = nn.relu
58
+ dtype: str = "float32"
59
+
60
+ @nn.compact
61
+ def __call__(self, x):
62
+ dense = partial(
63
+ nn.Dense,
64
+ dtype=self.dtype,
65
+ bias_init=self.bias_init,
66
+ )
67
+ x = jnp.reshape(x, (x.shape[0], -1))
68
+ for _ in range(self.n_layers):
69
+ x = dense(int(self.n_features * self.n_features_mult))(x)
70
+ x = nn.relu(x)
71
+ x = dense(self.n_classes)(x)
72
+ return x