MatchAnything / imcui /hloc /matchers /matchanything.py
XingyiHe's picture
init commit
3040ac4
import sys
from pathlib import Path
import numpy as np
import PIL
from PIL import Image
import cv2
import torch
import torch.nn.functional as F
import os
from .. import DEVICE, MODEL_REPO_ID, logger
from ..utils.base_model import BaseModel
sys.path.append(str(Path(__file__).parent / "../../third_party"))
sys.path.append(str(Path(__file__).parent / "../../third_party/MatchAnything"))
from MatchAnything.src.lightning.lightning_loftr import PL_LoFTR
from MatchAnything.src.config.default import get_cfg_defaults
class MatchAnything(BaseModel):
required_inputs = [
"image0",
"image1",
]
def _init(self, conf):
self.conf = conf
config = get_cfg_defaults()
if conf['model_name'] == 'matchanything_eloftr':
config_path = str(Path(__file__).parent / "../../third_party" / 'MatchAnything' / 'configs/models/eloftr_model.py')
config.merge_from_file(config_path)
# Config overwrite:
if config.LOFTR.COARSE.ROPE:
assert config.DATASET.NPE_NAME is not None
if config.DATASET.NPE_NAME is not None:
if config.DATASET.NPE_NAME == 'megadepth':
config.LOFTR.COARSE.NPE = [832, 832, conf['img_resize'], conf['img_resize']]
elif conf['model_name'] == 'matchanything_roma':
config_path = str(Path(__file__).parent / "../../third_party" / 'MatchAnything' / 'configs/models/roma_model.py')
config.merge_from_file(config_path)
print(f"*****************{DEVICE}, {str(DEVICE) == 'cpu'}**************************")
if str(DEVICE) == 'cpu':
config.LOFTR.FP16 = False
config.ROMA.MODEL.AMP = False
else:
raise NotImplementedError
config.METHOD = conf['model_name']
config.LOFTR.MATCH_COARSE.THR = conf["match_threshold"]
model_path = Path(__file__).parent / "../../third_party" / 'MatchAnything'/ 'weights' / "{}.ckpt".format(conf["model_name"])
self.net = PL_LoFTR(config, pretrained_ckpt=model_path, test_mode=True).matcher
self.net.eval().to(DEVICE)
logger.info(f"Loading {conf['model_name']} model done")
def _forward(self, data):
img0 = data["image0"].cpu().numpy().squeeze() * 255
img1 = data["image1"].cpu().numpy().squeeze() * 255
img0 = img0.transpose(1, 2, 0)
img1 = img1.transpose(1, 2, 0)
# Get original images:
img0, img1 = img0.astype("uint8"), img1.astype("uint8")
img0_size, img1_size = np.array(img0.shape[:2]), np.array(img1.shape[:2])
img0_gray, img1_gray = np.array(Image.fromarray(img0).convert("L")), np.array(Image.fromarray(img1).convert("L"))
(img0_gray, hw0_new, mask0), (img1_gray, hw1_new, mask1)= map(lambda x: resize(x, df=32), [img0_gray, img1_gray])
img0 = torch.from_numpy(img0_gray)[None][None] / 255.
img1 = torch.from_numpy(img1_gray)[None][None] / 255.
batch = {'image0': img0, 'image1': img1}
batch.update({'image0_rgb_origin': data['image0'], 'image1_rgb_origin': data['image1'], 'origin_img_size0': torch.from_numpy(img0_size)[None], 'origin_img_size1': torch.from_numpy(img1_size)[None]})
if mask0 is not None:
mask0 = torch.from_numpy(mask0).to(DEVICE)
mask1 = torch.from_numpy(mask1).to(DEVICE)
[ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
scale_factor=0.125,
mode='nearest',
recompute_scale_factor=False)[0].bool()
batch.update({"mask0": ts_mask_0[None], "mask1": ts_mask_1[None]})
batch = dict_to_cuda(batch, device=DEVICE)
self.net(batch)
mkpts0 = batch['mkpts0_f'].cpu()
mkpts1 = batch['mkpts1_f'].cpu()
mconf = batch['mconf'].cpu()
if self.conf['model_name'] == 'matchanything_eloftr':
mkpts0 *= torch.tensor(hw0_new)[[1,0]]
mkpts1 *= torch.tensor(hw1_new)[[1,0]]
pred = {
"keypoints0": mkpts0,
"keypoints1": mkpts1,
"mconf": mconf,
}
return pred
def resize(img, resize=None, df=8, padding=True):
w, h = img.shape[1], img.shape[0]
w_new, h_new = process_resize(w, h, resize=resize, df=df, resize_no_larger_than=False)
img_new = resize_image(img, (w_new, h_new), interp="pil_LANCZOS").astype('float32')
h_scale, w_scale = img.shape[0] / img_new.shape[0], img.shape[1] / img_new.shape[1]
mask = None
if padding:
img_new, mask = pad_bottom_right(img_new, max(h_new, w_new), ret_mask=True)
return img_new, [h_scale, w_scale], mask
def process_resize(w, h, resize=None, df=None, resize_no_larger_than=False):
if resize is not None:
assert(len(resize) > 0 and len(resize) <= 2)
if resize_no_larger_than and (max(h, w) <= max(resize)):
w_new, h_new = w, h
else:
if len(resize) == 1 and resize[0] > -1: # resize the larger side
scale = resize[0] / max(h, w)
w_new, h_new = int(round(w*scale)), int(round(h*scale))
elif len(resize) == 1 and resize[0] == -1:
w_new, h_new = w, h
else: # len(resize) == 2:
w_new, h_new = resize[0], resize[1]
else:
w_new, h_new = w, h
if df is not None:
w_new, h_new = map(lambda x: int(x // df * df), [w_new, h_new])
return w_new, h_new
def resize_image(image, size, interp):
if interp.startswith('cv2_'):
interp = getattr(cv2, 'INTER_'+interp[len('cv2_'):].upper())
h, w = image.shape[:2]
if interp == cv2.INTER_AREA and (w < size[0] or h < size[1]):
interp = cv2.INTER_LINEAR
resized = cv2.resize(image, size, interpolation=interp)
elif interp.startswith('pil_'):
interp = getattr(PIL.Image, interp[len('pil_'):].upper())
resized = PIL.Image.fromarray(image.astype(np.uint8))
resized = resized.resize(size, resample=interp)
resized = np.asarray(resized, dtype=image.dtype)
else:
raise ValueError(
f'Unknown interpolation {interp}.')
return resized
def pad_bottom_right(inp, pad_size, ret_mask=False):
assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
mask = None
if inp.ndim == 2:
padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
padded[:inp.shape[0], :inp.shape[1]] = inp
if ret_mask:
mask = np.zeros((pad_size, pad_size), dtype=bool)
mask[:inp.shape[0], :inp.shape[1]] = True
elif inp.ndim == 3:
padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
padded[:, :inp.shape[1], :inp.shape[2]] = inp
if ret_mask:
mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
mask[:, :inp.shape[1], :inp.shape[2]] = True
mask = mask[0]
else:
raise NotImplementedError()
return padded, mask
def dict_to_cuda(data_dict, device='cuda'):
data_dict_cuda = {}
for k, v in data_dict.items():
if isinstance(v, torch.Tensor):
data_dict_cuda[k] = v.to(device)
elif isinstance(v, dict):
data_dict_cuda[k] = dict_to_cuda(v, device=device)
elif isinstance(v, list):
data_dict_cuda[k] = list_to_cuda(v, device=device)
else:
data_dict_cuda[k] = v
return data_dict_cuda
def list_to_cuda(data_list, device='cuda'):
data_list_cuda = []
for obj in data_list:
if isinstance(obj, torch.Tensor):
data_list_cuda.append(obj.cuda())
elif isinstance(obj, dict):
data_list_cuda.append(dict_to_cuda(obj, device=device))
elif isinstance(obj, list):
data_list_cuda.append(list_to_cuda(obj, device=device))
else:
data_list_cuda.append(obj)
return data_list_cuda