File size: 2,602 Bytes
b388df6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e6f614
b388df6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import tensorflow as tf
import numpy as np
from PIL import Image
import os
# from matplotlib import image as mpimg
# from matplotlib import pyplot as plt

class api():

    height=64
    width=64
    channels=3
    model_name = 'cnn_model'
    classes = { 0 : 'Zero' , 1 : 'One' , 2 : 'Two' , 3 : 'Three' , 4 : 'Four' , 5 : 'Five' }

    def reset_graph(self,seed=42):
        tf.reset_default_graph()
        tf.set_random_seed(seed)
        np.random.seed(seed)


    def __init__(self,upload_path='uploads'):

        self.upload_path = upload_path

        # self.model_name = 'cnn_model'
        print('print',os.path.join('{}.meta'.format(self.model_name)))

        # self.import_meta = tf.train.import_meta_graph(os.path.join('signs_api','{}.meta'.format(self.model_name)))

    def predict(self,im):

        try :

            # im = Image.open( os.path.join(self.upload_path,filename) )

            #image size
            size=(self.height,self.width)
            #resize image
            out = im.resize(size)

            test_image =  np.array(out.getdata())

            test_image = test_image.reshape((-1,self.height,self.width,self.channels))

            # to make this notebook's output stable across runs
            self.reset_graph()

            # import meta from directory
            # import_meta = tf.train.import_meta_graph('{}.meta'.format(self.model_name))
            import_meta = tf.train.import_meta_graph(os.path.join('signs_api','{}.meta'.format(self.model_name)))

            with tf.Session() as sess:

                # tf.train.latest_checkpoint(<dir>) also works

                import_meta.restore(sess,'{}.ckpt'.format( os.path.join('signs_api',self.model_name) ) )

                # W1_val = sess.graph.get_tensor_by_name('W1:0')

                # X_val = sess.graph.get_tensor_by_name('Placeholder:0')

                ArgMax = sess.graph.get_tensor_by_name('ArgMax:0')

                ArgMax_val = ArgMax.eval({ 'Placeholder:0' : test_image })

                # graph = tf.get_default_graph()

                # for op in graph.get_operations():
                #     print(op.name)

                # print('W1_val',W1_val)
                # print('X_val',X_val)
                print('ArgMax',ArgMax_val)
                index = ArgMax_val.tolist()[0]
                class_val = self.classes[index]

                # os.remove(os.path.join(self.upload_path,filename))

                return { 'value' : index , 'class' : class_val  }

        except (OSError,IOError) as e:
            print('error',e)
            return { 'error' : True }