merve HF staff commited on
Commit
6cb57f7
·
1 Parent(s): eb496ad

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +19 -21
pipeline.py CHANGED
@@ -46,36 +46,34 @@ class PreTrainedPipeline(Pipeline):
46
  # Resize image to expected size
47
 
48
  expected_input_size = self.model.input_shape
 
 
 
 
49
  if expected_input_size[-1] == 1:
50
  inputs = inputs.convert("L")
51
 
 
 
52
  target_size = (expected_input_size[1], expected_input_size[2])
53
  img = tf.image.resize(inputs, target_size)
54
-
55
  img_array = tf.keras.preprocessing.image.img_to_array(img)
56
  img_array = img_array[tf.newaxis, ...]
57
-
58
- predictions = self.model.predict(img_array, axis=-1)
59
 
60
  self.single_output_unit = (
61
  self.model.output_shape[1] == 1
62
  ) # if there are two classes
63
 
64
-
65
- if self.single_output_unit:
66
- score = predictions[0][0]
67
- labels = [{"label":"pet", "score":1.0}, {"label":"other", "score":1.0}]
68
- #labels = [
69
- # {"label": str(self.id2label["1"]), "score": float(score)},
70
- # {"label": str(self.id2label["0"]), "score": float(1 - score)},
71
- ]
72
- else:
73
- labels = [
74
- {
75
- "label": str(self.id2label[str(i)]),
76
- "mask": base64.b64encode(predictions[0][i]),
77
- "score": float(score),
78
- }
79
- for i, score in enumerate(predictions[0])
80
- ]
81
- return sorted(labels, key=lambda tup: tup["score"], reverse=True)[: self.top_k]
 
46
  # Resize image to expected size
47
 
48
  expected_input_size = self.model.input_shape
49
+
50
+
51
+ with Image.open(inputs) as im:
52
+ inputs = np.array(im)
53
  if expected_input_size[-1] == 1:
54
  inputs = inputs.convert("L")
55
 
56
+
57
+
58
  target_size = (expected_input_size[1], expected_input_size[2])
59
  img = tf.image.resize(inputs, target_size)
60
+
61
  img_array = tf.keras.preprocessing.image.img_to_array(img)
62
  img_array = img_array[tf.newaxis, ...]
63
+
64
+ predictions = self.model.predict(img_array)
65
 
66
  self.single_output_unit = (
67
  self.model.output_shape[1] == 1
68
  ) # if there are two classes
69
 
70
+ labels = []
71
+ for i in enumerate(predictions):
72
+
73
+ labels.append({
74
+ "label": str(i[0]),
75
+ "mask": base64.b64encode(i[1]),
76
+ "score": 1.0,
77
+ })
78
+
79
+ return sorted(labels)