rashmi commited on
Commit
e069ab8
·
1 Parent(s): 1064ced

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -0
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scikit learn example https://scikit-learn.org/stable/auto_examples/cluster/plot_optics.html
2
+
3
+ import gradio as gr
4
+
5
+ from sklearn.cluster import OPTICS, cluster_optics_dbscan
6
+ import matplotlib.gridspec as gridspec
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+
10
+ plt.switch_backend("agg")
11
+
12
+ # Theme from - https://huggingface.co/spaces/trl-lib/stack-llama/blob/main/app.py
13
+ theme = gr.themes.Monochrome(
14
+ primary_hue="indigo",
15
+ secondary_hue="blue",
16
+ neutral_hue="slate",
17
+ radius_size=gr.themes.sizes.radius_sm,
18
+ font=[
19
+ gr.themes.GoogleFont("Open Sans"),
20
+ "ui-sans-serif",
21
+ "system-ui",
22
+ "sans-serif",
23
+ ],
24
+ )
25
+
26
+
27
+ def do_submit(n_points_per_cluster, min_samples, xi, min_cluster_size):
28
+ # # Generate sample data
29
+ np.random.seed(0)
30
+ n_points_per_cluster = int(n_points_per_cluster)
31
+
32
+ C1 = [-5, -2] + 0.8 * np.random.randn(n_points_per_cluster, 2)
33
+ C2 = [4, -1] + 0.1 * np.random.randn(n_points_per_cluster, 2)
34
+ C3 = [1, -2] + 0.2 * np.random.randn(n_points_per_cluster, 2)
35
+ C4 = [-2, 3] + 0.3 * np.random.randn(n_points_per_cluster, 2)
36
+ C5 = [3, -2] + 1.6 * np.random.randn(n_points_per_cluster, 2)
37
+ C6 = [5, 6] + 2 * np.random.randn(n_points_per_cluster, 2)
38
+ X = np.vstack((C1, C2, C3, C4, C5, C6))
39
+
40
+ clust = OPTICS(
41
+ min_samples=int(min_samples),
42
+ xi=float(xi),
43
+ min_cluster_size=float(min_cluster_size),
44
+ )
45
+
46
+ # Run the fit
47
+ clust.fit(X)
48
+
49
+ labels_050 = cluster_optics_dbscan(
50
+ reachability=clust.reachability_,
51
+ core_distances=clust.core_distances_,
52
+ ordering=clust.ordering_,
53
+ eps=0.5,
54
+ )
55
+ labels_200 = cluster_optics_dbscan(
56
+ reachability=clust.reachability_,
57
+ core_distances=clust.core_distances_,
58
+ ordering=clust.ordering_,
59
+ eps=2,
60
+ )
61
+
62
+ space = np.arange(len(X))
63
+ reachability = clust.reachability_[clust.ordering_]
64
+ labels = clust.labels_[clust.ordering_]
65
+
66
+ plt.figure(figsize=(10, 7))
67
+ G = gridspec.GridSpec(2, 3)
68
+ ax1 = plt.subplot(G[0, :])
69
+ ax2 = plt.subplot(G[1, 0])
70
+ ax3 = plt.subplot(G[1, 1])
71
+ ax4 = plt.subplot(G[1, 2])
72
+
73
+ # Reachability plot
74
+ colors = ["g.", "r.", "b.", "y.", "c."]
75
+ for klass, color in zip(range(0, 5), colors):
76
+ Xk = space[labels == klass]
77
+ Rk = reachability[labels == klass]
78
+ ax1.plot(Xk, Rk, color, alpha=0.3)
79
+ ax1.plot(space[labels == -1], reachability[labels == -1], "k.", alpha=0.3)
80
+ ax1.plot(space, np.full_like(space, 2.0, dtype=float), "k-", alpha=0.5)
81
+ ax1.plot(space, np.full_like(space, 0.5, dtype=float), "k-.", alpha=0.5)
82
+ ax1.set_ylabel("Reachability (epsilon distance)")
83
+ ax1.set_title("Reachability Plot")
84
+
85
+ # OPTICS
86
+ colors = ["g.", "r.", "b.", "y.", "c."]
87
+ for klass, color in zip(range(0, 5), colors):
88
+ Xk = X[clust.labels_ == klass]
89
+ ax2.plot(Xk[:, 0], Xk[:, 1], color, alpha=0.3)
90
+ ax2.plot(X[clust.labels_ == -1, 0], X[clust.labels_ == -1, 1], "k+", alpha=0.1)
91
+ ax2.set_title("Automatic Clustering\nOPTICS")
92
+
93
+ # DBSCAN at 0.5
94
+ colors = ["g.", "r.", "b.", "c."]
95
+ for klass, color in zip(range(0, 4), colors):
96
+ Xk = X[labels_050 == klass]
97
+ ax3.plot(Xk[:, 0], Xk[:, 1], color, alpha=0.3)
98
+ ax3.plot(X[labels_050 == -1, 0], X[labels_050 == -1, 1], "k+", alpha=0.1)
99
+ ax3.set_title("Clustering at 0.5 epsilon cut\nDBSCAN")
100
+
101
+ # DBSCAN at 2.
102
+ colors = ["g.", "m.", "y.", "c."]
103
+ for klass, color in zip(range(0, 4), colors):
104
+ Xk = X[labels_200 == klass]
105
+ ax4.plot(Xk[:, 0], Xk[:, 1], color, alpha=0.3)
106
+ ax4.plot(X[labels_200 == -1, 0], X[labels_200 == -1, 1], "k+", alpha=0.1)
107
+ ax4.set_title("Clustering at 2.0 epsilon cut\nDBSCAN")
108
+
109
+ plt.tight_layout()
110
+
111
+ return plt
112
+
113
+
114
+ title = "Demo of OPTICS clustering algorithm"
115
+ with gr.Blocks(title=title, theme=theme) as demo:
116
+ gr.Markdown(f"## {title}")
117
+ gr.Markdown(
118
+ "[Scikit-learn Example](https://scikit-learn.org/stable/auto_examples/cluster/plot_optics.html)"
119
+ )
120
+
121
+ gr.Markdown(
122
+ "Finds core samples of high density and expands clusters from them. This example uses data that is \
123
+ generated so that the clusters have different densities. The [OPTICS](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.OPTICS.html#sklearn.cluster.OPTICS) is first used with its Xi cluster detection \
124
+ method, and then setting specific thresholds on the reachability, which corresponds to [DBSCAN](https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html#sklearn.cluster.DBSCAN). We can see that \
125
+ the different clusters of OPTICS’s Xi method can be recovered with different choices of thresholds in DBSCAN."
126
+ )
127
+
128
+ n_points_per_cluster = gr.Slider(
129
+ minimum=200,
130
+ maximum=500,
131
+ label="Number of points per cluster",
132
+ step=50,
133
+ value=250,
134
+ )
135
+ min_samples = gr.Slider(
136
+ minimum=10,
137
+ maximum=100,
138
+ label="OPTICS - Minimum number of samples",
139
+ step=5,
140
+ value=50,
141
+ info="The number of samples in a neighborhood for a point to be considered as a core point.",
142
+ )
143
+ xi = gr.Slider(
144
+ minimum=0,
145
+ maximum=0.2,
146
+ label="OPTICS - Xi",
147
+ step=0.05,
148
+ value=0.05,
149
+ info="Determines the minimum steepness on the reachability plot that constitutes a cluster boundary. ",
150
+ )
151
+
152
+ min_cluster_size = gr.Slider(
153
+ minimum=0.01,
154
+ maximum=0.1,
155
+ label="OPTICS - Minimum cluster size",
156
+ step=0.01,
157
+ value=0.05,
158
+ info="Minimum number of samples in an OPTICS cluster, expressed as an absolute number or a fraction of the number of samples (rounded to be at least 2).",
159
+ )
160
+
161
+ plt_out = gr.Plot()
162
+
163
+ sub_btn = gr.Button("Submit")
164
+ sub_btn.click(
165
+ fn=do_submit,
166
+ inputs=[n_points_per_cluster, min_samples, xi, min_cluster_size],
167
+ outputs=[plt_out],
168
+ )
169
+
170
+
171
+ if __name__ == "__main__":
172
+ demo.launch()
173
+