Spaces:
Runtime error
Runtime error
"""Gradio demo for different clustering techiniques | |
Derived from https://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_comparison.html | |
""" | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from sklearn.cluster import ( | |
AgglomerativeClustering, Birch, DBSCAN, KMeans, MeanShift, OPTICS, SpectralClustering, estimate_bandwidth | |
) | |
from sklearn.datasets import make_blobs, make_circles, make_moons | |
from sklearn.mixture import GaussianMixture | |
from sklearn.neighbors import kneighbors_graph | |
from sklearn.preprocessing import StandardScaler | |
plt.style.use('seaborn') | |
SEED = 0 | |
N_CLUSTERS = 4 | |
N_SAMPLES = 1000 | |
np.random.seed(SEED) | |
def normalize(X): | |
return StandardScaler().fit_transform(X) | |
def get_regular(): | |
centers = [[1, 1], [1, -1], [-1, 1], [-1, -1]] | |
assert len(centers) == N_CLUSTERS | |
X, labels = make_blobs(n_samples=N_SAMPLES, centers=centers, cluster_std=0.7, random_state=SEED) | |
return normalize(X), labels | |
def get_circles(): | |
X, labels = make_circles(n_samples=N_SAMPLES, factor=0.5, noise=0.05, random_state=SEED) | |
return normalize(X), labels | |
def get_moons(): | |
X, labels = make_moons(n_samples=N_SAMPLES, noise=0.05, random_state=SEED) | |
return normalize(X), labels | |
def get_noise(): | |
X, labels = np.random.rand(N_SAMPLES, 2), np.zeros(N_SAMPLES) | |
return normalize(X), labels | |
def get_anisotropic(): | |
X, labels = make_blobs(n_samples=N_SAMPLES, centers=N_CLUSTERS, random_state=170) | |
transformation = [[0.6, -0.6], [-0.4, 0.8]] | |
X = np.dot(X, transformation) | |
return X, labels | |
def get_varied(): | |
X, labels = make_blobs( | |
n_samples=N_SAMPLES, cluster_std=[1.0, 2.5, 0.5], random_state=SEED | |
) | |
return normalize(X), labels | |
DATA_MAPPING = { | |
'regular': get_regular, | |
'circles': get_circles, | |
'moons': get_moons, | |
'noise': get_noise, | |
'anisotropic': get_anisotropic, | |
'varied': get_varied, | |
} | |
def get_kmeans(X, **kwargs): | |
model = KMeans(init="k-means++", n_clusters=N_CLUSTERS, n_init=10, random_state=SEED) | |
model.set_params(**kwargs) | |
return model.fit(X) | |
def get_dbscan(X, **kwargs): | |
model = DBSCAN(eps=0.3) | |
model.set_params(**kwargs) | |
return model.fit(X) | |
def get_agglomerative(X, **kwargs): | |
connectivity = kneighbors_graph( | |
X, n_neighbors=N_CLUSTERS, include_self=False | |
) | |
# make connectivity symmetric | |
connectivity = 0.5 * (connectivity + connectivity.T) | |
model = AgglomerativeClustering( | |
n_clusters=N_CLUSTERS, linkage="ward", connectivity=connectivity | |
) | |
model.set_params(**kwargs) | |
return model.fit(X) | |
def get_meanshift(X, **kwargs): | |
bandwidth = estimate_bandwidth(X, quantile=0.3) | |
model = MeanShift(bandwidth=bandwidth, bin_seeding=True) | |
model.set_params(**kwargs) | |
return model.fit(X) | |
def get_spectral(X, **kwargs): | |
model = SpectralClustering( | |
n_clusters=N_CLUSTERS, | |
eigen_solver="arpack", | |
affinity="nearest_neighbors", | |
) | |
model.set_params(**kwargs) | |
return model.fit(X) | |
def get_optics(X, **kwargs): | |
model = OPTICS( | |
min_samples=7, | |
xi=0.05, | |
min_cluster_size=0.1, | |
) | |
model.set_params(**kwargs) | |
return model.fit(X) | |
def get_birch(X, **kwargs): | |
model = Birch(n_clusters=3) | |
model.set_params(**kwargs) | |
return model.fit(X) | |
def get_gaussianmixture(X, **kwargs): | |
model = GaussianMixture( | |
n_components=N_CLUSTERS, covariance_type="full", random_state=SEED, | |
) | |
model.set_params(**kwargs) | |
return model.fit(X) | |
MODEL_MAPPING = { | |
'KMeans': get_kmeans, | |
'DBSCAN': get_dbscan, | |
'AgglomerativeClustering': get_agglomerative, | |
'MeanShift': get_meanshift, | |
'SpectralClustering': get_spectral, | |
'OPTICS': get_optics, | |
'Birch': get_birch, | |
'GaussianMixture': get_gaussianmixture, | |
} | |
def plot_clusters(ax, X, labels): | |
for label in range(N_CLUSTERS): | |
idx = labels == label | |
if not sum(idx): | |
continue | |
ax.scatter(X[idx, 0], X[idx, 1]) | |
ax.grid(None) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
return ax | |
def cluster(clustering_algorithm: str, dataset: str): | |
X, labels = DATA_MAPPING[dataset]() | |
model = MODEL_MAPPING[clustering_algorithm](X) | |
if hasattr(model, "labels_"): | |
y_pred = model.labels_.astype(int) | |
else: | |
y_pred = model.predict(X) | |
fig, axes = plt.subplots(1, 2, figsize=(16, 8)) | |
ax = axes[0] | |
plot_clusters(ax, X, labels) | |
ax.set_title("True clusters") | |
ax = axes[1] | |
plot_clusters(ax, X, y_pred) | |
ax.set_title(clustering_algorithm) | |
return fig | |
title = "Clustering with Scikit-learn" | |
description = "This example shows how different clustering algorithms work. Simply pick the algorithm and the dataset to see the clusters algorithms make." | |
demo = gr.Interface( | |
fn=cluster, | |
inputs=[ | |
gr.Radio( | |
list(MODEL_MAPPING), | |
value="KMeans", | |
label="clustering algorithm" | |
), | |
gr.Radio( | |
list(DATA_MAPPING), | |
value="regular", | |
label="dataset" | |
), | |
], | |
title=title, | |
description=description, | |
outputs=gr.Plot(), | |
) | |
demo.launch() | |