|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
import threading |
|
|
|
_palette = [ |
|
0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128, |
|
128, 128, 128, 128, 64, 0, 0, 191, 0, 0, 64, 128, 0, 191, 128, 0, 64, 0, |
|
128, 191, 0, 128, 64, 128, 128, 191, 128, 128, 0, 64, 0, 128, 64, 0, 0, |
|
191, 0, 128, 191, 0, 0, 64, 128, 128, 64, 128, 22, 22, 22, 23, 23, 23, 24, |
|
24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 29, 29, 29, 30, 30, |
|
30, 31, 31, 31, 32, 32, 32, 33, 33, 33, 34, 34, 34, 35, 35, 35, 36, 36, 36, |
|
37, 37, 37, 38, 38, 38, 39, 39, 39, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43, |
|
43, 43, 44, 44, 44, 45, 45, 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49, |
|
49, 50, 50, 50, 51, 51, 51, 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55, |
|
56, 56, 56, 57, 57, 57, 58, 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62, |
|
62, 62, 63, 63, 63, 64, 64, 64, 65, 65, 65, 66, 66, 66, 67, 67, 67, 68, 68, |
|
68, 69, 69, 69, 70, 70, 70, 71, 71, 71, 72, 72, 72, 73, 73, 73, 74, 74, 74, |
|
75, 75, 75, 76, 76, 76, 77, 77, 77, 78, 78, 78, 79, 79, 79, 80, 80, 80, 81, |
|
81, 81, 82, 82, 82, 83, 83, 83, 84, 84, 84, 85, 85, 85, 86, 86, 86, 87, 87, |
|
87, 88, 88, 88, 89, 89, 89, 90, 90, 90, 91, 91, 91, 92, 92, 92, 93, 93, 93, |
|
94, 94, 94, 95, 95, 95, 96, 96, 96, 97, 97, 97, 98, 98, 98, 99, 99, 99, |
|
100, 100, 100, 101, 101, 101, 102, 102, 102, 103, 103, 103, 104, 104, 104, |
|
105, 105, 105, 106, 106, 106, 107, 107, 107, 108, 108, 108, 109, 109, 109, |
|
110, 110, 110, 111, 111, 111, 112, 112, 112, 113, 113, 113, 114, 114, 114, |
|
115, 115, 115, 116, 116, 116, 117, 117, 117, 118, 118, 118, 119, 119, 119, |
|
120, 120, 120, 121, 121, 121, 122, 122, 122, 123, 123, 123, 124, 124, 124, |
|
125, 125, 125, 126, 126, 126, 127, 127, 127, 128, 128, 128, 129, 129, 129, |
|
130, 130, 130, 131, 131, 131, 132, 132, 132, 133, 133, 133, 134, 134, 134, |
|
135, 135, 135, 136, 136, 136, 137, 137, 137, 138, 138, 138, 139, 139, 139, |
|
140, 140, 140, 141, 141, 141, 142, 142, 142, 143, 143, 143, 144, 144, 144, |
|
145, 145, 145, 146, 146, 146, 147, 147, 147, 148, 148, 148, 149, 149, 149, |
|
150, 150, 150, 151, 151, 151, 152, 152, 152, 153, 153, 153, 154, 154, 154, |
|
155, 155, 155, 156, 156, 156, 157, 157, 157, 158, 158, 158, 159, 159, 159, |
|
160, 160, 160, 161, 161, 161, 162, 162, 162, 163, 163, 163, 164, 164, 164, |
|
165, 165, 165, 166, 166, 166, 167, 167, 167, 168, 168, 168, 169, 169, 169, |
|
170, 170, 170, 171, 171, 171, 172, 172, 172, 173, 173, 173, 174, 174, 174, |
|
175, 175, 175, 176, 176, 176, 177, 177, 177, 178, 178, 178, 179, 179, 179, |
|
180, 180, 180, 181, 181, 181, 182, 182, 182, 183, 183, 183, 184, 184, 184, |
|
185, 185, 185, 186, 186, 186, 187, 187, 187, 188, 188, 188, 189, 189, 189, |
|
190, 190, 190, 191, 191, 191, 192, 192, 192, 193, 193, 193, 194, 194, 194, |
|
195, 195, 195, 196, 196, 196, 197, 197, 197, 198, 198, 198, 199, 199, 199, |
|
200, 200, 200, 201, 201, 201, 202, 202, 202, 203, 203, 203, 204, 204, 204, |
|
205, 205, 205, 206, 206, 206, 207, 207, 207, 208, 208, 208, 209, 209, 209, |
|
210, 210, 210, 211, 211, 211, 212, 212, 212, 213, 213, 213, 214, 214, 214, |
|
215, 215, 215, 216, 216, 216, 217, 217, 217, 218, 218, 218, 219, 219, 219, |
|
220, 220, 220, 221, 221, 221, 222, 222, 222, 223, 223, 223, 224, 224, 224, |
|
225, 225, 225, 226, 226, 226, 227, 227, 227, 228, 228, 228, 229, 229, 229, |
|
230, 230, 230, 231, 231, 231, 232, 232, 232, 233, 233, 233, 234, 234, 234, |
|
235, 235, 235, 236, 236, 236, 237, 237, 237, 238, 238, 238, 239, 239, 239, |
|
240, 240, 240, 241, 241, 241, 242, 242, 242, 243, 243, 243, 244, 244, 244, |
|
245, 245, 245, 246, 246, 246, 247, 247, 247, 248, 248, 248, 249, 249, 249, |
|
250, 250, 250, 251, 251, 251, 252, 252, 252, 253, 253, 253, 254, 254, 254, |
|
255, 255, 255 |
|
] |
|
|
|
|
|
def label2colormap(label): |
|
|
|
m = label.astype(np.uint8) |
|
r, c = m.shape |
|
cmap = np.zeros((r, c, 3), dtype=np.uint8) |
|
cmap[:, :, 0] = (m & 1) << 7 | (m & 8) << 3 | (m & 64) >> 1 |
|
cmap[:, :, 1] = (m & 2) << 6 | (m & 16) << 2 | (m & 128) >> 2 |
|
cmap[:, :, 2] = (m & 4) << 5 | (m & 32) << 1 |
|
return cmap |
|
|
|
|
|
def one_hot_mask(mask, cls_num): |
|
if len(mask.size()) == 3: |
|
mask = mask.unsqueeze(1) |
|
indices = torch.arange(0, cls_num + 1, |
|
device=mask.device).view(1, -1, 1, 1) |
|
return (mask == indices).float() |
|
|
|
|
|
def masked_image(image, colored_mask, mask, alpha=0.7): |
|
mask = np.expand_dims(mask > 0, axis=0) |
|
mask = np.repeat(mask, 3, axis=0) |
|
show_img = (image * alpha + colored_mask * |
|
(1 - alpha)) * mask + image * (1 - mask) |
|
return show_img |
|
|
|
|
|
def save_image(image, path): |
|
im = Image.fromarray(np.uint8(image * 255.).transpose((1, 2, 0))) |
|
im.save(path) |
|
|
|
|
|
def _save_mask(mask, path, squeeze_idx=None): |
|
if squeeze_idx is not None: |
|
unsqueezed_mask = mask * 0 |
|
for idx in range(1, len(squeeze_idx)): |
|
obj_id = squeeze_idx[idx] |
|
mask_i = mask == idx |
|
unsqueezed_mask += (mask_i * obj_id).astype(np.uint8) |
|
mask = unsqueezed_mask |
|
mask = Image.fromarray(mask).convert('P') |
|
mask.putpalette(_palette) |
|
mask.save(path) |
|
|
|
|
|
def save_mask(mask_tensor, path, squeeze_idx=None): |
|
mask = mask_tensor.cpu().numpy().astype('uint8') |
|
threading.Thread(target=_save_mask, args=[mask, path, squeeze_idx]).start() |
|
|
|
|
|
def flip_tensor(tensor, dim=0): |
|
inv_idx = torch.arange(tensor.size(dim) - 1, -1, -1, |
|
device=tensor.device).long() |
|
tensor = tensor.index_select(dim, inv_idx) |
|
return tensor |
|
|
|
|
|
def shuffle_obj_mask(mask): |
|
|
|
bs, obj_num, _, _ = mask.size() |
|
new_masks = [] |
|
for idx in range(bs): |
|
now_mask = mask[idx] |
|
random_matrix = torch.eye(obj_num, device=mask.device) |
|
fg = random_matrix[1:][torch.randperm(obj_num - 1)] |
|
random_matrix = torch.cat([random_matrix[0:1], fg], dim=0) |
|
now_mask = torch.einsum('nm,nhw->mhw', random_matrix, now_mask) |
|
new_masks.append(now_mask) |
|
|
|
return torch.stack(new_masks, dim=0) |
|
|