MNIST-Digit-Classifier / modelutil.py
papasega's picture
Update modelutil.py
c9c1322
raw
history blame
No virus
947 Bytes
import tensorflow as tf
def create_model():
LAYERS = [tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1), name="convlayer1"),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu', name="convlayer2"),
tf.keras.layers.Flatten(input_shape=[28,28], name="inputlayer"),
tf.keras.layers.Dense(300, activation='relu', name="hiddenlayer1"),
tf.keras.layers.Dense(100, activation='relu', name="hiddenlayer2"),
tf.keras.layers.Dense(10, activation='softmax', name="outputlayer")]
model = tf.keras.models.Sequential(LAYERS)
model.load_weights('./checkpoint')
# LOSS_FUNCTION = tf.keras.losses.SparseCategoricalCrossentropy() # HERE
# OPTIMIZER = tf.keras.optimizers.legacy.Adam()
# METRICS = ["accuracy"]
# model.compile(loss=LOSS_FUNCTION,
# optimizer=OPTIMIZER,
# metrics=METRICS)
return model