Jayabalambika commited on
Commit
f1c2119
·
1 Parent(s): 8f951b8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import gradio as gr
4
+
5
+
6
+ def modified_huber_loss(y_true, y_pred):
7
+ z = y_pred * y_true
8
+ loss = -4 * z
9
+ loss[z >= -1] = (1 - z[z >= -1]) ** 2
10
+ loss[z >= 1.0] = 0
11
+ return loss
12
+
13
+
14
+ def plot_loss_func():
15
+ xmin, xmax = -4, 4
16
+ xx = np.linspace(xmin, xmax, 100)
17
+ lw = 2
18
+ plt.clf()
19
+
20
+ fig = plt.figure(figsize=(10, 10), dpi=100)
21
+ plt.plot([xmin, 0, 0, xmax], [1, 1, 0, 0], color="gold", lw=lw, label="Zero-one loss")
22
+ plt.plot(xx, np.where(xx < 1, 1 - xx, 0), color="teal", lw=lw, label="Hinge loss")
23
+ plt.plot(xx, -np.minimum(xx, 0), color="yellowgreen", lw=lw, label="Perceptron loss")
24
+ plt.plot(xx, np.log2(1 + np.exp(-xx)), color="cornflowerblue", lw=lw, label="Log loss")
25
+ plt.plot(
26
+ xx,
27
+ np.where(xx < 1, 1 - xx, 0) ** 2,
28
+ color="orange",
29
+ lw=lw,
30
+ label="Squared hinge loss",
31
+ )
32
+ plt.plot(
33
+ xx,
34
+ modified_huber_loss(xx, 1),
35
+ color="darkorchid",
36
+ lw=lw,
37
+ linestyle="--",
38
+ label="Modified Huber loss",
39
+ )
40
+ plt.ylim((0, 8))
41
+ plt.legend(loc="upper right")
42
+ plt.xlabel(r"Decision function $f(x)$")
43
+ plt.ylabel("$L(y=1, f(x))$")
44
+ return fig
45
+
46
+ title = "SGD convex loss functions"
47
+
48
+ # def greet(name):
49
+ # return "Hello " + name + "!"
50
+ with gr.Blocks(title=title) as demo:
51
+ gr.Markdown(f"# {title}")
52
+
53
+
54
+ gr.Markdown(" **[Demo is based on sklearn docs](https://scikit-learn.org/stable/auto_examples/linear_model/plot_sgd_loss_functions.html#sphx-glr-auto-examples-linear-model-plot-sgd-loss-functions-py)**")
55
+
56
+ btn = gr.Button(value="SGD convex loss functions")
57
+ btn.click(plot_loss_func, outputs= gr.Plot() ) #
58
+
59
+
60
+
61
+ demo.launch()