|
|
|
|
|
|
|
from dust3r.training import get_args_parser, train, load_model |
|
from dust3r.pose_eval import eval_pose_estimation, pose_estimation_custom |
|
from dust3r.depth_eval import eval_mono_depth_estimation |
|
import croco.utils.misc as misc |
|
import torch |
|
import torch.backends.cudnn as cudnn |
|
import numpy as np |
|
import os |
|
|
|
if __name__ == '__main__': |
|
args = get_args_parser() |
|
args = args.parse_args() |
|
if args.mode.startswith('eval'): |
|
misc.init_distributed_mode(args) |
|
global_rank = misc.get_rank() |
|
world_size = misc.get_world_size() |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
device = torch.device(device) |
|
|
|
|
|
seed = args.seed + misc.get_rank() |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
cudnn.benchmark = args.cudnn_benchmark |
|
model, _ = load_model(args, device) |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
if args.mode == 'eval_pose': |
|
ate_mean, rpe_trans_mean, rpe_rot_mean, outfile_list, bug = eval_pose_estimation(args, model, device, save_dir=args.output_dir) |
|
print(f'ATE mean: {ate_mean}, RPE trans mean: {rpe_trans_mean}, RPE rot mean: {rpe_rot_mean}') |
|
if args.mode == 'eval_pose_custom': |
|
pose_estimation_custom(args, model, device, save_dir=args.output_dir) |
|
|
|
if args.mode == 'eval_depth': |
|
eval_mono_depth_estimation(args, model, device) |
|
|
|
exit(0) |
|
train(args) |
|
|