import sys import os sys.path.insert(0, os.path.join(os.environ['HOME'], '.cache/torch', 'RAFT/core')) from raft import RAFT from utils import flow_viz sys.path = sys.path[1:] # remove the first path to RAFT import torch from cwm.utils import imagenet_unnormalize from torch import nn import argparse class Args: model = os.path.join(os.environ['HOME'], '.cache/torch', 'RAFT/models/raft-sintel.pth') small = False path = None mixed_precision = False alternate_corr = False def __iter__(self): for attr, value in self.__dict__.items(): yield attr, value class RAFTInterface(nn.Module): def __init__(self): super().__init__() args = Args() model = torch.nn.DataParallel(RAFT(args)) model.load_state_dict(torch.load(args.model, map_location=torch.device('cpu'))) self.model = model.module self.model.eval() for p in self.model.parameters(): p.requires_grad = False @staticmethod def prepare_inputs(x): # make sure the input is in the correct format for RAFT if x.max() <= 1.0 and x.min() >= 0.: # range(0, 1) x = x * 255. elif x.min() < 0: # imagenet normalized: x = imagenet_unnormalize(x) x = x * 255. return x def forward(self, x0, x1, return_magnitude=False): # x0: imagenet-normalized image 0 [B, C, H, W] # x1: imagenet-normalized image 1 [B, C, H, W] # ensure inputs in x0 = self.prepare_inputs(x0) x1 = self.prepare_inputs(x1) with torch.no_grad(): _, flow_up = self.model(x0, x1, iters=20, test_mode=True) if return_magnitude: flow_magnitude = flow_up.norm(p=2, dim=1) # [B, H, W] return flow_up, flow_magnitude return flow_up def viz(self, flow): flow_rgb = flow_viz.flow_to_image(flow[0].permute(1,2,0).cpu().numpy()) return flow_rgb