File size: 1,922 Bytes
a7ce59e
 
 
 
 
c6b5997
a7ce59e
d299b84
c6b5997
 
a7ce59e
 
 
c6b5997
a7ce59e
 
39a4527
c6b5997
 
 
 
 
 
 
 
 
 
a7ce59e
6cb57f7
c6b5997
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d299b84
 
 
 
 
 
 
 
c6b5997
d299b84
c6b5997
 
7aad423
6cb57f7
 
7aad423
 
6cb57f7
 
d299b84
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import json
from typing import Any, Dict, List

import tensorflow as tf
from tensorflow import keras
from huggingface_hub import from_pretrained_keras, hf_hub_download
import base64
import io
import numpy as np
from PIL import Image



class PreTrainedPipeline():
    def __init__(self, model_id: str):

        self.model = keras.models.load_model("./tf_model.h5")

    def __call__(self, inputs: "Image.Image")-> List[Dict[str, Any]]:

        with Image.open(inputs) as img:
            img = np.array(img)

        im = tf.image.resize(img, (128, 128))
        im = tf.cast(im, tf.float32) / 255.0
        pred_mask = model.predict(im[tf.newaxis, ...])
        pred_mask_arg = tf.argmax(pred_mask, axis=-1)

        labels = []

        binary_masks = {}
        mask_codes = {}


        for cls in range(pred_mask.shape[-1]):

            binary_masks[f"mask_{cls}"] = np.zeros(shape = (pred_mask.shape[1], pred_mask.shape[2]))
            for row in range(pred_mask_arg[0][1].get_shape().as_list()[0]):

                for col in range(pred_mask_arg[0][2].get_shape().as_list()[0]):

                    if pred_mask_arg[0][row][col] == cls:
                        
                        binary_masks[f"mask_{cls}"][row][col] = 1
                    else:
                        binary_masks[f"mask_{cls}"][row][col] = 0

            mask = binary_masks[f"mask_{cls}"]
            mask *= 255
            img = Image.fromarray(mask.astype(np.int8), mode="L")

            with io.BytesIO() as out:
                img.save(out, format="PNG")
                png_string = out.getvalue()
                mask = base64.b64encode(png_string).decode("utf-8")

            mask_codes[f"mask_{cls}"] = mask
    
    

            
            labels.append({
                "label": f"LABEL_{cls}",
                "mask": mask_codes[f"mask_{cls}"],
                "score": 1.0,
            })
        return labels