samtrack / aot /utils /image.py
aikenml's picture
Upload folder using huggingface_hub
c985ba4
raw
history blame
6.08 kB
import numpy as np
from PIL import Image
import torch
import threading
_palette = [
0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128,
128, 128, 128, 128, 64, 0, 0, 191, 0, 0, 64, 128, 0, 191, 128, 0, 64, 0,
128, 191, 0, 128, 64, 128, 128, 191, 128, 128, 0, 64, 0, 128, 64, 0, 0,
191, 0, 128, 191, 0, 0, 64, 128, 128, 64, 128, 22, 22, 22, 23, 23, 23, 24,
24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 29, 29, 29, 30, 30,
30, 31, 31, 31, 32, 32, 32, 33, 33, 33, 34, 34, 34, 35, 35, 35, 36, 36, 36,
37, 37, 37, 38, 38, 38, 39, 39, 39, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43,
43, 43, 44, 44, 44, 45, 45, 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49,
49, 50, 50, 50, 51, 51, 51, 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55,
56, 56, 56, 57, 57, 57, 58, 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62,
62, 62, 63, 63, 63, 64, 64, 64, 65, 65, 65, 66, 66, 66, 67, 67, 67, 68, 68,
68, 69, 69, 69, 70, 70, 70, 71, 71, 71, 72, 72, 72, 73, 73, 73, 74, 74, 74,
75, 75, 75, 76, 76, 76, 77, 77, 77, 78, 78, 78, 79, 79, 79, 80, 80, 80, 81,
81, 81, 82, 82, 82, 83, 83, 83, 84, 84, 84, 85, 85, 85, 86, 86, 86, 87, 87,
87, 88, 88, 88, 89, 89, 89, 90, 90, 90, 91, 91, 91, 92, 92, 92, 93, 93, 93,
94, 94, 94, 95, 95, 95, 96, 96, 96, 97, 97, 97, 98, 98, 98, 99, 99, 99,
100, 100, 100, 101, 101, 101, 102, 102, 102, 103, 103, 103, 104, 104, 104,
105, 105, 105, 106, 106, 106, 107, 107, 107, 108, 108, 108, 109, 109, 109,
110, 110, 110, 111, 111, 111, 112, 112, 112, 113, 113, 113, 114, 114, 114,
115, 115, 115, 116, 116, 116, 117, 117, 117, 118, 118, 118, 119, 119, 119,
120, 120, 120, 121, 121, 121, 122, 122, 122, 123, 123, 123, 124, 124, 124,
125, 125, 125, 126, 126, 126, 127, 127, 127, 128, 128, 128, 129, 129, 129,
130, 130, 130, 131, 131, 131, 132, 132, 132, 133, 133, 133, 134, 134, 134,
135, 135, 135, 136, 136, 136, 137, 137, 137, 138, 138, 138, 139, 139, 139,
140, 140, 140, 141, 141, 141, 142, 142, 142, 143, 143, 143, 144, 144, 144,
145, 145, 145, 146, 146, 146, 147, 147, 147, 148, 148, 148, 149, 149, 149,
150, 150, 150, 151, 151, 151, 152, 152, 152, 153, 153, 153, 154, 154, 154,
155, 155, 155, 156, 156, 156, 157, 157, 157, 158, 158, 158, 159, 159, 159,
160, 160, 160, 161, 161, 161, 162, 162, 162, 163, 163, 163, 164, 164, 164,
165, 165, 165, 166, 166, 166, 167, 167, 167, 168, 168, 168, 169, 169, 169,
170, 170, 170, 171, 171, 171, 172, 172, 172, 173, 173, 173, 174, 174, 174,
175, 175, 175, 176, 176, 176, 177, 177, 177, 178, 178, 178, 179, 179, 179,
180, 180, 180, 181, 181, 181, 182, 182, 182, 183, 183, 183, 184, 184, 184,
185, 185, 185, 186, 186, 186, 187, 187, 187, 188, 188, 188, 189, 189, 189,
190, 190, 190, 191, 191, 191, 192, 192, 192, 193, 193, 193, 194, 194, 194,
195, 195, 195, 196, 196, 196, 197, 197, 197, 198, 198, 198, 199, 199, 199,
200, 200, 200, 201, 201, 201, 202, 202, 202, 203, 203, 203, 204, 204, 204,
205, 205, 205, 206, 206, 206, 207, 207, 207, 208, 208, 208, 209, 209, 209,
210, 210, 210, 211, 211, 211, 212, 212, 212, 213, 213, 213, 214, 214, 214,
215, 215, 215, 216, 216, 216, 217, 217, 217, 218, 218, 218, 219, 219, 219,
220, 220, 220, 221, 221, 221, 222, 222, 222, 223, 223, 223, 224, 224, 224,
225, 225, 225, 226, 226, 226, 227, 227, 227, 228, 228, 228, 229, 229, 229,
230, 230, 230, 231, 231, 231, 232, 232, 232, 233, 233, 233, 234, 234, 234,
235, 235, 235, 236, 236, 236, 237, 237, 237, 238, 238, 238, 239, 239, 239,
240, 240, 240, 241, 241, 241, 242, 242, 242, 243, 243, 243, 244, 244, 244,
245, 245, 245, 246, 246, 246, 247, 247, 247, 248, 248, 248, 249, 249, 249,
250, 250, 250, 251, 251, 251, 252, 252, 252, 253, 253, 253, 254, 254, 254,
255, 255, 255
]
def label2colormap(label):
m = label.astype(np.uint8)
r, c = m.shape
cmap = np.zeros((r, c, 3), dtype=np.uint8)
cmap[:, :, 0] = (m & 1) << 7 | (m & 8) << 3 | (m & 64) >> 1
cmap[:, :, 1] = (m & 2) << 6 | (m & 16) << 2 | (m & 128) >> 2
cmap[:, :, 2] = (m & 4) << 5 | (m & 32) << 1
return cmap
def one_hot_mask(mask, cls_num):
if len(mask.size()) == 3:
mask = mask.unsqueeze(1)
indices = torch.arange(0, cls_num + 1,
device=mask.device).view(1, -1, 1, 1)
return (mask == indices).float()
def masked_image(image, colored_mask, mask, alpha=0.7):
mask = np.expand_dims(mask > 0, axis=0)
mask = np.repeat(mask, 3, axis=0)
show_img = (image * alpha + colored_mask *
(1 - alpha)) * mask + image * (1 - mask)
return show_img
def save_image(image, path):
im = Image.fromarray(np.uint8(image * 255.).transpose((1, 2, 0)))
im.save(path)
def _save_mask(mask, path, squeeze_idx=None):
if squeeze_idx is not None:
unsqueezed_mask = mask * 0
for idx in range(1, len(squeeze_idx)):
obj_id = squeeze_idx[idx]
mask_i = mask == idx
unsqueezed_mask += (mask_i * obj_id).astype(np.uint8)
mask = unsqueezed_mask
mask = Image.fromarray(mask).convert('P')
mask.putpalette(_palette)
mask.save(path)
def save_mask(mask_tensor, path, squeeze_idx=None):
mask = mask_tensor.cpu().numpy().astype('uint8')
threading.Thread(target=_save_mask, args=[mask, path, squeeze_idx]).start()
def flip_tensor(tensor, dim=0):
inv_idx = torch.arange(tensor.size(dim) - 1, -1, -1,
device=tensor.device).long()
tensor = tensor.index_select(dim, inv_idx)
return tensor
def shuffle_obj_mask(mask):
bs, obj_num, _, _ = mask.size()
new_masks = []
for idx in range(bs):
now_mask = mask[idx]
random_matrix = torch.eye(obj_num, device=mask.device)
fg = random_matrix[1:][torch.randperm(obj_num - 1)]
random_matrix = torch.cat([random_matrix[0:1], fg], dim=0)
now_mask = torch.einsum('nm,nhw->mhw', random_matrix, now_mask)
new_masks.append(now_mask)
return torch.stack(new_masks, dim=0)