chandan06 commited on
Commit
4197a00
·
1 Parent(s): 125bba7

Update classification.py

Browse files
Files changed (1) hide show
  1. classification.py +19 -19
classification.py CHANGED
@@ -3,18 +3,18 @@ import time
3
  from tensorflow.keras.preprocessing import image
4
  # from tensorflow.keras.preprocessing.image import ImageDataGenerator
5
  import tensorflow as tf
6
- # gpus = tf.config.experimental.list_physical_devices('GPU')
7
- # if gpus:
8
- # try:
9
- # for gpu in gpus:
10
- # tf.config.experimental.set_memory_growth(gpu, True)
11
- # except RuntimeError as e:
12
- # # Memory growth must be set before GPUs have been initialized
13
- # print(e)
14
  import streamlit as st
15
- with tf.device('/cpu:0'):
16
  # Load the saved model
17
- model = tf.keras.models.load_model('best_resnet152_model.h5')
18
 
19
  class_names = {0: '1099_Div', 1: '1099_Int', 2: 'Non_Form', 3: 'w_2', 4: 'w_3'}
20
  # print(class_names)
@@ -29,16 +29,16 @@ def predict(pil_img):
29
  img_array /= 255.0 # Rescale pixel values
30
 
31
  # Predict the class
32
- with tf.device('/cpu:0'):
33
- start_time = time.time()
34
- predictions = model.predict(img_array)
35
- end_time = time.time()
36
- predicted_class_index = np.argmax(predictions, axis=1)[0]
37
 
38
- # Get the predicted class name
39
- predicted_class_name = class_names[predicted_class_index]
40
- print("Predicted class:", predicted_class_name)
41
- print("Execution time: ", end_time - start_time)
42
  return predicted_class_name
43
  # import numpy as np
44
  # import time
 
3
  from tensorflow.keras.preprocessing import image
4
  # from tensorflow.keras.preprocessing.image import ImageDataGenerator
5
  import tensorflow as tf
6
+ gpus = tf.config.experimental.list_physical_devices('GPU')
7
+ if gpus:
8
+ try:
9
+ for gpu in gpus:
10
+ tf.config.experimental.set_memory_growth(gpu, True)
11
+ except RuntimeError as e:
12
+ # Memory growth must be set before GPUs have been initialized
13
+ print(e)
14
  import streamlit as st
15
+ # with tf.device('/cpu:0'):
16
  # Load the saved model
17
+ model = tf.keras.models.load_model('best_resnet152_model.h5')
18
 
19
  class_names = {0: '1099_Div', 1: '1099_Int', 2: 'Non_Form', 3: 'w_2', 4: 'w_3'}
20
  # print(class_names)
 
29
  img_array /= 255.0 # Rescale pixel values
30
 
31
  # Predict the class
32
+ # with tf.device('/cpu:0'):
33
+ start_time = time.time()
34
+ predictions = model.predict(img_array)
35
+ end_time = time.time()
36
+ predicted_class_index = np.argmax(predictions, axis=1)[0]
37
 
38
+ # Get the predicted class name
39
+ predicted_class_name = class_names[predicted_class_index]
40
+ print("Predicted class:", predicted_class_name)
41
+ print("Execution time: ", end_time - start_time)
42
  return predicted_class_name
43
  # import numpy as np
44
  # import time