import gradio as gr import numpy as np import matplotlib.pyplot as plt from sklearn import linear_model def plot(seed, num_points): # Setting the seed if seed != -1: np.random.seed(seed) num_points = int(num_points) #Ensuring the number of points is even if num_points%2 != 0: num_points +=1 half_num_points = int(num_points/2) X = np.r_[np.random.randn(half_num_points, 2) + [1, 1], np.random.randn(half_num_points, 2)] y = [1] * half_num_points + [-1] * half_num_points sample_weight = 100 * np.abs(np.random.randn(num_points)) # and assign a bigger weight to the second half of samples sample_weight[:half_num_points] *= 10 # plot the weighted data points xx, yy = np.meshgrid(np.linspace(-4, 5, 500), np.linspace(-4, 5, 500)) fig, ax = plt.subplots() ax.scatter( X[:, 0], X[:, 1], c=y, s=sample_weight, alpha=0.9, cmap=plt.cm.bone, edgecolor="black", ) # fit the unweighted model clf = linear_model.SGDClassifier(alpha=0.01, max_iter=100) clf.fit(X, y) Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) Z = Z.reshape(xx.shape) no_weights = ax.contour(xx, yy, Z, levels=[0], linestyles=["solid"]) # fit the weighted model clf = linear_model.SGDClassifier(alpha=0.01, max_iter=100) clf.fit(X, y, sample_weight=sample_weight) Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) Z = Z.reshape(xx.shape) samples_weights = ax.contour(xx, yy, Z, levels=[0], linestyles=["dashed"]) no_weights_handles, _ = no_weights.legend_elements() weights_handles, _ = samples_weights.legend_elements() ax.legend( [no_weights_handles[0], weights_handles[0]], ["no weights", "with weights"], loc="lower left", ) ax.set(xticks=(), yticks=()) return fig info = ''' # SGD: Weighted samples\n This is a demonstration of a modified version of [SGD](https://scikit-learn.org/stable/modules/sgd.html#id5) that takes into account the weights of the samples. Where the size of points is proportional to its weight.\n The algorithm is demonstrated using points sampled from the standard normal distribution, where the weighted class has a mean of one while the non-weighted class has a mean of zero.\n Created by [@Nahrawy](https://huggingface.co./Nahrawy) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/linear_model/plot_sgd_weighted_samples.html). ''' with gr.Blocks() as demo: gr.Markdown(info) with gr.Row(): with gr.Column(): seed = gr.Slider(label="Seed", minimum=-1, maximum=10000, step=1,info="Set to -1 to generate new random points each run ",value=-1) num_points = gr.Slider(label="Number of Points", value="20", minimum=5, maximum=100, step=2) #btn = gr.Button("Run") out = gr.Plot() seed.change(fn=plot, inputs=[seed,num_points] , outputs=out) num_points.change(fn=plot, inputs=[seed,num_points] , outputs=out) #btn.click(fn=plot, inputs=[seed,num_points] , outputs=out) demo.launch()