Spaces:
Runtime error
Runtime error
import cliport.models as models | |
import cliport.models.core.fusion as fusion | |
from cliport.models.core.transport import Transport | |
class TwoStreamTransport(Transport): | |
"""Two Stream Transport (a.k.a Place) module""" | |
def __init__(self, stream_fcn, in_shape, n_rotations, crop_size, preprocess, cfg, device): | |
self.fusion_type = cfg['train']['trans_stream_fusion_type'] | |
super().__init__(stream_fcn, in_shape, n_rotations, crop_size, preprocess, cfg, device) | |
def _build_nets(self): | |
stream_one_fcn, stream_two_fcn = self.stream_fcn | |
stream_one_model = models.names[stream_one_fcn] | |
stream_two_model = models.names[stream_two_fcn] | |
self.key_stream_one = stream_one_model(self.in_shape, self.output_dim, self.cfg, self.device, self.preprocess) | |
self.key_stream_two = stream_two_model(self.in_shape, self.output_dim, self.cfg, self.device, self.preprocess) | |
self.query_stream_one = stream_one_model(self.kernel_shape, self.kernel_dim, self.cfg, self.device, self.preprocess) | |
self.query_stream_two = stream_two_model(self.in_shape, self.kernel_dim, self.cfg, self.device, self.preprocess) | |
self.fusion_key = fusion.names[self.fusion_type](input_dim=self.kernel_dim) | |
self.fusion_query = fusion.names[self.fusion_type](input_dim=self.kernel_dim) | |
print(f"Transport FCN - Stream One: {stream_one_fcn}, Stream Two: {stream_two_fcn}, Stream Fusion: {self.fusion_type}") | |
def transport(self, in_tensor, crop): | |
logits = self.fusion_key(self.key_stream_one(in_tensor), self.key_stream_two(in_tensor)) | |
kernel = self.fusion_query(self.query_stream_one(crop), self.query_stream_two(crop)) | |
return logits, kernel | |
class TwoStreamTransportLat(TwoStreamTransport): | |
"""Two Stream Transport (a.k.a Place) module with lateral connections""" | |
def __init__(self, stream_fcn, in_shape, n_rotations, crop_size, preprocess, cfg, device): | |
super().__init__(stream_fcn, in_shape, n_rotations, crop_size, preprocess, cfg, device) | |
def transport(self, in_tensor, crop): | |
key_out_one, key_lat_one = self.key_stream_one(in_tensor) | |
key_out_two = self.key_stream_two(in_tensor, key_lat_one) | |
logits = self.fusion_key(key_out_one, key_out_two) | |
query_out_one, query_lat_one = self.query_stream_one(crop) | |
query_out_two = self.query_stream_two(crop, query_lat_one) | |
kernel = self.fusion_query(query_out_one, query_out_two) | |
return logits, kernel |