File size: 1,971 Bytes
6dfcb0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import sys
import os
sys.path.insert(0, os.path.join(os.environ['HOME'], '.cache/torch', 'RAFT/core'))
from raft import RAFT
from utils import flow_viz
sys.path = sys.path[1:] # remove the first path to RAFT
import torch
from cwm.utils import imagenet_unnormalize
from torch import nn
import argparse


class Args:
    model = os.path.join(os.environ['HOME'], '.cache/torch', 'RAFT/models/raft-sintel.pth')
    small = False
    path = None
    mixed_precision = False
    alternate_corr = False

    def __iter__(self):
        for attr, value in self.__dict__.items():
            yield attr, value

class RAFTInterface(nn.Module):
    def __init__(self):
        super().__init__()
        args = Args()
        model = torch.nn.DataParallel(RAFT(args))
        model.load_state_dict(torch.load(args.model, map_location=torch.device('cpu')))
        self.model = model.module
        self.model.eval()

        for p in self.model.parameters():
            p.requires_grad = False

    @staticmethod
    def prepare_inputs(x):
        # make sure the input is in the correct format for RAFT
        if x.max() <= 1.0 and x.min() >= 0.: # range(0, 1)
            x = x * 255.
        elif x.min() < 0: # imagenet normalized:
            x = imagenet_unnormalize(x)
            x = x * 255.

        return x

    def forward(self, x0, x1, return_magnitude=False):
        # x0: imagenet-normalized image 0 [B, C, H, W]
        # x1: imagenet-normalized image 1 [B, C, H, W]

        # ensure inputs in
        x0 = self.prepare_inputs(x0)
        x1 = self.prepare_inputs(x1)
        with torch.no_grad():
            _, flow_up = self.model(x0, x1, iters=20, test_mode=True)

        if return_magnitude:
            flow_magnitude = flow_up.norm(p=2, dim=1)  # [B, H, W]
            return flow_up, flow_magnitude

        return flow_up

    def viz(self, flow):
        flow_rgb = flow_viz.flow_to_image(flow[0].permute(1,2,0).cpu().numpy())
        return flow_rgb