GenSim / cliport /agents /transporter.py
LeroyWaa's picture
add gensim code
8fc2b4e
raw
history blame
19.5 kB
import os
import numpy as np
import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from cliport.tasks import cameras
from cliport.utils import utils
from cliport.models.core.attention import Attention
from cliport.models.core.transport import Transport
from cliport.models.streams.two_stream_attention import TwoStreamAttention
from cliport.models.streams.two_stream_transport import TwoStreamTransport
from cliport.models.streams.two_stream_attention import TwoStreamAttentionLat
from cliport.models.streams.two_stream_transport import TwoStreamTransportLat
import time
import IPython
class TransporterAgent(LightningModule):
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__()
utils.set_seed(0)
self.automatic_optimization=False
self.device_type = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # this is bad for PL :(
self.name = name
self.cfg = cfg
self.train_loader = train_ds
self.test_loader = test_ds
self.train_ds = train_ds.dataset
self.test_ds = test_ds.dataset
self.name = name
self.task = cfg['train']['task']
self.total_steps = 0
self.crop_size = 64
self.n_rotations = cfg['train']['n_rotations']
self.pix_size = 0.003125
self.in_shape = (320, 160, 6)
self.cam_config = cameras.RealSenseD415.CONFIG
self.bounds = np.array([[0.25, 0.75], [-0.5, 0.5], [0, 0.28]])
self.val_repeats = cfg['train']['val_repeats']
self.save_steps = cfg['train']['save_steps']
self._build_model()
##
# reduce the number of parameters here
##
self._optimizers = {
'attn': torch.optim.Adam(self.attention.parameters(), lr=self.cfg['train']['lr']),
'trans': torch.optim.Adam(self.transport.parameters(), lr=self.cfg['train']['lr'])
}
print("Agent: {}, Logging: {}".format(name, cfg['train']['log']))
def configure_optimizers(self):
return self._optimizers
def _build_model(self):
self.attention = None
self.transport = None
raise NotImplementedError()
def forward(self, x):
raise NotImplementedError()
def cross_entropy_with_logits(self, pred, labels, reduction='mean'):
# Lucas found that both sum and mean work equally well
x = (-labels.view(len(labels), -1) * F.log_softmax(pred.view(len(labels), -1), -1))
if reduction == 'sum':
return x.sum()
elif reduction == 'mean':
return x.mean()
else:
raise NotImplementedError()
def attn_forward(self, inp, softmax=True):
inp_img = inp['inp_img']
output = self.attention.forward(inp_img, softmax=softmax)
return output
def attn_training_step(self, frame, backprop=True, compute_err=False):
inp_img = frame['img']
p0, p0_theta = frame['p0'], frame['p0_theta']
inp = {'inp_img': inp_img}
out = self.attn_forward(inp, softmax=False)
return self.attn_criterion(backprop, compute_err, inp, out, p0, p0_theta)
def attn_criterion(self, backprop, compute_err, inp, out, p, theta):
# Get label.
if type(theta) is torch.Tensor:
theta = theta.detach().cpu().numpy()
theta_i = theta / (2 * np.pi / self.attention.n_rotations)
theta_i = np.int32(np.round(theta_i)) % self.attention.n_rotations
inp_img = inp['inp_img'].float()
label_size = inp_img.shape[:3] + (self.attention.n_rotations,)
label = torch.zeros(label_size, dtype=torch.float, device=out.device)
# remove this for-loop laters
for idx, p_i in enumerate(p):
label[idx, int(p_i[0]), int(p_i[1]), theta_i[idx]] = 1
label = label.permute((0, 3, 1, 2)).contiguous()
# Get loss.
loss = self.cross_entropy_with_logits(out, label)
# Backpropagate.
if backprop:
attn_optim = self._optimizers['attn']
self.manual_backward(loss)
attn_optim.step()
attn_optim.zero_grad()
# Pixel and Rotation error (not used anywhere).
err = {}
if compute_err:
with torch.no_grad():
pick_conf = self.attn_forward(inp)
pick_conf = pick_conf[0].permute(1,2,0)
pick_conf = pick_conf.detach().cpu().numpy()
p = p[0]
theta = theta[0]
# single batch
argmax = np.argmax(pick_conf)
argmax = np.unravel_index(argmax, shape=pick_conf.shape)
p0_pix = argmax[:2]
p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2])
err = {
'dist': np.linalg.norm(np.array(p.detach().cpu().numpy()) - p0_pix, ord=1),
'theta': np.absolute((theta - p0_theta) % np.pi)
}
return loss, err
def trans_forward(self, inp, softmax=True):
inp_img = inp['inp_img']
p0 = inp['p0']
output = self.transport.forward(inp_img, p0, softmax=softmax)
return output
def transport_training_step(self, frame, backprop=True, compute_err=False):
inp_img = frame['img'].float()
p0 = frame['p0']
p1, p1_theta = frame['p1'], frame['p1_theta']
inp = {'inp_img': inp_img, 'p0': p0}
output = self.trans_forward(inp, softmax=False)
err, loss = self.transport_criterion(backprop, compute_err, inp, output, p0, p1, p1_theta)
return loss, err
def transport_criterion(self, backprop, compute_err, inp, output, p, q, theta):
s = time.time()
if type(theta) is torch.Tensor:
theta = theta.detach().cpu().numpy()
itheta = theta / (2 * np.pi / self.transport.n_rotations)
itheta = np.int32(np.round(itheta)) % self.transport.n_rotations
# Get one-hot pixel label map.
inp_img = inp['inp_img']
# label_size = inp_img.shape[:2] + (self.transport.n_rotations,)
label_size = inp_img.shape[:3] + (self.transport.n_rotations,)
label = torch.zeros(label_size, dtype=torch.float, device=output.device)
# remove this for-loop laters
q[:,0] = torch.clamp(q[:,0], 0, label.shape[1]-1)
q[:,1] = torch.clamp(q[:,1], 0, label.shape[2]-1)
for idx, q_i in enumerate(q):
label[idx, int(q_i[0]), int(q_i[1]), itheta[idx]] = 1
label = label.permute((0, 3, 1, 2)).contiguous()
# Get loss.
loss = self.cross_entropy_with_logits(output, label)
if backprop:
transport_optim = self._optimizers['trans']
transport_optim.zero_grad()
self.manual_backward(loss)
transport_optim.step()
# Pixel and Rotation error (not used anywhere).
err = {}
if compute_err:
with torch.no_grad():
place_conf = self.trans_forward(inp)
# pick the first batch
place_conf = place_conf[0]
q = q[0]
theta = theta[0]
place_conf = place_conf.permute(1, 2, 0)
place_conf = place_conf.detach().cpu().numpy()
argmax = np.argmax(place_conf)
argmax = np.unravel_index(argmax, shape=place_conf.shape)
p1_pix = argmax[:2]
p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2])
err = {
'dist': np.linalg.norm(np.array(q.detach().cpu().numpy()) - p1_pix, ord=1),
'theta': np.absolute((theta - p1_theta) % np.pi)
}
self.transport.iters += 1
return err, loss
def training_step(self, batch, batch_idx):
self.attention.train()
self.transport.train()
frame, _ = batch
self.start_time = time.time()
# Get training losses.
step = self.total_steps + 1
loss0, err0 = self.attn_training_step(frame)
self.start_time = time.time()
if isinstance(self.transport, Attention):
loss1, err1 = self.attn_training_step(frame)
else:
loss1, err1 = self.transport_training_step(frame)
total_loss = loss0 + loss1
self.total_steps = step
self.start_time = time.time()
self.log('tr/attn/loss', loss0)
self.log('tr/trans/loss', loss1)
self.log('tr/loss', total_loss)
self.check_save_iteration()
return dict(
loss=total_loss,
)
def check_save_iteration(self):
global_step = self.total_steps
if (global_step + 1) % 100 == 0:
# save lastest checkpoint
print(f"Saving last.ckpt Epoch: {self.trainer.current_epoch} | Global Step: {self.trainer.global_step}")
self.save_last_checkpoint()
def save_last_checkpoint(self):
checkpoint_path = os.path.join(self.cfg['train']['train_dir'], 'checkpoints')
ckpt_path = os.path.join(checkpoint_path, 'last.ckpt')
self.trainer.save_checkpoint(ckpt_path)
def validation_step(self, batch, batch_idx):
self.attention.eval()
self.transport.eval()
loss0, loss1 = 0, 0
assert self.val_repeats >= 1
for i in range(self.val_repeats):
frame, _ = batch
l0, err0 = self.attn_training_step(frame, backprop=False, compute_err=True)
loss0 += l0
if isinstance(self.transport, Attention):
l1, err1 = self.attn_training_step(frame, backprop=False, compute_err=True)
loss1 += l1
else:
l1, err1 = self.transport_training_step(frame, backprop=False, compute_err=True)
loss1 += l1
loss0 /= self.val_repeats
loss1 /= self.val_repeats
val_total_loss = loss0 + loss1
return dict(
val_loss=val_total_loss,
val_loss0=loss0,
val_loss1=loss1,
val_attn_dist_err=err0['dist'],
val_attn_theta_err=err0['theta'],
val_trans_dist_err=err1['dist'],
val_trans_theta_err=err1['theta'],
)
def training_epoch_end(self, all_outputs):
super().training_epoch_end(all_outputs)
utils.set_seed(self.trainer.current_epoch+1)
def validation_epoch_end(self, all_outputs):
mean_val_total_loss = np.mean([v['val_loss'].item() for v in all_outputs])
mean_val_loss0 = np.mean([v['val_loss0'].item() for v in all_outputs])
mean_val_loss1 = np.mean([v['val_loss1'].item() for v in all_outputs])
total_attn_dist_err = np.sum([v['val_attn_dist_err'].sum() for v in all_outputs])
total_attn_theta_err = np.sum([v['val_attn_theta_err'].sum() for v in all_outputs])
total_trans_dist_err = np.sum([v['val_trans_dist_err'].sum() for v in all_outputs])
total_trans_theta_err = np.sum([v['val_trans_theta_err'].sum() for v in all_outputs])
self.log('vl/attn/loss', mean_val_loss0)
self.log('vl/trans/loss', mean_val_loss1)
self.log('vl/loss', mean_val_total_loss)
self.log('vl/total_attn_dist_err', total_attn_dist_err)
self.log('vl/total_attn_theta_err', total_attn_theta_err)
self.log('vl/total_trans_dist_err', total_trans_dist_err)
self.log('vl/total_trans_theta_err', total_trans_theta_err)
print("\nAttn Err - Dist: {:.2f}, Theta: {:.2f}".format(total_attn_dist_err, total_attn_theta_err))
print("Transport Err - Dist: {:.2f}, Theta: {:.2f}".format(total_trans_dist_err, total_trans_theta_err))
return dict(
val_loss=mean_val_total_loss,
val_loss0=mean_val_loss0,
mean_val_loss1=mean_val_loss1,
total_attn_dist_err=total_attn_dist_err,
total_attn_theta_err=total_attn_theta_err,
total_trans_dist_err=total_trans_dist_err,
total_trans_theta_err=total_trans_theta_err,
)
def act(self, obs, info=None, goal=None): # pylint: disable=unused-argument
"""Run inference and return best action given visual observations."""
# Get heightmap from RGB-D images.
img = self.test_ds.get_image(obs)
# Attention model forward pass.
pick_inp = {'inp_img': img}
pick_conf = self.attn_forward(pick_inp)
pick_conf = pick_conf.detach().cpu().numpy()
argmax = np.argmax(pick_conf)
argmax = np.unravel_index(argmax, shape=pick_conf.shape)
p0_pix = argmax[:2]
p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2])
# Transport model forward pass.
place_inp = {'inp_img': img, 'p0': p0_pix}
place_conf = self.trans_forward(place_inp)
place_conf = place_conf.permute(1, 2, 0)
place_conf = place_conf.detach().cpu().numpy()
argmax = np.argmax(place_conf)
argmax = np.unravel_index(argmax, shape=place_conf.shape)
p1_pix = argmax[:2]
p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2])
# Pixels to end effector poses.
hmap = img[:, :, 3]
p0_xyz = utils.pix_to_xyz(p0_pix, hmap, self.bounds, self.pix_size)
p1_xyz = utils.pix_to_xyz(p1_pix, hmap, self.bounds, self.pix_size)
p0_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p0_theta))
p1_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p1_theta))
return {
'pose0': (np.asarray(p0_xyz), np.asarray(p0_xyzw)),
'pose1': (np.asarray(p1_xyz), np.asarray(p1_xyzw)),
'pick': p0_pix,
'place': p1_pix,
}
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure, on_tpu, using_native_amp, using_lbfgs):
pass
def configure_optimizers(self):
pass
def train_dataloader(self):
return self.train_loader
def val_dataloader(self):
return self.test_loader
def load(self, model_path):
self.load_state_dict(torch.load(model_path)['state_dict'])
self.to(device=self.device_type)
class OriginalTransporterAgent(TransporterAgent):
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
stream_fcn = 'plain_resnet'
self.attention = Attention(
stream_fcn=(stream_fcn, None),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = Transport(
stream_fcn=(stream_fcn, None),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
class ClipUNetTransporterAgent(TransporterAgent):
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
stream_fcn = 'clip_unet'
self.attention = Attention(
stream_fcn=(stream_fcn, None),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = Transport(
stream_fcn=(stream_fcn, None),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
class TwoStreamClipUNetTransporterAgent(TransporterAgent):
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
stream_one_fcn = 'plain_resnet'
stream_two_fcn = 'clip_unet'
self.attention = TwoStreamAttention(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = TwoStreamTransport(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
class TwoStreamClipUNetLatTransporterAgent(TransporterAgent):
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
stream_one_fcn = 'plain_resnet_lat'
stream_two_fcn = 'clip_unet_lat'
self.attention = TwoStreamAttentionLat(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = TwoStreamTransportLat(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
class TwoStreamClipWithoutSkipsTransporterAgent(TransporterAgent):
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
# TODO: lateral version
stream_one_fcn = 'plain_resnet'
stream_two_fcn = 'clip_woskip'
self.attention = TwoStreamAttention(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = TwoStreamTransport(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
class TwoStreamRN50BertUNetTransporterAgent(TransporterAgent):
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
# TODO: lateral version
stream_one_fcn = 'plain_resnet'
stream_two_fcn = 'rn50_bert_unet'
self.attention = TwoStreamAttention(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = TwoStreamTransport(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)