File size: 2,343 Bytes
2aa6515
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import cv2

import numpy as np


# class to rgb colour pallet
color_dict = {
    0: (0, 0, 0),        # BG
    1: (239, 164, 0),    # EX
    2: (0, 186, 127),    # HE
    3: (0, 185, 255),    # SE
    4: (34, 80, 242),    # MA
    5: (73, 73, 73),     # OD
    6: (255, 255, 255),  # VB
}


def rgb_to_onehot(rgb_arr, color_dict):
    """
    Converts a rgb label map to onehot label map defined by color_dict
        Parameters:
            rgb_arr (array): rgb label mask with shape (H x W x 3)
            color_dict (dict): dictionary mapping of class to colour
        Returns:
            arr (array): onehot label map of shape (H x W x n_classes)
    """
    num_classes = len(color_dict)
    shape = rgb_arr.shape[:2]+(num_classes,)
    arr = np.zeros(shape, dtype=np.int8)
    for i, cls in enumerate(color_dict):
        arr[:, :, i] = np.all(rgb_arr.reshape((-1, 3)) == color_dict[i], axis=1).reshape(shape[:2])
    return arr


def onehot_to_rgb(onehot_arr, color_dict):
    """
    Converts an onehot label map to rgb label map defined by color_dict
        Parameters:
            onehot_arr (array): onehot label mask with shape (H x W x n_classes)
            color_dict (dict): dictionary mapping of class to colour
        Returns:
            arr (array): rgb label map of shape (H x W x 3)
    """
    shape = onehot_arr.shape[:2]
    mask = np.argmax(onehot_arr, axis=-1)
    arr = np.zeros(shape+(3,), dtype=np.uint8)
    for i, cls in enumerate(color_dict):
        arr = arr + np.tile(color_dict[cls], shape + (1,)) * (mask[..., None] == cls)
    return arr


def fix_pred_label(labels):
    """
    Post-processing fixes for the prediction of VB and BG label class,
    the Vitrous Body should be consistently spherical on a black background
        Parameters:
            labels (tensor): A 4-D array of predicted label
              with shape (batch x H x W x 7)
        Returns:
            fixed_labels (array): shape (batch x H x W x 7)
    """
    shape = labels.shape[1:-1]
    VB = np.uint8(cv2.circle(np.zeros(shape), (shape[0]//2, shape[1]//2), min(shape) // 2, 1, -1))[..., None]
    BG = np.uint8(VB == 0)

    VB = VB - np.sum(labels[..., 1:-1], axis=-1)[..., None]
    BG = np.broadcast_to(BG, VB.shape)

    fixed_labels = np.concatenate([BG, labels[..., 1:-1], VB], axis=-1)

    return fixed_labels