rahulvenkk
app.py updated
6dfcb0f
raw
history blame
1.37 kB
import torch
from torch import nn
import cwm.eval.Segmentation.utils as utils
from external.raft_interface import RAFTInterface
class SegmentExtractor(nn.Module):
def __init__(self, num_segments=1, iters=4, motion_range=4):
self.num_segments = num_segments
self.iters = iters
self.motion_range = motion_range
self.flow_interface = RAFTInterface()
def get_sampling_dist(self, x, model):
pass
def forward(self, x, model, sampling_dist=None):
"""
x: [B, 3, H, W] a batch of imagenet-normalized image tensor
model: a pre-trained CWM model
"""
if not sampling_dist:
sampling_dist = self.get_sampling_dist(x, model)
## Step 1: sample initial moving and static locations from the distribution
moving_pos = utils.sample_positions_from_dist(num=1, dist=sampling_dist) # [B, num, 2]
static_pos = utils.sample_positions_from_dist(num=1, dist=(1-sampling_dist)) # [B, num, 2]
movement = torch.randint(-self.motion_range, self.motion_range, (B, 1, 2)) # [B, 1, 2]
## Step 2: compute initial flow maps
pred = model.get_counterfactual(x, mask, moving_pos=moving_pos, static_pos=static_pos, movement=movement)
flow = self.flow_interface(x[:, :, 0], pred)
## Step 3: iterate to add more moving and static motions