Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import sys | |
import os | |
sys.path.insert(0, os.path.join(os.environ['HOME'], '.cache/torch', 'RAFT/core')) | |
from raft import RAFT | |
from utils import flow_viz | |
sys.path = sys.path[1:] # remove the first path to RAFT | |
import torch | |
from cwm.utils import imagenet_unnormalize | |
from torch import nn | |
import argparse | |
class Args: | |
model = os.path.join(os.environ['HOME'], '.cache/torch', 'RAFT/models/raft-sintel.pth') | |
small = False | |
path = None | |
mixed_precision = False | |
alternate_corr = False | |
def __iter__(self): | |
for attr, value in self.__dict__.items(): | |
yield attr, value | |
class RAFTInterface(nn.Module): | |
def __init__(self): | |
super().__init__() | |
args = Args() | |
model = torch.nn.DataParallel(RAFT(args)) | |
model.load_state_dict(torch.load(args.model, map_location=torch.device('cpu'))) | |
self.model = model.module | |
self.model.eval() | |
for p in self.model.parameters(): | |
p.requires_grad = False | |
def prepare_inputs(x): | |
# make sure the input is in the correct format for RAFT | |
if x.max() <= 1.0 and x.min() >= 0.: # range(0, 1) | |
x = x * 255. | |
elif x.min() < 0: # imagenet normalized: | |
x = imagenet_unnormalize(x) | |
x = x * 255. | |
return x | |
def forward(self, x0, x1, return_magnitude=False): | |
# x0: imagenet-normalized image 0 [B, C, H, W] | |
# x1: imagenet-normalized image 1 [B, C, H, W] | |
# ensure inputs in | |
x0 = self.prepare_inputs(x0) | |
x1 = self.prepare_inputs(x1) | |
with torch.no_grad(): | |
_, flow_up = self.model(x0, x1, iters=20, test_mode=True) | |
if return_magnitude: | |
flow_magnitude = flow_up.norm(p=2, dim=1) # [B, H, W] | |
return flow_up, flow_magnitude | |
return flow_up | |
def viz(self, flow): | |
flow_rgb = flow_viz.flow_to_image(flow[0].permute(1,2,0).cpu().numpy()) | |
return flow_rgb |