mralamdari commited on
Commit
88f711a
·
1 Parent(s): 1b51486

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -17
app.py CHANGED
@@ -26,27 +26,98 @@ import tensorflow as tf
26
  # model.load_model("my_model.keras")
27
 
28
 
29
- def image_mod(image):
30
 
31
- # img = Image.fromarray(image['composite'])
32
- model = tf.keras.models.load_model('weights_1.h5')
33
- test_img = np.array(image['composite']).reshape(1, 28, 28, 1)
34
- # test_img = cv2.resize(np.array(image['composite']), (28, 28, 1))
35
- prediction = model.predict(test_img)
36
- pred = np.argmax(prediction, axis=1)[0]
37
- return pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  title = "Draw to Search"
41
  description = "Using the power of AI to detect the number you draw!"
 
42
 
43
- demo = gr.Interface(
44
- fn=image_mod,
45
- inputs='sketchpad',
46
- outputs='text',
47
- title=title,
48
- description=description,
49
- live=True)
50
 
51
- demo.launch(share=False)
52
- # demo.launch(debug=True)
 
 
 
 
 
 
 
 
26
  # model.load_model("my_model.keras")
27
 
28
 
29
+ # def image_mod(image):
30
 
31
+ # # img = Image.fromarray(image['composite'])
32
+ # model = tf.keras.models.load_model('weights_1.h5')
33
+ # test_img = np.array(image['composite']).reshape(1, 28, 28, 1)
34
+ # # test_img = cv2.resize(np.array(image['composite']), (28, 28, 1))
35
+ # prediction = model.predict(test_img)
36
+ # pred = np.argmax(prediction, axis=1)[0]
37
+ # return pred
38
+
39
+
40
+ # title = "Draw to Search"
41
+ # description = "Using the power of AI to detect the number you draw!"
42
+
43
+ # demo = gr.Interface(
44
+ # fn=image_mod,
45
+ # inputs='sketchpad',
46
+ # outputs='text',
47
+ # title=title,
48
+ # description=description,
49
+ # live=True)
50
+
51
+ # demo.launch(share=False)
52
+ # demo.launch(debug=True)
53
+
54
+
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+
68
+
69
+
70
+
71
+
72
+
73
+
74
+
75
+
76
+ model = tf.keras.models.Sequential([
77
+ tf.keras.layers.Input(shape=(28, 28, 1)),
78
+ tf.keras.layers.Conv2D(filters=16, kernel_size=3, strides=1, padding='same', activation='relu', input_shape=(28, 28, 1)),
79
+ tf.keras.layers.Conv2D(filters=16, kernel_size=3, strides=1, padding='same', activation='relu'),
80
+ tf.keras.layers.BatchNormalization(),
81
+ tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=1, padding='same', activation='relu'),
82
+ tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=1, padding='same', activation='relu'),
83
+ tf.keras.layers.BatchNormalization(),
84
+ tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, padding='same', activation='relu'),
85
+ tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, padding='same', activation='relu'),
86
+ tf.keras.layers.BatchNormalization(),
87
+ tf.keras.layers.GlobalAveragePooling2D(),
88
+ tf.keras.layers.Dense(10, activation='softmax')
89
+ ])
90
+
91
+ model.compile(optimizer=tf.keras.optimizers.Adam(),
92
+ loss=tf.keras.losses.CategoricalCrossentropy(),
93
+ metrics=[tf.keras.metrics.MeanSquaredError(), tf.keras.metrics.AUC(), tf.keras.metrics.CategoricalAccuracy()])
94
+
95
+ model = tf.keras.models.load_model('my_model.h5')
96
+
97
+ def classify_image(image):
98
+ if len(np.array(image).shape) == 3:
99
+ image = tf.image.rgb_to_grayscale(image)
100
+ image_tensor = tf.convert_to_tensor(image)
101
+ image_tensor = tf.cast(image_tensor, tf.float32)
102
+ image_tensor = tf.expand_dims(image_tensor, 0)
103
+ image_tensor = image_tensor / 255.0
104
+ prediction = model.predict(image_tensor)
105
+ prediction_label = str(prediction.argmax())
106
+ return prediction_label
107
 
108
 
109
  title = "Draw to Search"
110
  description = "Using the power of AI to detect the number you draw!"
111
+ article = "for source code you can visit [my github](https://github.com/mralamdari)"
112
 
113
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
 
 
 
 
 
 
114
 
115
+ interface = gr.Interface(fn=classify_image,
116
+ inputs=gr.Image(type="pil"),
117
+ outputs=gr.Label(num_top_classes=3, label="Predictions"),
118
+ examples=example_list,
119
+ title=title,
120
+ description=description,
121
+ article=article)
122
+
123
+ interface.launch()