Spaces:
Runtime error
Runtime error
from functools import partial | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn import svm, datasets | |
from sklearn.model_selection import train_test_split | |
from sklearn.metrics import ConfusionMatrixDisplay | |
import gradio as gr | |
def train_model(normalize): | |
# import some data to play with | |
iris = datasets.load_iris() | |
X = iris.data | |
y = iris.target | |
class_names = iris.target_names | |
# Split the data into a training set and a test set | |
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) | |
# Run classifier, using a model that is too regularized (C too low) to see | |
# the impact on the results | |
classifier = svm.SVC(kernel="linear", C=0.01).fit(X_train, y_train) | |
np.set_printoptions(precision=2) | |
title = ( | |
"Normalized confusion matrix" if normalize | |
else "Confusion matrix, without normalization" | |
) | |
disp = ConfusionMatrixDisplay.from_estimator( | |
classifier, | |
X_test, | |
y_test, | |
display_labels=class_names, | |
cmap=plt.cm.Blues, | |
normalize='true' if normalize else None, | |
) | |
disp.ax_.set_title(title) | |
return disp.figure_ | |
title = "Confusion matrix" | |
description = "Example of confusion matrix usage to evaluate the quality of the output of a classifier on the iris data set" | |
with gr.Blocks() as demo: | |
gr.Markdown(f"## {title}") | |
gr.Markdown(description) | |
normalize = gr.Checkbox(label="Normalize") | |
plot = gr.Plot(label="Confusion matrix") | |
fn = partial(train_model) | |
normalize.change(fn=fn, inputs=[normalize], outputs=plot) | |
demo.launch(enable_queue=True, debug=True) | |