vahidrezanezhad commited on
Commit
98abfe5
1 Parent(s): 0b19517

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -2,12 +2,13 @@ import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
4
  import cv2
 
5
  from huggingface_hub import from_pretrained_keras
6
 
7
  def resize_image(img_in,input_height,input_width):
8
  return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST)
9
 
10
- def visualize_model_output(prediction):
11
  unique_classes = np.unique(prediction[:,:,0])
12
  rgb_colors = {'0' : [0, 0, 0],
13
  '1' : [255, 0, 0],
@@ -37,7 +38,13 @@ def visualize_model_output(prediction):
37
  output[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1]
38
  output[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2]
39
 
40
- return output.astype(np.uint8)
 
 
 
 
 
 
41
 
42
 
43
  def do_prediction(model_name, img):
@@ -169,7 +176,7 @@ def do_prediction(model_name, img):
169
  '''
170
  #prediction_true = prediction_true * -1
171
  #prediction_true = prediction_true + 1
172
- return "No numerical output", visualize_model_output(prediction_true)
173
 
174
  # catch-all (we should not reach this)
175
  case _:
 
2
  import tensorflow as tf
3
  import numpy as np
4
  import cv2
5
+ from PIL import Image
6
  from huggingface_hub import from_pretrained_keras
7
 
8
  def resize_image(img_in,input_height,input_width):
9
  return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST)
10
 
11
+ def visualize_model_output(prediction, img):
12
  unique_classes = np.unique(prediction[:,:,0])
13
  rgb_colors = {'0' : [0, 0, 0],
14
  '1' : [255, 0, 0],
 
38
  output[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1]
39
  output[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2]
40
 
41
+ output = output.astype(np.uint8)
42
+ im_pil_output = Image.fromarray(output)
43
+ im_pil = Image.fromarray(img)
44
+
45
+ im_pil.paste(im_pil_output, (0,0))
46
+
47
+ return im_pil.astype(np.uint8)
48
 
49
 
50
  def do_prediction(model_name, img):
 
176
  '''
177
  #prediction_true = prediction_true * -1
178
  #prediction_true = prediction_true + 1
179
+ return "No numerical output", visualize_model_output(prediction_true,img)
180
 
181
  # catch-all (we should not reach this)
182
  case _: