|
import json |
|
from typing import Any, Dict, List |
|
|
|
import tensorflow as tf |
|
from tensorflow import keras |
|
from app.pipelines import Pipeline |
|
from huggingface_hub import from_pretrained_keras, hf_hub_download |
|
from PIL import Image |
|
import base64 |
|
|
|
|
|
MODEL_FILENAME = "saved_model.pb" |
|
CONFIG_FILENAME = "config.json" |
|
|
|
|
|
class PreTrainedPipeline(Pipeline): |
|
def __init__(self, model_id: str): |
|
|
|
|
|
|
|
self.model = from_pretrained_keras(model_id) |
|
|
|
|
|
self.num_labels = self.model.output_shape[1] |
|
|
|
|
|
config_file = hf_hub_download(model_id, filename=CONFIG_FILENAME) |
|
with open(config_file) as config: |
|
config = json.load(config) |
|
|
|
self.id2label = config.get( |
|
"id2label", {str(i): f"LABEL_{i}" for i in range(self.num_labels)} |
|
) |
|
|
|
def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]: |
|
""" |
|
Args: |
|
inputs (:obj:`PIL.Image`): |
|
The raw image representation as PIL. |
|
No transformation made whatsoever from the input. Make all necessary transformations here. |
|
Return: |
|
A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX" (str), mask: "base64 encoding of the mask" (str), "score": float} |
|
It is preferred if the returned list is in decreasing `score` order |
|
""" |
|
|
|
|
|
expected_input_size = self.model.input_shape |
|
if expected_input_size[-1] == 1: |
|
inputs = inputs.convert("L") |
|
|
|
target_size = (expected_input_size[1], expected_input_size[2]) |
|
img = tf.image.resize(inputs, target_size) |
|
|
|
img_array = tf.keras.preprocessing.image.img_to_array(img) |
|
img_array = img_array[tf.newaxis, ...] |
|
|
|
predictions = self.model.predict(img_array, axis=-1) |
|
|
|
self.single_output_unit = ( |
|
self.model.output_shape[1] == 1 |
|
) |
|
|
|
if self.single_output_unit: |
|
score = predictions[0][0] |
|
labels = [ |
|
{"label": str(self.id2label["1"]), "score": float(score)}, |
|
{"label": str(self.id2label["0"]), "score": float(1 - score)}, |
|
] |
|
else: |
|
labels = [ |
|
{ |
|
"label": str(self.id2label[str(i)]), |
|
"mask": base64.b64encode(predictions[0][i]), |
|
"score": float(score), |
|
} |
|
for i, score in enumerate(predictions[0]) |
|
] |
|
return sorted(labels, key=lambda tup: tup["score"], reverse=True)[: self.top_k] |
|
|