liangfeng
clean up
b92a792
raw
history blame
2.26 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
from typing import Tuple
import numpy as np
import torch
from .clip import load as clip_load
from detectron2.utils.comm import get_local_rank, synchronize
def expand_box(
x1: float,
y1: float,
x2: float,
y2: float,
expand_ratio: float = 1.0,
max_h: int = None,
max_w: int = None,
):
cx = 0.5 * (x1 + x2)
cy = 0.5 * (y1 + y2)
w = x2 - x1
h = y2 - y1
w = w * expand_ratio
h = h * expand_ratio
box = [cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h]
if max_h is not None:
box[1] = max(0, box[1])
box[3] = min(max_h - 1, box[3])
if max_w is not None:
box[0] = max(0, box[0])
box[2] = min(max_w - 1, box[2])
return [int(b) for b in box]
def mask2box(mask: torch.Tensor):
# use naive way
row = torch.nonzero(mask.sum(dim=0))[:, 0]
if len(row) == 0:
return None
x1 = row.min()
x2 = row.max()
col = np.nonzero(mask.sum(dim=1))[:, 0]
y1 = col.min()
y2 = col.max()
return x1, y1, x2 + 1, y2 + 1
def crop_with_mask(
image: torch.Tensor,
mask: torch.Tensor,
bbox: torch.Tensor,
fill: Tuple[float, float, float] = (0, 0, 0),
expand_ratio: float = 1.0,
):
l, t, r, b = expand_box(*bbox, expand_ratio)
_, h, w = image.shape
l = max(l, 0)
t = max(t, 0)
r = min(r, w)
b = min(b, h)
new_image = torch.cat(
[image.new_full((1, b - t, r - l), fill_value=val) for val in fill]
)
# return image[:, t:b, l:r], mask[None, t:b, l:r]
return image[:, t:b, l:r] * mask[None, t:b, l:r] + (1 - mask[None, t:b, l:r]) * new_image, mask[None, t:b, l:r]
def build_clip_model(model: str, mask_prompt_depth: int = 0, frozen: bool = True):
rank = get_local_rank()
if rank == 0:
# download on rank 0 only
model, _ = clip_load(model, mask_prompt_depth=mask_prompt_depth, device="cpu")
synchronize()
if rank != 0:
model, _ = clip_load(model, mask_prompt_depth=mask_prompt_depth, device="cpu")
synchronize()
if frozen:
for param in model.parameters():
param.requires_grad = False
return model