import os import numpy as np import torch import torch.nn.functional as F from pytorch_lightning import LightningModule from cliport.tasks import cameras from cliport.utils import utils from cliport.models.core.attention import Attention from cliport.models.core.transport import Transport from cliport.models.streams.two_stream_attention import TwoStreamAttention from cliport.models.streams.two_stream_transport import TwoStreamTransport from cliport.models.streams.two_stream_attention import TwoStreamAttentionLat from cliport.models.streams.two_stream_transport import TwoStreamTransportLat import time import IPython class TransporterAgent(LightningModule): def __init__(self, name, cfg, train_ds, test_ds): super().__init__() utils.set_seed(0) self.automatic_optimization=False self.device_type = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # this is bad for PL :( self.name = name self.cfg = cfg self.train_loader = train_ds self.test_loader = test_ds self.train_ds = train_ds.dataset self.test_ds = test_ds.dataset self.name = name self.task = cfg['train']['task'] self.total_steps = 0 self.crop_size = 64 self.n_rotations = cfg['train']['n_rotations'] self.pix_size = 0.003125 self.in_shape = (320, 160, 6) self.cam_config = cameras.RealSenseD415.CONFIG self.bounds = np.array([[0.25, 0.75], [-0.5, 0.5], [0, 0.28]]) self.val_repeats = cfg['train']['val_repeats'] self.save_steps = cfg['train']['save_steps'] self._build_model() ## # reduce the number of parameters here ## self._optimizers = { 'attn': torch.optim.Adam(self.attention.parameters(), lr=self.cfg['train']['lr']), 'trans': torch.optim.Adam(self.transport.parameters(), lr=self.cfg['train']['lr']) } print("Agent: {}, Logging: {}".format(name, cfg['train']['log'])) def configure_optimizers(self): return self._optimizers def _build_model(self): self.attention = None self.transport = None raise NotImplementedError() def forward(self, x): raise NotImplementedError() def cross_entropy_with_logits(self, pred, labels, reduction='mean'): # Lucas found that both sum and mean work equally well x = (-labels.view(len(labels), -1) * F.log_softmax(pred.view(len(labels), -1), -1)) if reduction == 'sum': return x.sum() elif reduction == 'mean': return x.mean() else: raise NotImplementedError() def attn_forward(self, inp, softmax=True): inp_img = inp['inp_img'] output = self.attention.forward(inp_img, softmax=softmax) return output def attn_training_step(self, frame, backprop=True, compute_err=False): inp_img = frame['img'] p0, p0_theta = frame['p0'], frame['p0_theta'] inp = {'inp_img': inp_img} out = self.attn_forward(inp, softmax=False) return self.attn_criterion(backprop, compute_err, inp, out, p0, p0_theta) def attn_criterion(self, backprop, compute_err, inp, out, p, theta): # Get label. if type(theta) is torch.Tensor: theta = theta.detach().cpu().numpy() theta_i = theta / (2 * np.pi / self.attention.n_rotations) theta_i = np.int32(np.round(theta_i)) % self.attention.n_rotations inp_img = inp['inp_img'].float() label_size = inp_img.shape[:3] + (self.attention.n_rotations,) label = torch.zeros(label_size, dtype=torch.float, device=out.device) # remove this for-loop laters for idx, p_i in enumerate(p): label[idx, int(p_i[0]), int(p_i[1]), theta_i[idx]] = 1 label = label.permute((0, 3, 1, 2)).contiguous() # Get loss. loss = self.cross_entropy_with_logits(out, label) # Backpropagate. if backprop: attn_optim = self._optimizers['attn'] self.manual_backward(loss) attn_optim.step() attn_optim.zero_grad() # Pixel and Rotation error (not used anywhere). err = {} if compute_err: with torch.no_grad(): pick_conf = self.attn_forward(inp) pick_conf = pick_conf[0].permute(1,2,0) pick_conf = pick_conf.detach().cpu().numpy() p = p[0] theta = theta[0] # single batch argmax = np.argmax(pick_conf) argmax = np.unravel_index(argmax, shape=pick_conf.shape) p0_pix = argmax[:2] p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2]) err = { 'dist': np.linalg.norm(np.array(p.detach().cpu().numpy()) - p0_pix, ord=1), 'theta': np.absolute((theta - p0_theta) % np.pi) } return loss, err def trans_forward(self, inp, softmax=True): inp_img = inp['inp_img'] p0 = inp['p0'] output = self.transport.forward(inp_img, p0, softmax=softmax) return output def transport_training_step(self, frame, backprop=True, compute_err=False): inp_img = frame['img'].float() p0 = frame['p0'] p1, p1_theta = frame['p1'], frame['p1_theta'] inp = {'inp_img': inp_img, 'p0': p0} output = self.trans_forward(inp, softmax=False) err, loss = self.transport_criterion(backprop, compute_err, inp, output, p0, p1, p1_theta) return loss, err def transport_criterion(self, backprop, compute_err, inp, output, p, q, theta): s = time.time() if type(theta) is torch.Tensor: theta = theta.detach().cpu().numpy() itheta = theta / (2 * np.pi / self.transport.n_rotations) itheta = np.int32(np.round(itheta)) % self.transport.n_rotations # Get one-hot pixel label map. inp_img = inp['inp_img'] # label_size = inp_img.shape[:2] + (self.transport.n_rotations,) label_size = inp_img.shape[:3] + (self.transport.n_rotations,) label = torch.zeros(label_size, dtype=torch.float, device=output.device) # remove this for-loop laters q[:,0] = torch.clamp(q[:,0], 0, label.shape[1]-1) q[:,1] = torch.clamp(q[:,1], 0, label.shape[2]-1) for idx, q_i in enumerate(q): label[idx, int(q_i[0]), int(q_i[1]), itheta[idx]] = 1 label = label.permute((0, 3, 1, 2)).contiguous() # Get loss. loss = self.cross_entropy_with_logits(output, label) if backprop: transport_optim = self._optimizers['trans'] transport_optim.zero_grad() self.manual_backward(loss) transport_optim.step() # Pixel and Rotation error (not used anywhere). err = {} if compute_err: with torch.no_grad(): place_conf = self.trans_forward(inp) # pick the first batch place_conf = place_conf[0] q = q[0] theta = theta[0] place_conf = place_conf.permute(1, 2, 0) place_conf = place_conf.detach().cpu().numpy() argmax = np.argmax(place_conf) argmax = np.unravel_index(argmax, shape=place_conf.shape) p1_pix = argmax[:2] p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2]) err = { 'dist': np.linalg.norm(np.array(q.detach().cpu().numpy()) - p1_pix, ord=1), 'theta': np.absolute((theta - p1_theta) % np.pi) } self.transport.iters += 1 return err, loss def training_step(self, batch, batch_idx): self.attention.train() self.transport.train() frame, _ = batch self.start_time = time.time() # Get training losses. step = self.total_steps + 1 loss0, err0 = self.attn_training_step(frame) self.start_time = time.time() if isinstance(self.transport, Attention): loss1, err1 = self.attn_training_step(frame) else: loss1, err1 = self.transport_training_step(frame) total_loss = loss0 + loss1 self.total_steps = step self.start_time = time.time() self.log('tr/attn/loss', loss0) self.log('tr/trans/loss', loss1) self.log('tr/loss', total_loss) self.check_save_iteration() return dict( loss=total_loss, ) def check_save_iteration(self): global_step = self.total_steps if (global_step + 1) % 100 == 0: # save lastest checkpoint print(f"Saving last.ckpt Epoch: {self.trainer.current_epoch} | Global Step: {self.trainer.global_step}") self.save_last_checkpoint() def save_last_checkpoint(self): checkpoint_path = os.path.join(self.cfg['train']['train_dir'], 'checkpoints') ckpt_path = os.path.join(checkpoint_path, 'last.ckpt') self.trainer.save_checkpoint(ckpt_path) def validation_step(self, batch, batch_idx): self.attention.eval() self.transport.eval() loss0, loss1 = 0, 0 assert self.val_repeats >= 1 for i in range(self.val_repeats): frame, _ = batch l0, err0 = self.attn_training_step(frame, backprop=False, compute_err=True) loss0 += l0 if isinstance(self.transport, Attention): l1, err1 = self.attn_training_step(frame, backprop=False, compute_err=True) loss1 += l1 else: l1, err1 = self.transport_training_step(frame, backprop=False, compute_err=True) loss1 += l1 loss0 /= self.val_repeats loss1 /= self.val_repeats val_total_loss = loss0 + loss1 return dict( val_loss=val_total_loss, val_loss0=loss0, val_loss1=loss1, val_attn_dist_err=err0['dist'], val_attn_theta_err=err0['theta'], val_trans_dist_err=err1['dist'], val_trans_theta_err=err1['theta'], ) def training_epoch_end(self, all_outputs): super().training_epoch_end(all_outputs) utils.set_seed(self.trainer.current_epoch+1) def validation_epoch_end(self, all_outputs): mean_val_total_loss = np.mean([v['val_loss'].item() for v in all_outputs]) mean_val_loss0 = np.mean([v['val_loss0'].item() for v in all_outputs]) mean_val_loss1 = np.mean([v['val_loss1'].item() for v in all_outputs]) total_attn_dist_err = np.sum([v['val_attn_dist_err'].sum() for v in all_outputs]) total_attn_theta_err = np.sum([v['val_attn_theta_err'].sum() for v in all_outputs]) total_trans_dist_err = np.sum([v['val_trans_dist_err'].sum() for v in all_outputs]) total_trans_theta_err = np.sum([v['val_trans_theta_err'].sum() for v in all_outputs]) self.log('vl/attn/loss', mean_val_loss0) self.log('vl/trans/loss', mean_val_loss1) self.log('vl/loss', mean_val_total_loss) self.log('vl/total_attn_dist_err', total_attn_dist_err) self.log('vl/total_attn_theta_err', total_attn_theta_err) self.log('vl/total_trans_dist_err', total_trans_dist_err) self.log('vl/total_trans_theta_err', total_trans_theta_err) print("\nAttn Err - Dist: {:.2f}, Theta: {:.2f}".format(total_attn_dist_err, total_attn_theta_err)) print("Transport Err - Dist: {:.2f}, Theta: {:.2f}".format(total_trans_dist_err, total_trans_theta_err)) return dict( val_loss=mean_val_total_loss, val_loss0=mean_val_loss0, mean_val_loss1=mean_val_loss1, total_attn_dist_err=total_attn_dist_err, total_attn_theta_err=total_attn_theta_err, total_trans_dist_err=total_trans_dist_err, total_trans_theta_err=total_trans_theta_err, ) def act(self, obs, info=None, goal=None): # pylint: disable=unused-argument """Run inference and return best action given visual observations.""" # Get heightmap from RGB-D images. img = self.test_ds.get_image(obs) # Attention model forward pass. pick_inp = {'inp_img': img} pick_conf = self.attn_forward(pick_inp) pick_conf = pick_conf.detach().cpu().numpy() argmax = np.argmax(pick_conf) argmax = np.unravel_index(argmax, shape=pick_conf.shape) p0_pix = argmax[:2] p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2]) # Transport model forward pass. place_inp = {'inp_img': img, 'p0': p0_pix} place_conf = self.trans_forward(place_inp) place_conf = place_conf.permute(1, 2, 0) place_conf = place_conf.detach().cpu().numpy() argmax = np.argmax(place_conf) argmax = np.unravel_index(argmax, shape=place_conf.shape) p1_pix = argmax[:2] p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2]) # Pixels to end effector poses. hmap = img[:, :, 3] p0_xyz = utils.pix_to_xyz(p0_pix, hmap, self.bounds, self.pix_size) p1_xyz = utils.pix_to_xyz(p1_pix, hmap, self.bounds, self.pix_size) p0_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p0_theta)) p1_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p1_theta)) return { 'pose0': (np.asarray(p0_xyz), np.asarray(p0_xyzw)), 'pose1': (np.asarray(p1_xyz), np.asarray(p1_xyzw)), 'pick': p0_pix, 'place': p1_pix, } def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure, on_tpu, using_native_amp, using_lbfgs): pass def configure_optimizers(self): pass def train_dataloader(self): return self.train_loader def val_dataloader(self): return self.test_loader def load(self, model_path): self.load_state_dict(torch.load(model_path)['state_dict']) self.to(device=self.device_type) class OriginalTransporterAgent(TransporterAgent): def __init__(self, name, cfg, train_ds, test_ds): super().__init__(name, cfg, train_ds, test_ds) def _build_model(self): stream_fcn = 'plain_resnet' self.attention = Attention( stream_fcn=(stream_fcn, None), in_shape=self.in_shape, n_rotations=1, preprocess=utils.preprocess, cfg=self.cfg, device=self.device_type, ) self.transport = Transport( stream_fcn=(stream_fcn, None), in_shape=self.in_shape, n_rotations=self.n_rotations, crop_size=self.crop_size, preprocess=utils.preprocess, cfg=self.cfg, device=self.device_type, ) class ClipUNetTransporterAgent(TransporterAgent): def __init__(self, name, cfg, train_ds, test_ds): super().__init__(name, cfg, train_ds, test_ds) def _build_model(self): stream_fcn = 'clip_unet' self.attention = Attention( stream_fcn=(stream_fcn, None), in_shape=self.in_shape, n_rotations=1, preprocess=utils.preprocess, cfg=self.cfg, device=self.device_type, ) self.transport = Transport( stream_fcn=(stream_fcn, None), in_shape=self.in_shape, n_rotations=self.n_rotations, crop_size=self.crop_size, preprocess=utils.preprocess, cfg=self.cfg, device=self.device_type, ) class TwoStreamClipUNetTransporterAgent(TransporterAgent): def __init__(self, name, cfg, train_ds, test_ds): super().__init__(name, cfg, train_ds, test_ds) def _build_model(self): stream_one_fcn = 'plain_resnet' stream_two_fcn = 'clip_unet' self.attention = TwoStreamAttention( stream_fcn=(stream_one_fcn, stream_two_fcn), in_shape=self.in_shape, n_rotations=1, preprocess=utils.preprocess, cfg=self.cfg, device=self.device_type, ) self.transport = TwoStreamTransport( stream_fcn=(stream_one_fcn, stream_two_fcn), in_shape=self.in_shape, n_rotations=self.n_rotations, crop_size=self.crop_size, preprocess=utils.preprocess, cfg=self.cfg, device=self.device_type, ) class TwoStreamClipUNetLatTransporterAgent(TransporterAgent): def __init__(self, name, cfg, train_ds, test_ds): super().__init__(name, cfg, train_ds, test_ds) def _build_model(self): stream_one_fcn = 'plain_resnet_lat' stream_two_fcn = 'clip_unet_lat' self.attention = TwoStreamAttentionLat( stream_fcn=(stream_one_fcn, stream_two_fcn), in_shape=self.in_shape, n_rotations=1, preprocess=utils.preprocess, cfg=self.cfg, device=self.device_type, ) self.transport = TwoStreamTransportLat( stream_fcn=(stream_one_fcn, stream_two_fcn), in_shape=self.in_shape, n_rotations=self.n_rotations, crop_size=self.crop_size, preprocess=utils.preprocess, cfg=self.cfg, device=self.device_type, ) class TwoStreamClipWithoutSkipsTransporterAgent(TransporterAgent): def __init__(self, name, cfg, train_ds, test_ds): super().__init__(name, cfg, train_ds, test_ds) def _build_model(self): # TODO: lateral version stream_one_fcn = 'plain_resnet' stream_two_fcn = 'clip_woskip' self.attention = TwoStreamAttention( stream_fcn=(stream_one_fcn, stream_two_fcn), in_shape=self.in_shape, n_rotations=1, preprocess=utils.preprocess, cfg=self.cfg, device=self.device_type, ) self.transport = TwoStreamTransport( stream_fcn=(stream_one_fcn, stream_two_fcn), in_shape=self.in_shape, n_rotations=self.n_rotations, crop_size=self.crop_size, preprocess=utils.preprocess, cfg=self.cfg, device=self.device_type, ) class TwoStreamRN50BertUNetTransporterAgent(TransporterAgent): def __init__(self, name, cfg, train_ds, test_ds): super().__init__(name, cfg, train_ds, test_ds) def _build_model(self): # TODO: lateral version stream_one_fcn = 'plain_resnet' stream_two_fcn = 'rn50_bert_unet' self.attention = TwoStreamAttention( stream_fcn=(stream_one_fcn, stream_two_fcn), in_shape=self.in_shape, n_rotations=1, preprocess=utils.preprocess, cfg=self.cfg, device=self.device_type, ) self.transport = TwoStreamTransport( stream_fcn=(stream_one_fcn, stream_two_fcn), in_shape=self.in_shape, n_rotations=self.n_rotations, crop_size=self.crop_size, preprocess=utils.preprocess, cfg=self.cfg, device=self.device_type, )