rahulvenkk
app.py updated
6dfcb0f
raw
history blame
1.97 kB
import sys
import os
sys.path.insert(0, os.path.join(os.environ['HOME'], '.cache/torch', 'RAFT/core'))
from raft import RAFT
from utils import flow_viz
sys.path = sys.path[1:] # remove the first path to RAFT
import torch
from cwm.utils import imagenet_unnormalize
from torch import nn
import argparse
class Args:
model = os.path.join(os.environ['HOME'], '.cache/torch', 'RAFT/models/raft-sintel.pth')
small = False
path = None
mixed_precision = False
alternate_corr = False
def __iter__(self):
for attr, value in self.__dict__.items():
yield attr, value
class RAFTInterface(nn.Module):
def __init__(self):
super().__init__()
args = Args()
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model, map_location=torch.device('cpu')))
self.model = model.module
self.model.eval()
for p in self.model.parameters():
p.requires_grad = False
@staticmethod
def prepare_inputs(x):
# make sure the input is in the correct format for RAFT
if x.max() <= 1.0 and x.min() >= 0.: # range(0, 1)
x = x * 255.
elif x.min() < 0: # imagenet normalized:
x = imagenet_unnormalize(x)
x = x * 255.
return x
def forward(self, x0, x1, return_magnitude=False):
# x0: imagenet-normalized image 0 [B, C, H, W]
# x1: imagenet-normalized image 1 [B, C, H, W]
# ensure inputs in
x0 = self.prepare_inputs(x0)
x1 = self.prepare_inputs(x1)
with torch.no_grad():
_, flow_up = self.model(x0, x1, iters=20, test_mode=True)
if return_magnitude:
flow_magnitude = flow_up.norm(p=2, dim=1) # [B, H, W]
return flow_up, flow_magnitude
return flow_up
def viz(self, flow):
flow_rgb = flow_viz.flow_to_image(flow[0].permute(1,2,0).cpu().numpy())
return flow_rgb