Spaces:
Runtime error
Runtime error
""" | |
""" | |
import torch | |
class BoundingBox: | |
"""A rectangular bounding box determines the directed regions.""" | |
def __init__(self, resolution, box_ratios, margin=0.0): | |
""" | |
Args: | |
resolution(int): the resolution of the 2d spatial input | |
box_ratios(List[float]): | |
Returns: | |
""" | |
assert ( | |
box_ratios[1] < box_ratios[3] | |
), "the boundary top ratio should be less than bottom" | |
assert ( | |
box_ratios[0] < box_ratios[2] | |
), "the boundary left ratio should be less than right" | |
self.left = int((box_ratios[0] - margin) * resolution) | |
self.right = int((box_ratios[2] + margin) * resolution) | |
self.top = int((box_ratios[1] - margin) * resolution) | |
self.bottom = int((box_ratios[3] + margin) * resolution) | |
self.height = self.bottom - self.top | |
self.width = self.right - self.left | |
if self.height == 0: | |
self.height = 1 | |
if self.width == 0: | |
self.width = 1 | |
def sliced_tensor_in_bbox(self, tensor: torch.tensor) -> torch.tensor: | |
""" slicing the tensor with bbox area | |
Args: | |
tensor(torch.tensor): the original tensor in 4d | |
Returns: | |
(torch.tensor): the reduced tensor inside bbox | |
""" | |
return tensor[:, self.top : self.bottom, self.left : self.right, :] | |
def mask_reweight_out_bbox( | |
self, tensor: torch.tensor, value: float = 0.0 | |
) -> torch.tensor: | |
"""reweighting value outside bbox | |
Args: | |
tensor(torch.tensor): the original tensor in 4d | |
value(float): reweighting factor default with 0.0 | |
Returns: | |
(torch.tensor): the reweighted tensor | |
""" | |
mask = torch.ones_like(tensor).to(tensor.device) * value | |
mask[:, self.top : self.bottom, self.left : self.right, :] = 1 | |
return tensor * mask | |
def mask_reweight_in_bbox( | |
self, tensor: torch.tensor, value: float = 0.0 | |
) -> torch.tensor: | |
"""reweighting value within bbox | |
Args: | |
tensor(torch.tensor): the original tensor in 4d | |
value(float): reweighting factor default with 0.0 | |
Returns: | |
(torch.tensor): the reweighted tensor | |
""" | |
mask = torch.ones_like(tensor).to(tensor.device) | |
mask[:, self.top : self.bottom, self.left : self.right, :] = value | |
return tensor * mask | |
def __str__(self): | |
"""it prints Box(L:%d, R:%d, T:%d, B:%d) for better ingestion""" | |
return f"Box(L:{self.left}, R:{self.right}, T:{self.top}, B:{self.bottom})" | |
def __rerp__(self): | |
""" """ | |
return f"Box(L:{self.left}, R:{self.right}, T:{self.top}, B:{self.bottom})" | |
if __name__ == "__main__": | |
# Example: second quadrant | |
input_res = 32 | |
left = 0.0 | |
top = 0.0 | |
right = 0.5 | |
bottom = 0.5 | |
box_ratios = [left, top, right, bottom] | |
bbox = BoundingBox(resolution=input_res, box_ratios=box_ratios) | |
print(bbox) | |
# Box(L:0, R:16, T:0, B:16) | |