Spaces:
Sleeping
Sleeping
Update to space
Browse files- .gitignore +1 -0
- app.py +73 -34
- requirements.txt +2 -1
- utils.py +72 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
app.py
CHANGED
@@ -1,24 +1,18 @@
|
|
1 |
-
from functools import partial
|
2 |
-
|
3 |
-
import gradio as gr
|
4 |
-
import numpy as np
|
5 |
-
from scipy.ndimage import rotate
|
6 |
-
|
7 |
-
|
8 |
from functools import partial
|
9 |
from typing import Any, Callable, Sequence, Tuple
|
10 |
|
11 |
import dirichlet
|
12 |
-
|
13 |
import jax
|
14 |
import jax.numpy as jnp
|
15 |
import numpy as np
|
16 |
import plotly.graph_objects as go
|
17 |
import tensorflow_probability.substrates.jax as tfp
|
18 |
from flax import linen as nn
|
19 |
-
|
20 |
-
|
21 |
from scipy.ndimage import rotate
|
|
|
22 |
|
23 |
tfd = tfp.distributions
|
24 |
|
@@ -27,6 +21,59 @@ image = np.load("image_2.npy")[0, ..., 0]
|
|
27 |
|
28 |
rotate_image = partial(rotate, reshape=False, axes=(0, 1))
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
with gr.Blocks() as demo:
|
32 |
with gr.Row():
|
@@ -41,13 +88,13 @@ with gr.Blocks() as demo:
|
|
41 |
slider = gr.Slider(0, 360, step=1, label="Rotation angle")
|
42 |
output = gr.Image(image_mode="L", label="Rotated", height=300, width=300)
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
|
52 |
slider.release(
|
53 |
rotate_image,
|
@@ -56,22 +103,14 @@ with gr.Blocks() as demo:
|
|
56 |
api_name="rotate",
|
57 |
)
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
# # identity,
|
70 |
-
# # inputs=[slider, state],
|
71 |
-
# # outputs=[number, state, number2],
|
72 |
-
# # api_name="predict",
|
73 |
-
# # )
|
74 |
-
|
75 |
-
|
76 |
-
# demo.launch(share=False, server_name="0.0.0.0", server_port=43000)
|
77 |
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from functools import partial
|
2 |
from typing import Any, Callable, Sequence, Tuple
|
3 |
|
4 |
import dirichlet
|
5 |
+
import gradio as gr
|
6 |
import jax
|
7 |
import jax.numpy as jnp
|
8 |
import numpy as np
|
9 |
import plotly.graph_objects as go
|
10 |
import tensorflow_probability.substrates.jax as tfp
|
11 |
from flax import linen as nn
|
12 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
13 |
+
from safetensors.flax import load_file, save_file
|
14 |
from scipy.ndimage import rotate
|
15 |
+
from utils import MLP, get_apply_fn
|
16 |
|
17 |
tfd = tfp.distributions
|
18 |
|
|
|
21 |
|
22 |
rotate_image = partial(rotate, reshape=False, axes=(0, 1))
|
23 |
|
24 |
+
model = MLP(512)
|
25 |
+
apply_fn, apply_map_fn = get_apply_fn(model)
|
26 |
+
params = unflatten_dict(load_file("model.safetensors"), sep=".")
|
27 |
+
|
28 |
+
rng_test = jax.random.split(jax.random.PRNGKey(0), 1000)
|
29 |
+
|
30 |
+
|
31 |
+
def predict(image):
|
32 |
+
image = image[None, ..., None] / 255
|
33 |
+
|
34 |
+
y_logits_map = apply_map_fn(params, image, rng_test)[:, 0, :]
|
35 |
+
y_logits = apply_fn(params, image, rng_test)[:, 0, :]
|
36 |
+
|
37 |
+
y_probs = jax.nn.softmax(y_logits)
|
38 |
+
y_probs_map = jax.nn.softmax(y_logits_map)
|
39 |
+
|
40 |
+
alpha = dirichlet.mle(y_probs, method="fixedpoint")
|
41 |
+
|
42 |
+
predictive = tfd.DirichletMultinomial(1, alpha)
|
43 |
+
entropy = -predictive.log_prob(y_probs).mean()
|
44 |
+
|
45 |
+
predictive_map = tfd.Categorical(logits=y_logits_map)
|
46 |
+
entropy_map = predictive_map.entropy().mean()
|
47 |
+
|
48 |
+
|
49 |
+
fig = go.Figure()
|
50 |
+
for class_id in range(10):
|
51 |
+
fig.add_trace(
|
52 |
+
go.Box(y=y_probs[:, class_id], name=str(class_id), boxpoints=False)
|
53 |
+
)
|
54 |
+
|
55 |
+
fig.update_layout(
|
56 |
+
title="Predicted labels (probabilistic model)",
|
57 |
+
xaxis_title="Class",
|
58 |
+
yaxis_title="Probability",
|
59 |
+
yaxis=dict(range=[0, 1.1]),
|
60 |
+
)
|
61 |
+
|
62 |
+
fig_map = go.Figure()
|
63 |
+
for class_id in range(10):
|
64 |
+
fig_map.add_trace(
|
65 |
+
go.Box(y=y_probs_map[:, class_id], name=str(class_id), boxpoints=False)
|
66 |
+
)
|
67 |
+
|
68 |
+
fig_map.update_layout(
|
69 |
+
title="Predicted labels (deterministic model)",
|
70 |
+
xaxis_title="Class",
|
71 |
+
yaxis_title="Probability",
|
72 |
+
yaxis=dict(range=[0, 1.1]),
|
73 |
+
)
|
74 |
+
|
75 |
+
return fig, fig_map, f"{entropy:.3f}", f"{entropy_map:.3f}"
|
76 |
+
|
77 |
|
78 |
with gr.Blocks() as demo:
|
79 |
with gr.Row():
|
|
|
88 |
slider = gr.Slider(0, 360, step=1, label="Rotation angle")
|
89 |
output = gr.Image(image_mode="L", label="Rotated", height=300, width=300)
|
90 |
|
91 |
+
with gr.Column():
|
92 |
+
plot_map = gr.Plot()
|
93 |
+
entropy_map = gr.Label(label="Entropy of predictions (deterministic)")
|
94 |
|
95 |
+
with gr.Column():
|
96 |
+
plot = gr.Plot()
|
97 |
+
entropy = gr.Label(label="Entropy of predictions (probabilistic)")
|
98 |
|
99 |
slider.release(
|
100 |
rotate_image,
|
|
|
103 |
api_name="rotate",
|
104 |
)
|
105 |
|
106 |
+
slider.release(
|
107 |
+
predict,
|
108 |
+
inputs=[output],
|
109 |
+
outputs=[plot, plot_map, entropy, entropy_map],
|
110 |
+
api_name="predict",
|
111 |
+
)
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
demo.launch()
|
requirements.txt
CHANGED
@@ -3,4 +3,5 @@ git+https://github.com/ericsuh/dirichlet.git
|
|
3 |
jax[cpu]==0.4.14
|
4 |
plotly==5.18.0
|
5 |
tensorflow-probability==0.23.0
|
6 |
-
flax==0.7.2
|
|
|
|
3 |
jax[cpu]==0.4.14
|
4 |
plotly==5.18.0
|
5 |
tensorflow-probability==0.23.0
|
6 |
+
flax==0.7.2
|
7 |
+
safetensors==0.4.2
|
utils.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Any, Callable, Sequence, Tuple
|
3 |
+
|
4 |
+
import dirichlet
|
5 |
+
import jax
|
6 |
+
import jax.numpy as jnp
|
7 |
+
import numpy as np
|
8 |
+
import tensorflow_probability.substrates.jax as tfp
|
9 |
+
from flax import linen as nn
|
10 |
+
from flax.training.checkpoints import restore_checkpoint
|
11 |
+
|
12 |
+
tfd = tfp.distributions
|
13 |
+
|
14 |
+
|
15 |
+
def split_tree(a, rng_key):
|
16 |
+
treedef = jax.tree_util.tree_structure(a)
|
17 |
+
num_vars = len(jax.tree_util.tree_leaves(a))
|
18 |
+
all_keys = jax.random.split(rng_key, num=(num_vars + 1))
|
19 |
+
return jax.tree_util.tree_unflatten(treedef, all_keys[1:])
|
20 |
+
|
21 |
+
|
22 |
+
def sample_fn(rng, vi_params: nn.FrozenDict):
|
23 |
+
rng = split_tree(vi_params["mean"], rng)
|
24 |
+
params = jax.tree_map(
|
25 |
+
lambda m, ls, k: tfd.Normal(loc=m, scale=jnp.exp(ls)).sample(seed=k),
|
26 |
+
vi_params["mean"],
|
27 |
+
vi_params["log_scale"],
|
28 |
+
rng,
|
29 |
+
) # type: nn.FrozenDict
|
30 |
+
return params
|
31 |
+
|
32 |
+
|
33 |
+
def get_apply_fn(model: nn.Module):
|
34 |
+
"""Returns the model forward function"""
|
35 |
+
|
36 |
+
@jax.jit
|
37 |
+
@partial(jax.vmap, in_axes=(None, None, 0))
|
38 |
+
def apply_fn(vi_params, inputs, rng):
|
39 |
+
params = sample_fn(rng, vi_params)
|
40 |
+
outputs = model.apply({"params": params}, inputs)
|
41 |
+
return outputs
|
42 |
+
|
43 |
+
@jax.jit
|
44 |
+
def apply_map_fn(params, inputs, rng):
|
45 |
+
outputs = model.apply({"params": params["mean"]}, inputs)
|
46 |
+
return outputs[None, ...]
|
47 |
+
|
48 |
+
return apply_fn, apply_map_fn
|
49 |
+
|
50 |
+
|
51 |
+
class MLP(nn.Module):
|
52 |
+
n_features: int = 512
|
53 |
+
n_layers: int = 3
|
54 |
+
n_classes: int = 10
|
55 |
+
n_features_mult: int = 1
|
56 |
+
bias_init: Callable = nn.initializers.zeros_init()
|
57 |
+
act: Callable = nn.relu
|
58 |
+
dtype: str = "float32"
|
59 |
+
|
60 |
+
@nn.compact
|
61 |
+
def __call__(self, x):
|
62 |
+
dense = partial(
|
63 |
+
nn.Dense,
|
64 |
+
dtype=self.dtype,
|
65 |
+
bias_init=self.bias_init,
|
66 |
+
)
|
67 |
+
x = jnp.reshape(x, (x.shape[0], -1))
|
68 |
+
for _ in range(self.n_layers):
|
69 |
+
x = dense(int(self.n_features * self.n_features_mult))(x)
|
70 |
+
x = nn.relu(x)
|
71 |
+
x = dense(self.n_classes)(x)
|
72 |
+
return x
|