MatchAnything / imcui /hloc /localize_inloc.py
XingyiHe's picture
init commit
3040ac4
import argparse
import pickle
from pathlib import Path
import cv2
import h5py
import numpy as np
import pycolmap
import torch
from scipy.io import loadmat
from tqdm import tqdm
from . import logger
from .utils.parsers import names_to_pair, parse_retrieval
def interpolate_scan(scan, kp):
h, w, c = scan.shape
kp = kp / np.array([[w - 1, h - 1]]) * 2 - 1
assert np.all(kp > -1) and np.all(kp < 1)
scan = torch.from_numpy(scan).permute(2, 0, 1)[None]
kp = torch.from_numpy(kp)[None, None]
grid_sample = torch.nn.functional.grid_sample
# To maximize the number of points that have depth:
# do bilinear interpolation first and then nearest for the remaining points
interp_lin = grid_sample(scan, kp, align_corners=True, mode="bilinear")[0, :, 0]
interp_nn = torch.nn.functional.grid_sample(
scan, kp, align_corners=True, mode="nearest"
)[0, :, 0]
interp = torch.where(torch.isnan(interp_lin), interp_nn, interp_lin)
valid = ~torch.any(torch.isnan(interp), 0)
kp3d = interp.T.numpy()
valid = valid.numpy()
return kp3d, valid
def get_scan_pose(dataset_dir, rpath):
split_image_rpath = rpath.split("/")
floor_name = split_image_rpath[-3]
scan_id = split_image_rpath[-2]
image_name = split_image_rpath[-1]
building_name = image_name[:3]
path = Path(
dataset_dir,
"database/alignments",
floor_name,
f"transformations/{building_name}_trans_{scan_id}.txt",
)
with open(path) as f:
raw_lines = f.readlines()
P_after_GICP = np.array(
[
np.fromstring(raw_lines[7], sep=" "),
np.fromstring(raw_lines[8], sep=" "),
np.fromstring(raw_lines[9], sep=" "),
np.fromstring(raw_lines[10], sep=" "),
]
)
return P_after_GICP
def pose_from_cluster(dataset_dir, q, retrieved, feature_file, match_file, skip=None):
height, width = cv2.imread(str(dataset_dir / q)).shape[:2]
cx = 0.5 * width
cy = 0.5 * height
focal_length = 4032.0 * 28.0 / 36.0
all_mkpq = []
all_mkpr = []
all_mkp3d = []
all_indices = []
kpq = feature_file[q]["keypoints"].__array__()
num_matches = 0
for i, r in enumerate(retrieved):
kpr = feature_file[r]["keypoints"].__array__()
pair = names_to_pair(q, r)
m = match_file[pair]["matches0"].__array__()
v = m > -1
if skip and (np.count_nonzero(v) < skip):
continue
mkpq, mkpr = kpq[v], kpr[m[v]]
num_matches += len(mkpq)
scan_r = loadmat(Path(dataset_dir, r + ".mat"))["XYZcut"]
mkp3d, valid = interpolate_scan(scan_r, mkpr)
Tr = get_scan_pose(dataset_dir, r)
mkp3d = (Tr[:3, :3] @ mkp3d.T + Tr[:3, -1:]).T
all_mkpq.append(mkpq[valid])
all_mkpr.append(mkpr[valid])
all_mkp3d.append(mkp3d[valid])
all_indices.append(np.full(np.count_nonzero(valid), i))
all_mkpq = np.concatenate(all_mkpq, 0)
all_mkpr = np.concatenate(all_mkpr, 0)
all_mkp3d = np.concatenate(all_mkp3d, 0)
all_indices = np.concatenate(all_indices, 0)
cfg = {
"model": "SIMPLE_PINHOLE",
"width": width,
"height": height,
"params": [focal_length, cx, cy],
}
ret = pycolmap.absolute_pose_estimation(all_mkpq, all_mkp3d, cfg, 48.00)
ret["cfg"] = cfg
return ret, all_mkpq, all_mkpr, all_mkp3d, all_indices, num_matches
def main(dataset_dir, retrieval, features, matches, results, skip_matches=None):
assert retrieval.exists(), retrieval
assert features.exists(), features
assert matches.exists(), matches
retrieval_dict = parse_retrieval(retrieval)
queries = list(retrieval_dict.keys())
feature_file = h5py.File(features, "r", libver="latest")
match_file = h5py.File(matches, "r", libver="latest")
poses = {}
logs = {
"features": features,
"matches": matches,
"retrieval": retrieval,
"loc": {},
}
logger.info("Starting localization...")
for q in tqdm(queries):
db = retrieval_dict[q]
ret, mkpq, mkpr, mkp3d, indices, num_matches = pose_from_cluster(
dataset_dir, q, db, feature_file, match_file, skip_matches
)
poses[q] = (ret["qvec"], ret["tvec"])
logs["loc"][q] = {
"db": db,
"PnP_ret": ret,
"keypoints_query": mkpq,
"keypoints_db": mkpr,
"3d_points": mkp3d,
"indices_db": indices,
"num_matches": num_matches,
}
logger.info(f"Writing poses to {results}...")
with open(results, "w") as f:
for q in queries:
qvec, tvec = poses[q]
qvec = " ".join(map(str, qvec))
tvec = " ".join(map(str, tvec))
name = q.split("/")[-1]
f.write(f"{name} {qvec} {tvec}\n")
logs_path = f"{results}_logs.pkl"
logger.info(f"Writing logs to {logs_path}...")
with open(logs_path, "wb") as f:
pickle.dump(logs, f)
logger.info("Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_dir", type=Path, required=True)
parser.add_argument("--retrieval", type=Path, required=True)
parser.add_argument("--features", type=Path, required=True)
parser.add_argument("--matches", type=Path, required=True)
parser.add_argument("--results", type=Path, required=True)
parser.add_argument("--skip_matches", type=int)
args = parser.parse_args()
main(**args.__dict__)