import torch from torch import nn import cwm.eval.Segmentation.utils as utils from external.raft_interface import RAFTInterface class SegmentExtractor(nn.Module): def __init__(self, num_segments=1, iters=4, motion_range=4): self.num_segments = num_segments self.iters = iters self.motion_range = motion_range self.flow_interface = RAFTInterface() def get_sampling_dist(self, x, model): pass def forward(self, x, model, sampling_dist=None): """ x: [B, 3, H, W] a batch of imagenet-normalized image tensor model: a pre-trained CWM model """ if not sampling_dist: sampling_dist = self.get_sampling_dist(x, model) ## Step 1: sample initial moving and static locations from the distribution moving_pos = utils.sample_positions_from_dist(num=1, dist=sampling_dist) # [B, num, 2] static_pos = utils.sample_positions_from_dist(num=1, dist=(1-sampling_dist)) # [B, num, 2] movement = torch.randint(-self.motion_range, self.motion_range, (B, 1, 2)) # [B, 1, 2] ## Step 2: compute initial flow maps pred = model.get_counterfactual(x, mask, moving_pos=moving_pos, static_pos=static_pos, movement=movement) flow = self.flow_interface(x[:, :, 0], pred) ## Step 3: iterate to add more moving and static motions