Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from pytorch_lightning import LightningModule | |
from cliport.tasks import cameras | |
from cliport.utils import utils | |
from cliport.models.core.attention import Attention | |
from cliport.models.core.transport import Transport | |
from cliport.models.streams.two_stream_attention import TwoStreamAttention | |
from cliport.models.streams.two_stream_transport import TwoStreamTransport | |
from cliport.models.streams.two_stream_attention import TwoStreamAttentionLat | |
from cliport.models.streams.two_stream_transport import TwoStreamTransportLat | |
import time | |
import IPython | |
class TransporterAgent(LightningModule): | |
def __init__(self, name, cfg, train_ds, test_ds): | |
super().__init__() | |
utils.set_seed(0) | |
self.automatic_optimization=False | |
self.device_type = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # this is bad for PL :( | |
self.name = name | |
self.cfg = cfg | |
self.train_loader = train_ds | |
self.test_loader = test_ds | |
self.train_ds = train_ds.dataset | |
self.test_ds = test_ds.dataset | |
self.name = name | |
self.task = cfg['train']['task'] | |
self.total_steps = 0 | |
self.crop_size = 64 | |
self.n_rotations = cfg['train']['n_rotations'] | |
self.pix_size = 0.003125 | |
self.in_shape = (320, 160, 6) | |
self.cam_config = cameras.RealSenseD415.CONFIG | |
self.bounds = np.array([[0.25, 0.75], [-0.5, 0.5], [0, 0.28]]) | |
self.val_repeats = cfg['train']['val_repeats'] | |
self.save_steps = cfg['train']['save_steps'] | |
self._build_model() | |
## | |
# reduce the number of parameters here | |
## | |
self._optimizers = { | |
'attn': torch.optim.Adam(self.attention.parameters(), lr=self.cfg['train']['lr']), | |
'trans': torch.optim.Adam(self.transport.parameters(), lr=self.cfg['train']['lr']) | |
} | |
print("Agent: {}, Logging: {}".format(name, cfg['train']['log'])) | |
def configure_optimizers(self): | |
return self._optimizers | |
def _build_model(self): | |
self.attention = None | |
self.transport = None | |
raise NotImplementedError() | |
def forward(self, x): | |
raise NotImplementedError() | |
def cross_entropy_with_logits(self, pred, labels, reduction='mean'): | |
# Lucas found that both sum and mean work equally well | |
x = (-labels.view(len(labels), -1) * F.log_softmax(pred.view(len(labels), -1), -1)) | |
if reduction == 'sum': | |
return x.sum() | |
elif reduction == 'mean': | |
return x.mean() | |
else: | |
raise NotImplementedError() | |
def attn_forward(self, inp, softmax=True): | |
inp_img = inp['inp_img'] | |
output = self.attention.forward(inp_img, softmax=softmax) | |
return output | |
def attn_training_step(self, frame, backprop=True, compute_err=False): | |
inp_img = frame['img'] | |
p0, p0_theta = frame['p0'], frame['p0_theta'] | |
inp = {'inp_img': inp_img} | |
out = self.attn_forward(inp, softmax=False) | |
return self.attn_criterion(backprop, compute_err, inp, out, p0, p0_theta) | |
def attn_criterion(self, backprop, compute_err, inp, out, p, theta): | |
# Get label. | |
if type(theta) is torch.Tensor: | |
theta = theta.detach().cpu().numpy() | |
theta_i = theta / (2 * np.pi / self.attention.n_rotations) | |
theta_i = np.int32(np.round(theta_i)) % self.attention.n_rotations | |
inp_img = inp['inp_img'].float() | |
label_size = inp_img.shape[:3] + (self.attention.n_rotations,) | |
label = torch.zeros(label_size, dtype=torch.float, device=out.device) | |
# remove this for-loop laters | |
for idx, p_i in enumerate(p): | |
label[idx, int(p_i[0]), int(p_i[1]), theta_i[idx]] = 1 | |
label = label.permute((0, 3, 1, 2)).contiguous() | |
# Get loss. | |
loss = self.cross_entropy_with_logits(out, label) | |
# Backpropagate. | |
if backprop: | |
attn_optim = self._optimizers['attn'] | |
self.manual_backward(loss) | |
attn_optim.step() | |
attn_optim.zero_grad() | |
# Pixel and Rotation error (not used anywhere). | |
err = {} | |
if compute_err: | |
with torch.no_grad(): | |
pick_conf = self.attn_forward(inp) | |
pick_conf = pick_conf[0].permute(1,2,0) | |
pick_conf = pick_conf.detach().cpu().numpy() | |
p = p[0] | |
theta = theta[0] | |
# single batch | |
argmax = np.argmax(pick_conf) | |
argmax = np.unravel_index(argmax, shape=pick_conf.shape) | |
p0_pix = argmax[:2] | |
p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2]) | |
err = { | |
'dist': np.linalg.norm(np.array(p.detach().cpu().numpy()) - p0_pix, ord=1), | |
'theta': np.absolute((theta - p0_theta) % np.pi) | |
} | |
return loss, err | |
def trans_forward(self, inp, softmax=True): | |
inp_img = inp['inp_img'] | |
p0 = inp['p0'] | |
output = self.transport.forward(inp_img, p0, softmax=softmax) | |
return output | |
def transport_training_step(self, frame, backprop=True, compute_err=False): | |
inp_img = frame['img'].float() | |
p0 = frame['p0'] | |
p1, p1_theta = frame['p1'], frame['p1_theta'] | |
inp = {'inp_img': inp_img, 'p0': p0} | |
output = self.trans_forward(inp, softmax=False) | |
err, loss = self.transport_criterion(backprop, compute_err, inp, output, p0, p1, p1_theta) | |
return loss, err | |
def transport_criterion(self, backprop, compute_err, inp, output, p, q, theta): | |
s = time.time() | |
if type(theta) is torch.Tensor: | |
theta = theta.detach().cpu().numpy() | |
itheta = theta / (2 * np.pi / self.transport.n_rotations) | |
itheta = np.int32(np.round(itheta)) % self.transport.n_rotations | |
# Get one-hot pixel label map. | |
inp_img = inp['inp_img'] | |
# label_size = inp_img.shape[:2] + (self.transport.n_rotations,) | |
label_size = inp_img.shape[:3] + (self.transport.n_rotations,) | |
label = torch.zeros(label_size, dtype=torch.float, device=output.device) | |
# remove this for-loop laters | |
q[:,0] = torch.clamp(q[:,0], 0, label.shape[1]-1) | |
q[:,1] = torch.clamp(q[:,1], 0, label.shape[2]-1) | |
for idx, q_i in enumerate(q): | |
label[idx, int(q_i[0]), int(q_i[1]), itheta[idx]] = 1 | |
label = label.permute((0, 3, 1, 2)).contiguous() | |
# Get loss. | |
loss = self.cross_entropy_with_logits(output, label) | |
if backprop: | |
transport_optim = self._optimizers['trans'] | |
transport_optim.zero_grad() | |
self.manual_backward(loss) | |
transport_optim.step() | |
# Pixel and Rotation error (not used anywhere). | |
err = {} | |
if compute_err: | |
with torch.no_grad(): | |
place_conf = self.trans_forward(inp) | |
# pick the first batch | |
place_conf = place_conf[0] | |
q = q[0] | |
theta = theta[0] | |
place_conf = place_conf.permute(1, 2, 0) | |
place_conf = place_conf.detach().cpu().numpy() | |
argmax = np.argmax(place_conf) | |
argmax = np.unravel_index(argmax, shape=place_conf.shape) | |
p1_pix = argmax[:2] | |
p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2]) | |
err = { | |
'dist': np.linalg.norm(np.array(q.detach().cpu().numpy()) - p1_pix, ord=1), | |
'theta': np.absolute((theta - p1_theta) % np.pi) | |
} | |
self.transport.iters += 1 | |
return err, loss | |
def training_step(self, batch, batch_idx): | |
self.attention.train() | |
self.transport.train() | |
frame, _ = batch | |
self.start_time = time.time() | |
# Get training losses. | |
step = self.total_steps + 1 | |
loss0, err0 = self.attn_training_step(frame) | |
self.start_time = time.time() | |
if isinstance(self.transport, Attention): | |
loss1, err1 = self.attn_training_step(frame) | |
else: | |
loss1, err1 = self.transport_training_step(frame) | |
total_loss = loss0 + loss1 | |
self.total_steps = step | |
self.start_time = time.time() | |
self.log('tr/attn/loss', loss0) | |
self.log('tr/trans/loss', loss1) | |
self.log('tr/loss', total_loss) | |
self.check_save_iteration() | |
return dict( | |
loss=total_loss, | |
) | |
def check_save_iteration(self): | |
global_step = self.total_steps | |
if (global_step + 1) % 100 == 0: | |
# save lastest checkpoint | |
print(f"Saving last.ckpt Epoch: {self.trainer.current_epoch} | Global Step: {self.trainer.global_step}") | |
self.save_last_checkpoint() | |
def save_last_checkpoint(self): | |
checkpoint_path = os.path.join(self.cfg['train']['train_dir'], 'checkpoints') | |
ckpt_path = os.path.join(checkpoint_path, 'last.ckpt') | |
self.trainer.save_checkpoint(ckpt_path) | |
def validation_step(self, batch, batch_idx): | |
self.attention.eval() | |
self.transport.eval() | |
loss0, loss1 = 0, 0 | |
assert self.val_repeats >= 1 | |
for i in range(self.val_repeats): | |
frame, _ = batch | |
l0, err0 = self.attn_training_step(frame, backprop=False, compute_err=True) | |
loss0 += l0 | |
if isinstance(self.transport, Attention): | |
l1, err1 = self.attn_training_step(frame, backprop=False, compute_err=True) | |
loss1 += l1 | |
else: | |
l1, err1 = self.transport_training_step(frame, backprop=False, compute_err=True) | |
loss1 += l1 | |
loss0 /= self.val_repeats | |
loss1 /= self.val_repeats | |
val_total_loss = loss0 + loss1 | |
return dict( | |
val_loss=val_total_loss, | |
val_loss0=loss0, | |
val_loss1=loss1, | |
val_attn_dist_err=err0['dist'], | |
val_attn_theta_err=err0['theta'], | |
val_trans_dist_err=err1['dist'], | |
val_trans_theta_err=err1['theta'], | |
) | |
def training_epoch_end(self, all_outputs): | |
super().training_epoch_end(all_outputs) | |
utils.set_seed(self.trainer.current_epoch+1) | |
def validation_epoch_end(self, all_outputs): | |
mean_val_total_loss = np.mean([v['val_loss'].item() for v in all_outputs]) | |
mean_val_loss0 = np.mean([v['val_loss0'].item() for v in all_outputs]) | |
mean_val_loss1 = np.mean([v['val_loss1'].item() for v in all_outputs]) | |
total_attn_dist_err = np.sum([v['val_attn_dist_err'].sum() for v in all_outputs]) | |
total_attn_theta_err = np.sum([v['val_attn_theta_err'].sum() for v in all_outputs]) | |
total_trans_dist_err = np.sum([v['val_trans_dist_err'].sum() for v in all_outputs]) | |
total_trans_theta_err = np.sum([v['val_trans_theta_err'].sum() for v in all_outputs]) | |
self.log('vl/attn/loss', mean_val_loss0) | |
self.log('vl/trans/loss', mean_val_loss1) | |
self.log('vl/loss', mean_val_total_loss) | |
self.log('vl/total_attn_dist_err', total_attn_dist_err) | |
self.log('vl/total_attn_theta_err', total_attn_theta_err) | |
self.log('vl/total_trans_dist_err', total_trans_dist_err) | |
self.log('vl/total_trans_theta_err', total_trans_theta_err) | |
print("\nAttn Err - Dist: {:.2f}, Theta: {:.2f}".format(total_attn_dist_err, total_attn_theta_err)) | |
print("Transport Err - Dist: {:.2f}, Theta: {:.2f}".format(total_trans_dist_err, total_trans_theta_err)) | |
return dict( | |
val_loss=mean_val_total_loss, | |
val_loss0=mean_val_loss0, | |
mean_val_loss1=mean_val_loss1, | |
total_attn_dist_err=total_attn_dist_err, | |
total_attn_theta_err=total_attn_theta_err, | |
total_trans_dist_err=total_trans_dist_err, | |
total_trans_theta_err=total_trans_theta_err, | |
) | |
def act(self, obs, info=None, goal=None): # pylint: disable=unused-argument | |
"""Run inference and return best action given visual observations.""" | |
# Get heightmap from RGB-D images. | |
img = self.test_ds.get_image(obs) | |
# Attention model forward pass. | |
pick_inp = {'inp_img': img} | |
pick_conf = self.attn_forward(pick_inp) | |
pick_conf = pick_conf.detach().cpu().numpy() | |
argmax = np.argmax(pick_conf) | |
argmax = np.unravel_index(argmax, shape=pick_conf.shape) | |
p0_pix = argmax[:2] | |
p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2]) | |
# Transport model forward pass. | |
place_inp = {'inp_img': img, 'p0': p0_pix} | |
place_conf = self.trans_forward(place_inp) | |
place_conf = place_conf.permute(1, 2, 0) | |
place_conf = place_conf.detach().cpu().numpy() | |
argmax = np.argmax(place_conf) | |
argmax = np.unravel_index(argmax, shape=place_conf.shape) | |
p1_pix = argmax[:2] | |
p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2]) | |
# Pixels to end effector poses. | |
hmap = img[:, :, 3] | |
p0_xyz = utils.pix_to_xyz(p0_pix, hmap, self.bounds, self.pix_size) | |
p1_xyz = utils.pix_to_xyz(p1_pix, hmap, self.bounds, self.pix_size) | |
p0_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p0_theta)) | |
p1_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p1_theta)) | |
return { | |
'pose0': (np.asarray(p0_xyz), np.asarray(p0_xyzw)), | |
'pose1': (np.asarray(p1_xyz), np.asarray(p1_xyzw)), | |
'pick': p0_pix, | |
'place': p1_pix, | |
} | |
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure, on_tpu, using_native_amp, using_lbfgs): | |
pass | |
def configure_optimizers(self): | |
pass | |
def train_dataloader(self): | |
return self.train_loader | |
def val_dataloader(self): | |
return self.test_loader | |
def load(self, model_path): | |
self.load_state_dict(torch.load(model_path)['state_dict']) | |
self.to(device=self.device_type) | |
class OriginalTransporterAgent(TransporterAgent): | |
def __init__(self, name, cfg, train_ds, test_ds): | |
super().__init__(name, cfg, train_ds, test_ds) | |
def _build_model(self): | |
stream_fcn = 'plain_resnet' | |
self.attention = Attention( | |
stream_fcn=(stream_fcn, None), | |
in_shape=self.in_shape, | |
n_rotations=1, | |
preprocess=utils.preprocess, | |
cfg=self.cfg, | |
device=self.device_type, | |
) | |
self.transport = Transport( | |
stream_fcn=(stream_fcn, None), | |
in_shape=self.in_shape, | |
n_rotations=self.n_rotations, | |
crop_size=self.crop_size, | |
preprocess=utils.preprocess, | |
cfg=self.cfg, | |
device=self.device_type, | |
) | |
class ClipUNetTransporterAgent(TransporterAgent): | |
def __init__(self, name, cfg, train_ds, test_ds): | |
super().__init__(name, cfg, train_ds, test_ds) | |
def _build_model(self): | |
stream_fcn = 'clip_unet' | |
self.attention = Attention( | |
stream_fcn=(stream_fcn, None), | |
in_shape=self.in_shape, | |
n_rotations=1, | |
preprocess=utils.preprocess, | |
cfg=self.cfg, | |
device=self.device_type, | |
) | |
self.transport = Transport( | |
stream_fcn=(stream_fcn, None), | |
in_shape=self.in_shape, | |
n_rotations=self.n_rotations, | |
crop_size=self.crop_size, | |
preprocess=utils.preprocess, | |
cfg=self.cfg, | |
device=self.device_type, | |
) | |
class TwoStreamClipUNetTransporterAgent(TransporterAgent): | |
def __init__(self, name, cfg, train_ds, test_ds): | |
super().__init__(name, cfg, train_ds, test_ds) | |
def _build_model(self): | |
stream_one_fcn = 'plain_resnet' | |
stream_two_fcn = 'clip_unet' | |
self.attention = TwoStreamAttention( | |
stream_fcn=(stream_one_fcn, stream_two_fcn), | |
in_shape=self.in_shape, | |
n_rotations=1, | |
preprocess=utils.preprocess, | |
cfg=self.cfg, | |
device=self.device_type, | |
) | |
self.transport = TwoStreamTransport( | |
stream_fcn=(stream_one_fcn, stream_two_fcn), | |
in_shape=self.in_shape, | |
n_rotations=self.n_rotations, | |
crop_size=self.crop_size, | |
preprocess=utils.preprocess, | |
cfg=self.cfg, | |
device=self.device_type, | |
) | |
class TwoStreamClipUNetLatTransporterAgent(TransporterAgent): | |
def __init__(self, name, cfg, train_ds, test_ds): | |
super().__init__(name, cfg, train_ds, test_ds) | |
def _build_model(self): | |
stream_one_fcn = 'plain_resnet_lat' | |
stream_two_fcn = 'clip_unet_lat' | |
self.attention = TwoStreamAttentionLat( | |
stream_fcn=(stream_one_fcn, stream_two_fcn), | |
in_shape=self.in_shape, | |
n_rotations=1, | |
preprocess=utils.preprocess, | |
cfg=self.cfg, | |
device=self.device_type, | |
) | |
self.transport = TwoStreamTransportLat( | |
stream_fcn=(stream_one_fcn, stream_two_fcn), | |
in_shape=self.in_shape, | |
n_rotations=self.n_rotations, | |
crop_size=self.crop_size, | |
preprocess=utils.preprocess, | |
cfg=self.cfg, | |
device=self.device_type, | |
) | |
class TwoStreamClipWithoutSkipsTransporterAgent(TransporterAgent): | |
def __init__(self, name, cfg, train_ds, test_ds): | |
super().__init__(name, cfg, train_ds, test_ds) | |
def _build_model(self): | |
# TODO: lateral version | |
stream_one_fcn = 'plain_resnet' | |
stream_two_fcn = 'clip_woskip' | |
self.attention = TwoStreamAttention( | |
stream_fcn=(stream_one_fcn, stream_two_fcn), | |
in_shape=self.in_shape, | |
n_rotations=1, | |
preprocess=utils.preprocess, | |
cfg=self.cfg, | |
device=self.device_type, | |
) | |
self.transport = TwoStreamTransport( | |
stream_fcn=(stream_one_fcn, stream_two_fcn), | |
in_shape=self.in_shape, | |
n_rotations=self.n_rotations, | |
crop_size=self.crop_size, | |
preprocess=utils.preprocess, | |
cfg=self.cfg, | |
device=self.device_type, | |
) | |
class TwoStreamRN50BertUNetTransporterAgent(TransporterAgent): | |
def __init__(self, name, cfg, train_ds, test_ds): | |
super().__init__(name, cfg, train_ds, test_ds) | |
def _build_model(self): | |
# TODO: lateral version | |
stream_one_fcn = 'plain_resnet' | |
stream_two_fcn = 'rn50_bert_unet' | |
self.attention = TwoStreamAttention( | |
stream_fcn=(stream_one_fcn, stream_two_fcn), | |
in_shape=self.in_shape, | |
n_rotations=1, | |
preprocess=utils.preprocess, | |
cfg=self.cfg, | |
device=self.device_type, | |
) | |
self.transport = TwoStreamTransport( | |
stream_fcn=(stream_one_fcn, stream_two_fcn), | |
in_shape=self.in_shape, | |
n_rotations=self.n_rotations, | |
crop_size=self.crop_size, | |
preprocess=utils.preprocess, | |
cfg=self.cfg, | |
device=self.device_type, | |
) | |