semantic-segmentation / pipeline.py
merve's picture
merve HF staff
Upload pipeline.py
a7ce59e
raw
history blame
2.73 kB
import json
from typing import Any, Dict, List
import tensorflow as tf
from tensorflow import keras
from app.pipelines import Pipeline
from huggingface_hub import from_pretrained_keras, hf_hub_download
from PIL import Image
import base64
MODEL_FILENAME = "saved_model.pb"
CONFIG_FILENAME = "config.json"
class PreTrainedPipeline(Pipeline):
def __init__(self, model_id: str):
# Reload Keras SavedModel
self.model = from_pretrained_keras(model_id)
# Number of labels
self.num_labels = self.model.output_shape[1]
# Config is required to know the mapping to label.
config_file = hf_hub_download(model_id, filename=CONFIG_FILENAME)
with open(config_file) as config:
config = json.load(config)
self.id2label = config.get(
"id2label", {str(i): f"LABEL_{i}" for i in range(self.num_labels)}
)
def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]:
"""
Args:
inputs (:obj:`PIL.Image`):
The raw image representation as PIL.
No transformation made whatsoever from the input. Make all necessary transformations here.
Return:
A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX" (str), mask: "base64 encoding of the mask" (str), "score": float}
It is preferred if the returned list is in decreasing `score` order
"""
# Resize image to expected size
expected_input_size = self.model.input_shape
if expected_input_size[-1] == 1:
inputs = inputs.convert("L")
target_size = (expected_input_size[1], expected_input_size[2])
img = tf.image.resize(inputs, target_size)
img_array = tf.keras.preprocessing.image.img_to_array(img)
img_array = img_array[tf.newaxis, ...]
predictions = self.model.predict(img_array, axis=-1)
self.single_output_unit = (
self.model.output_shape[1] == 1
) # if there are two classes
if self.single_output_unit:
score = predictions[0][0]
labels = [
{"label": str(self.id2label["1"]), "score": float(score)},
{"label": str(self.id2label["0"]), "score": float(1 - score)},
]
else:
labels = [
{
"label": str(self.id2label[str(i)]),
"mask": base64.b64encode(predictions[0][i]),
"score": float(score),
}
for i, score in enumerate(predictions[0])
]
return sorted(labels, key=lambda tup: tup["score"], reverse=True)[: self.top_k]