|
|
|
""" |
|
@File : visualizer.py |
|
@Time : 2022/04/05 11:39:33 |
|
@Author : Shilong Liu |
|
@Contact : [email protected] |
|
""" |
|
|
|
import datetime |
|
import os |
|
|
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
from matplotlib import transforms |
|
from matplotlib.collections import PatchCollection |
|
from matplotlib.patches import Polygon |
|
from pycocotools import mask as maskUtils |
|
|
|
|
|
def renorm( |
|
img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
) -> torch.FloatTensor: |
|
|
|
|
|
assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim() |
|
if img.dim() == 3: |
|
assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % ( |
|
img.size(0), |
|
str(img.size()), |
|
) |
|
img_perm = img.permute(1, 2, 0) |
|
mean = torch.Tensor(mean) |
|
std = torch.Tensor(std) |
|
img_res = img_perm * std + mean |
|
return img_res.permute(2, 0, 1) |
|
else: |
|
assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % ( |
|
img.size(1), |
|
str(img.size()), |
|
) |
|
img_perm = img.permute(0, 2, 3, 1) |
|
mean = torch.Tensor(mean) |
|
std = torch.Tensor(std) |
|
img_res = img_perm * std + mean |
|
return img_res.permute(0, 3, 1, 2) |
|
|
|
|
|
class ColorMap: |
|
def __init__(self, basergb=[255, 255, 0]): |
|
self.basergb = np.array(basergb) |
|
|
|
def __call__(self, attnmap): |
|
|
|
|
|
assert attnmap.dtype == np.uint8 |
|
h, w = attnmap.shape |
|
res = self.basergb.copy() |
|
res = res[None][None].repeat(h, 0).repeat(w, 1) |
|
attn1 = attnmap.copy()[..., None] |
|
res = np.concatenate((res, attn1), axis=-1).astype(np.uint8) |
|
return res |
|
|
|
|
|
def rainbow_text(x, y, ls, lc, **kw): |
|
""" |
|
Take a list of strings ``ls`` and colors ``lc`` and place them next to each |
|
other, with text ls[i] being shown in color lc[i]. |
|
|
|
This example shows how to do both vertical and horizontal text, and will |
|
pass all keyword arguments to plt.text, so you can set the font size, |
|
family, etc. |
|
""" |
|
t = plt.gca().transData |
|
fig = plt.gcf() |
|
plt.show() |
|
|
|
|
|
for s, c in zip(ls, lc): |
|
text = plt.text(x, y, " " + s + " ", color=c, transform=t, **kw) |
|
text.draw(fig.canvas.get_renderer()) |
|
ex = text.get_window_extent() |
|
t = transforms.offset_copy(text._transform, x=ex.width, units="dots") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class COCOVisualizer: |
|
def __init__(self, coco=None, tokenlizer=None) -> None: |
|
self.coco = coco |
|
|
|
def visualize(self, img, tgt, caption=None, dpi=180, savedir="vis"): |
|
""" |
|
img: tensor(3, H, W) |
|
tgt: make sure they are all on cpu. |
|
must have items: 'image_id', 'boxes', 'size' |
|
""" |
|
plt.figure(dpi=dpi) |
|
plt.rcParams["font.size"] = "5" |
|
ax = plt.gca() |
|
img = renorm(img).permute(1, 2, 0) |
|
|
|
|
|
ax.imshow(img) |
|
|
|
self.addtgt(tgt) |
|
|
|
if tgt is None: |
|
image_id = 0 |
|
elif "image_id" not in tgt: |
|
image_id = 0 |
|
else: |
|
image_id = tgt["image_id"] |
|
|
|
if caption is None: |
|
savename = "{}/{}-{}.png".format( |
|
savedir, int(image_id), str(datetime.datetime.now()).replace(" ", "-") |
|
) |
|
else: |
|
savename = "{}/{}-{}-{}.png".format( |
|
savedir, caption, int(image_id), str(datetime.datetime.now()).replace(" ", "-") |
|
) |
|
print("savename: {}".format(savename)) |
|
os.makedirs(os.path.dirname(savename), exist_ok=True) |
|
plt.savefig(savename) |
|
plt.close() |
|
|
|
def addtgt(self, tgt): |
|
""" """ |
|
if tgt is None or not "boxes" in tgt: |
|
ax = plt.gca() |
|
|
|
if "caption" in tgt: |
|
ax.set_title(tgt["caption"], wrap=True) |
|
|
|
ax.set_axis_off() |
|
return |
|
|
|
ax = plt.gca() |
|
H, W = tgt["size"] |
|
numbox = tgt["boxes"].shape[0] |
|
|
|
color = [] |
|
polygons = [] |
|
boxes = [] |
|
for box in tgt["boxes"].cpu(): |
|
unnormbbox = box * torch.Tensor([W, H, W, H]) |
|
unnormbbox[:2] -= unnormbbox[2:] / 2 |
|
[bbox_x, bbox_y, bbox_w, bbox_h] = unnormbbox.tolist() |
|
boxes.append([bbox_x, bbox_y, bbox_w, bbox_h]) |
|
poly = [ |
|
[bbox_x, bbox_y], |
|
[bbox_x, bbox_y + bbox_h], |
|
[bbox_x + bbox_w, bbox_y + bbox_h], |
|
[bbox_x + bbox_w, bbox_y], |
|
] |
|
np_poly = np.array(poly).reshape((4, 2)) |
|
polygons.append(Polygon(np_poly)) |
|
c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0] |
|
color.append(c) |
|
|
|
p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.1) |
|
ax.add_collection(p) |
|
p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2) |
|
ax.add_collection(p) |
|
|
|
if "strings_positive" in tgt and len(tgt["strings_positive"]) > 0: |
|
assert ( |
|
len(tgt["strings_positive"]) == numbox |
|
), f"{len(tgt['strings_positive'])} = {numbox}, " |
|
for idx, strlist in enumerate(tgt["strings_positive"]): |
|
cate_id = int(tgt["labels"][idx]) |
|
_string = str(cate_id) + ":" + " ".join(strlist) |
|
bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx] |
|
|
|
ax.text( |
|
bbox_x, |
|
bbox_y, |
|
_string, |
|
color="black", |
|
bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1}, |
|
) |
|
|
|
if "box_label" in tgt: |
|
assert len(tgt["box_label"]) == numbox, f"{len(tgt['box_label'])} = {numbox}, " |
|
for idx, bl in enumerate(tgt["box_label"]): |
|
_string = str(bl) |
|
bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx] |
|
|
|
ax.text( |
|
bbox_x, |
|
bbox_y, |
|
_string, |
|
color="black", |
|
bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1}, |
|
) |
|
|
|
if "caption" in tgt: |
|
ax.set_title(tgt["caption"], wrap=True) |
|
|
|
|
|
|
|
|
|
if "attn" in tgt: |
|
|
|
|
|
if isinstance(tgt["attn"], tuple): |
|
tgt["attn"] = [tgt["attn"]] |
|
for item in tgt["attn"]: |
|
attn_map, basergb = item |
|
attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-3) |
|
attn_map = (attn_map * 255).astype(np.uint8) |
|
cm = ColorMap(basergb) |
|
heatmap = cm(attn_map) |
|
ax.imshow(heatmap) |
|
ax.set_axis_off() |
|
|
|
def showAnns(self, anns, draw_bbox=False): |
|
""" |
|
Display the specified annotations. |
|
:param anns (array of object): annotations to display |
|
:return: None |
|
""" |
|
if len(anns) == 0: |
|
return 0 |
|
if "segmentation" in anns[0] or "keypoints" in anns[0]: |
|
datasetType = "instances" |
|
elif "caption" in anns[0]: |
|
datasetType = "captions" |
|
else: |
|
raise Exception("datasetType not supported") |
|
if datasetType == "instances": |
|
ax = plt.gca() |
|
ax.set_autoscale_on(False) |
|
polygons = [] |
|
color = [] |
|
for ann in anns: |
|
c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0] |
|
if "segmentation" in ann: |
|
if type(ann["segmentation"]) == list: |
|
|
|
for seg in ann["segmentation"]: |
|
poly = np.array(seg).reshape((int(len(seg) / 2), 2)) |
|
polygons.append(Polygon(poly)) |
|
color.append(c) |
|
else: |
|
|
|
t = self.imgs[ann["image_id"]] |
|
if type(ann["segmentation"]["counts"]) == list: |
|
rle = maskUtils.frPyObjects( |
|
[ann["segmentation"]], t["height"], t["width"] |
|
) |
|
else: |
|
rle = [ann["segmentation"]] |
|
m = maskUtils.decode(rle) |
|
img = np.ones((m.shape[0], m.shape[1], 3)) |
|
if ann["iscrowd"] == 1: |
|
color_mask = np.array([2.0, 166.0, 101.0]) / 255 |
|
if ann["iscrowd"] == 0: |
|
color_mask = np.random.random((1, 3)).tolist()[0] |
|
for i in range(3): |
|
img[:, :, i] = color_mask[i] |
|
ax.imshow(np.dstack((img, m * 0.5))) |
|
if "keypoints" in ann and type(ann["keypoints"]) == list: |
|
|
|
sks = np.array(self.loadCats(ann["category_id"])[0]["skeleton"]) - 1 |
|
kp = np.array(ann["keypoints"]) |
|
x = kp[0::3] |
|
y = kp[1::3] |
|
v = kp[2::3] |
|
for sk in sks: |
|
if np.all(v[sk] > 0): |
|
plt.plot(x[sk], y[sk], linewidth=3, color=c) |
|
plt.plot( |
|
x[v > 0], |
|
y[v > 0], |
|
"o", |
|
markersize=8, |
|
markerfacecolor=c, |
|
markeredgecolor="k", |
|
markeredgewidth=2, |
|
) |
|
plt.plot( |
|
x[v > 1], |
|
y[v > 1], |
|
"o", |
|
markersize=8, |
|
markerfacecolor=c, |
|
markeredgecolor=c, |
|
markeredgewidth=2, |
|
) |
|
|
|
if draw_bbox: |
|
[bbox_x, bbox_y, bbox_w, bbox_h] = ann["bbox"] |
|
poly = [ |
|
[bbox_x, bbox_y], |
|
[bbox_x, bbox_y + bbox_h], |
|
[bbox_x + bbox_w, bbox_y + bbox_h], |
|
[bbox_x + bbox_w, bbox_y], |
|
] |
|
np_poly = np.array(poly).reshape((4, 2)) |
|
polygons.append(Polygon(np_poly)) |
|
color.append(c) |
|
|
|
|
|
|
|
p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2) |
|
ax.add_collection(p) |
|
elif datasetType == "captions": |
|
for ann in anns: |
|
print(ann["caption"]) |
|
|