ObjCtrl-2.5D / objctrl_2_5d /utils /objmask_util.py
wzhouxiff
init
38e3f9b
import os
from scipy.interpolate import griddata as interp_grid
from tqdm import tqdm
import numpy as np
import cv2
from PIL import Image
import torch
from packaging import version as pver
import torch.nn.functional as F
def trajectory_to_camera_poses_v1(traj, intrinsics, sample_n_frames, zc = 1.0):
if not isinstance(zc, list):
assert isinstance(zc, float) or isinstance(zc, int), 'zc should be a float or int or a list of float or int'
zc = [zc] * traj.shape[0]
zc = np.array(zc, dtype=traj.dtype)
xc = (traj[:, 0] - intrinsics[0, 2]) * zc / intrinsics[0, 0]
yc = (traj[:, 1] - intrinsics[0, 3]) * zc / intrinsics[0, 1]
first_frame_w2c = np.array([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]
], dtype=np.float32)
xw = xc[0]
yw = yc[0]
zw = zc[0]
# zw = 0
# print(f'zw: {zw}')
Tx = xc - xw
Ty = yc - yw
Tz = zc - zw
traj_w2c = [first_frame_w2c]
for i in range(1, sample_n_frames):
w2c_mat = np.array([
[1, 0, 0, Tx[i]],
[0, 1, 0, Ty[i]],
[0, 0, 1, Tz[i]],
[0, 0, 0, 1]
], dtype=first_frame_w2c.dtype)
traj_w2c.append(w2c_mat)
traj_w2c = np.stack(traj_w2c, axis=0)
return traj_w2c # [n_frame, 4, 4]
def Unprojected(image_curr: np.array,
depth_curr: np.array,
RTs: np.array,
H: int = 320, W: int = 576,
K: np.array = None,
dtype: np.dtype = np.float32):
'''
image_curr: [H, W, c], float, 0-1
depth_curr: [H, W], float32, in meters
RTs: [num_frames, 3, 4], the camera poses, w2c
'''
x, y = np.meshgrid(np.arange(W, dtype=dtype), np.arange(H, dtype=dtype), indexing='xy') # pixels
# ceter_depth = np.mean(depth_curr[cam.H//2-10:cam.H//2+10, cam.W//2-10:cam.W//2+10])
RTs = RTs.astype(dtype)
depth_curr = depth_curr.astype(dtype)
image_curr = image_curr.reshape(H*W, -1).astype(dtype) # [0, 1]
R0, T0 = RTs[0, :, :3], RTs[0, :, 3:4]
# new_pts_coord_world2 = image_curr
pts_coord_cam = np.matmul(np.linalg.inv(K), np.stack((x*depth_curr, y*depth_curr, 1*depth_curr), axis=0).reshape(3,-1))
new_pts_coord_world2 = (np.linalg.inv(R0).dot(pts_coord_cam) - np.linalg.inv(R0).dot(T0)) ## new_pts_coord_world2
new_pts_colors2 = image_curr ## new_pts_colors2
pts_coord_world, pts_colors = new_pts_coord_world2.copy(), new_pts_colors2.copy()
images = []
for i in tqdm(range(1, RTs.shape[0])):
R, T = RTs[i, :, :3], RTs[i, :, 3:4]
### Transform world to pixel
pts_coord_cam2 = R.dot(pts_coord_world) + T ### Same with c2w*world_coord (in homogeneous space)
pixel_coord_cam2 = np.matmul(K, pts_coord_cam2) #.reshape(3,H,W).transpose(1,2,0).astype(np.float32)
valid_idx = np.where(np.logical_and.reduce((pixel_coord_cam2[2]>0,
pixel_coord_cam2[0]/pixel_coord_cam2[2]>=0,
pixel_coord_cam2[0]/pixel_coord_cam2[2]<=W-1,
pixel_coord_cam2[1]/pixel_coord_cam2[2]>=0,
pixel_coord_cam2[1]/pixel_coord_cam2[2]<=H-1)))[0]
pixel_coord_cam2 = pixel_coord_cam2[:2, valid_idx]/pixel_coord_cam2[-1:, valid_idx]
# round_coord_cam2 = np.round(pixel_coord_cam2).astype(np.int32)
x, y = np.meshgrid(np.arange(W, dtype=dtype), np.arange(H, dtype=dtype), indexing='xy')
grid = np.stack((x,y), axis=-1).reshape(-1,2)
image2 = interp_grid(pixel_coord_cam2.transpose(1,0), pts_colors[valid_idx], grid, method='linear', fill_value=0).reshape(H,W,-1)
images.append(image2)
print(f'Total {len(images)} images, each image shape: {images[0].shape}')
return images
class Camera(object):
def __init__(self, entry):
fx, fy, cx, cy = entry[1:5]
self.fx = fx
self.fy = fy
self.cx = cx
self.cy = cy
w2c_mat = np.array(entry[7:]).reshape(3, 4)
w2c_mat_4x4 = np.eye(4)
w2c_mat_4x4[:3, :] = w2c_mat
self.w2c_mat = w2c_mat_4x4
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
def get_relative_pose(cam_params, zero_t_first_frame):
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
source_cam_c2w = abs_c2ws[0]
if zero_t_first_frame:
cam_to_origin = 0
else:
cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3])
target_cam_c2w = np.array([
[1, 0, 0, 0],
[0, 1, 0, -cam_to_origin],
[0, 0, 1, 0],
[0, 0, 0, 1]
])
abs2rel = target_cam_c2w @ abs_w2cs[0]
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
ret_poses = np.array(ret_poses, dtype=np.float32)
return ret_poses
def custom_meshgrid(*args):
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
if pver.parse(torch.__version__) < pver.parse('1.10'):
return torch.meshgrid(*args)
else:
return torch.meshgrid(*args, indexing='ij')
def ray_condition(K, c2w, H, W, device, flip_flag=None):
# c2w: B, V, 4, 4
# K: B, V, 4
B, V = K.shape[:2]
j, i = custom_meshgrid(
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
)
i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW]
j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW]
n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0
if n_flip > 0:
j_flip, i_flip = custom_meshgrid(
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype)
)
i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
i[:, flip_flag, ...] = i_flip
j[:, flip_flag, ...] = j_flip
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
zs = torch.ones_like(i) # [B, V, HxW]
xs = (i - cx) / fx * zs
ys = (j - cy) / fy * zs
zs = zs.expand_as(ys)
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, HW, 3
rays_o = c2w[..., :3, 3] # B, V, 3
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, HW, 3
# c2w @ dirctions
rays_dxo = torch.cross(rays_o, rays_d) # B, V, HW, 3
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
# plucker = plucker.permute(0, 1, 4, 2, 3)
return plucker, rays_o, rays_d
def RT2Plucker(RT, num_frames, sample_size, fx, fy, cx, cy):
'''
RT: [num_frames, 3, 4]
'''
cam_params = []
for i in range(num_frames):
cam_params.append(Camera([0, fx, fy, cx, cy, 0, 0, RT[i].reshape(-1)]))
print(cam_params[0].c2w_mat.shape)
intrinsics = np.asarray([[cam_param.fx * sample_size[1],
cam_param.fy * sample_size[0],
cam_param.cx * sample_size[1],
cam_param.cy * sample_size[0]]
for cam_param in cam_params], dtype=np.float32)
intrinsics = torch.as_tensor(intrinsics)[None]
print(intrinsics.shape)
relative_pose = True
zero_t_first_frame = True
use_flip = False
if relative_pose:
c2w_poses = get_relative_pose(cam_params, zero_t_first_frame)
else:
c2w_poses = np.array([cam_param.c2w_mat for cam_param in cam_params], dtype=np.float32)
c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4]
flip_flag = torch.zeros(num_frames, dtype=torch.bool, device=c2w.device)
plucker_embedding, rays_o, rays_d = ray_condition(intrinsics, c2w, sample_size[0], sample_size[1], device='cpu',
flip_flag=flip_flag)
plucker_embedding = plucker_embedding[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
plucker_embedding = plucker_embedding.permute(1, 0, 2, 3).contiguous() # 6, V, H, W
return plucker_embedding, rays_o, rays_d
def visualized_trajectories(images, trajectories, save_path, save_each_frame=False):
'''
images: [n_frame, H, W, 3], numpy, 0-255
trajectories: [[n_frame, 2]], list[numpy], x, y
save_path: str, end with .gif
'''
pil_image = []
H, W = images.shape[1], images.shape[2]
n_frame = images.shape[0]
for i in range(n_frame):
image = images[i].astype(np.uint8)
image = cv2.UMat(image)
# print(f'image: {image.shape} {image.dtype} {type(image)}')
#
for traj in trajectories:
line_data = traj[:i+1]
if len(line_data) == 1:
y = int(round(line_data[0][1]))
x = int(round(line_data[0][0]))
if y >= H:
y = H - 1
if line_data[0][0] >= W:
x = W - 1
# image[y, x] = [255, 0, 0]
cv2.circle(image, (x, y), 1, (0, 255, 0), 8)
else:
for j in range(1, len(line_data)):
x0, y0 = int(round(line_data[j-1][0])), int(round(line_data[j-1][1]))
x1, y1 = int(round(line_data[j][0])), int(round(line_data[j][1]))
if y0 >= H:
y0 = H - 1
if y1 >= H:
y1 = H - 1
if x0 >= W:
x0 = W - 1
if x1 >= W:
x1 = W - 1
if x0 != x1 or y0 != y1:
if j == len(line_data) - 1:
line_length = np.sqrt((x1 - x0) ** 2 + (y1 - y0) ** 2)
arrow_head_length = 10
tip_length = arrow_head_length / line_length
cv2.arrowedLine(image, (x0, y0), (x1, y1), (255, 0, 0), 6, tipLength=tip_length)
else:
cv2.line(image, (x0, y0), (x1, y1), (255, 0, 0), 6)
cv2.circle(image, (x, y), 1, (0, 255, 0), 8)
# cv2.circle(image, (x1, y1), 1, (0, 0, 255), 5)
image = cv2.UMat.get(image)
pil_image.append(Image.fromarray(image))
pil_image[0].save(save_path, save_all=True, append_images=pil_image[1:], loop=0, duration=100)
# save each frame
if save_each_frame:
img_save_root = save_path.replace('.gif', '')
os.makedirs(img_save_root, exist_ok=True)
for i, img in enumerate(pil_image):
img.save(os.path.join(img_save_root, f'{i:05d}.png'))
def roll_with_ignore_multidim(arr, shifts):
result = np.copy(arr)
for axis, shift in enumerate(shifts):
result = roll_with_ignore(result, shift, axis)
return result
def roll_with_ignore(arr, shift, axis):
result = np.zeros_like(arr)
if shift > 0:
result[tuple(slice(shift, None) if i == axis else slice(None) for i in range(arr.ndim))] = \
arr[tuple(slice(None, -shift) if i == axis else slice(None) for i in range(arr.ndim))]
elif shift < 0:
result[tuple(slice(None, shift) if i == axis else slice(None) for i in range(arr.ndim))] = \
arr[tuple(slice(-shift, None) if i == axis else slice(None) for i in range(arr.ndim))]
else:
result = arr
return result
def dilate_mask_pytorch(mask, kernel_size=2):
'''
mask: torch.Tensor, shape [b, c, h, w]
kernel_size: int
'''
# Define a smaller kernel for the dilation
kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=mask.dtype, device=mask.device)
# Perform the dilation operation
dilated_mask = F.conv2d(mask, kernel, padding=kernel_size//2)
# Ensure the output is still a binary mask (0 and 1)
dilated_mask = (dilated_mask > 0).to(mask.dtype).to(mask.device)
return dilated_mask
def smooth_mask(mask, kernel_size=3):
'''
mask: torch.Tensor, shape [b, c, h, w]
kernel_size: int
'''
# Define a Gaussian kernel for smoothing
kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=mask.dtype, device=mask.device) / (kernel_size * kernel_size)
# Perform the smoothing operation
smoothed_mask = F.conv2d(mask, kernel, padding=kernel_size//2)
# Ensure the output is still a binary mask (0 and 1)
smoothed_mask = (smoothed_mask > 0.5).to(mask.dtype).to(mask.device)
return smoothed_mask