GenSim / cliport /models /core /transport.py
LeroyWaa's picture
add gensim code
8fc2b4e
raw
history blame
4.17 kB
import numpy as np
import cliport.models as models
from cliport.utils import utils
import torch
import torch.nn as nn
import torch.nn.functional as F
class Transport(nn.Module):
def __init__(self, stream_fcn, in_shape, n_rotations, crop_size, preprocess, cfg, device):
"""Transport (a.k.a Place) module."""
super().__init__()
self.iters = 0
self.stream_fcn = stream_fcn
self.n_rotations = n_rotations
self.crop_size = crop_size # crop size must be N*16 (e.g. 96)
self.preprocess = preprocess
self.cfg = cfg
self.device = device
self.batchnorm = self.cfg['train']['batchnorm']
self.pad_size = int(self.crop_size / 2)
self.padding = np.zeros((3, 2), dtype=int)
self.padding[:2, :] = self.pad_size
in_shape = np.array(in_shape)
in_shape = tuple(in_shape)
self.in_shape = in_shape
# Crop before network (default from Transporters CoRL 2020).
self.kernel_shape = (self.crop_size, self.crop_size, self.in_shape[2])
if not hasattr(self, 'output_dim'):
self.output_dim = 3
if not hasattr(self, 'kernel_dim'):
self.kernel_dim = 3
self.rotator = utils.ImageRotator(self.n_rotations)
self._build_nets()
def _build_nets(self):
stream_one_fcn, _ = self.stream_fcn
model = models.names[stream_one_fcn]
self.key_resnet = model(self.in_shape, self.output_dim, self.cfg, self.device)
self.query_resnet = model(self.kernel_shape, self.kernel_dim, self.cfg, self.device)
print(f"Transport FCN: {stream_one_fcn}")
def correlate(self, in0, in1, softmax):
"""Correlate two input tensors."""
output = F.conv2d(in0, in1, padding=(self.pad_size, self.pad_size))
output = F.interpolate(output, size=(in0.shape[-2], in0.shape[-1]), mode='bilinear')
output = output[:,:,self.pad_size:-self.pad_size, self.pad_size:-self.pad_size]
output_shape = output.shape
# a hack around the batch size 1. The shape needs to tile back.
channel_num = in1.shape[0] // in0.shape[0]
output = torch.stack([output[i,i*channel_num:(i+1)*channel_num] for i in range(len(output))], dim=0)
if softmax:
output = output.reshape((len(output), -1))
output = F.softmax(output, dim=-1)
output = output.reshape(len(output),channel_num,output_shape[2],output_shape[3])
return output
def transport(self, in_tensor, crop):
logits = self.key_resnet(in_tensor)
kernel = self.query_resnet(crop)
return logits, kernel
def forward(self, inp_img, p, softmax=True):
"""Forward pass."""
img_unprocessed = np.pad(inp_img, self.padding, mode='constant')
input_data = img_unprocessed
in_shape = input_data.shape
if len(inp_shape) == 3:
inp_shape = (1,) + inp_shape
input_data = input_data.reshape(in_shape) # [B W H D]
in_tensor = torch.from_numpy(input_data).to(dtype=torch.float, device=self.device)
# Rotation pivot.
pv = p + self.pad_size # np.array([p[0], p[1]])
# Crop before network (default from Transporters CoRL 2020).
hcrop = self.pad_size
in_tensor = in_tensor.permute(0, 3, 1, 2) # [B D W H]
crop = in_tensor.repeat(self.n_rotations, 1, 1, 1)
crop = self.rotator(crop, pivot=pv)
crop = torch.cat(crop, dim=0)
crop = crop[:, :, pv[0]-hcrop:pv[0]+hcrop, pv[1]-hcrop:pv[1]+hcrop]
logits, kernel = self.transport(in_tensor, crop)
# TODO(Mohit): Crop after network. Broken for now.
# in_tensor = in_tensor.permute(0, 3, 1, 2)
# logits, crop = self.transport(in_tensor)
# crop = crop.repeat(self.n_rotations, 1, 1, 1)
# crop = self.rotator(crop, pivot=pv)
# crop = torch.cat(crop, dim=0)
# kernel = crop[:, :, pv[0]-hcrop:pv[0]+hcrop, pv[1]-hcrop:pv[1]+hcrop]
# kernel = crop[:, :, p[0]:(p[0] + self.crop_size), p[1]:(p[1] + self.crop_size)]
return self.correlate(logits, kernel, softmax)