Hnabil's picture
Add application file
2d09772
raw
history blame
1.64 kB
from functools import partial
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import ConfusionMatrixDisplay
import gradio as gr
def train_model(normalize):
# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
class_names = iris.target_names
# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel="linear", C=0.01).fit(X_train, y_train)
np.set_printoptions(precision=2)
title = (
"Normalized confusion matrix" if normalize
else "Confusion matrix, without normalization"
)
disp = ConfusionMatrixDisplay.from_estimator(
classifier,
X_test,
y_test,
display_labels=class_names,
cmap=plt.cm.Blues,
normalize='true' if normalize else None,
)
disp.ax_.set_title(title)
return disp.figure_
title = "Confusion matrix"
description = "Example of confusion matrix usage to evaluate the quality of the output of a classifier on the iris data set"
with gr.Blocks() as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description)
normalize = gr.Checkbox(label="Normalize")
plot = gr.Plot(label="Confusion matrix")
fn = partial(train_model)
normalize.change(fn=fn, inputs=[normalize], outputs=plot)
demo.launch(enable_queue=True, debug=True)