Spaces:
Runtime error
Runtime error
import json | |
import os | |
import torch | |
import numpy as np | |
from leo.model import SequentialGrounder | |
from leo.utils import LabelConverter, convert_pc_to_box, obj_processing_post, pad_sequence | |
from torch.utils.data import default_collate | |
ASSET_DIR = os.path.join(os.getcwd(), 'assets') | |
CKPT_DIR = os.path.join(os.getcwd(), 'checkpoint/leo') | |
int2cat = json.load(open(os.path.join(ASSET_DIR, "meta/scannetv2_raw_categories.json"), 'r', encoding="utf-8")) | |
cat2int = {w: i for i, w in enumerate(int2cat)} | |
label_converter = LabelConverter(os.path.join(ASSET_DIR, "meta/scannetv2-labels.combined.tsv")) | |
role_prompt = "You are an AI visual assistant situated in a 3D scene. "\ | |
"You can perceive (1) an ego-view image (accessible when necessary) and (2) the objects (including yourself) in the scene (always accessible). "\ | |
"You should properly respond to the USER's instruction according to the given visual information. " | |
#role_prompt = " " | |
egoview_prompt = "Ego-view image:" | |
objects_prompt = "Objects (including you) in the scene:" | |
task_prompt = "USER: {instruction} ASSISTANT:" | |
def get_prompt(instruction): | |
return { | |
'prompt_before_obj': role_prompt, | |
'prompt_middle_1': egoview_prompt, | |
'prompt_middle_2': objects_prompt, | |
'prompt_after_obj': task_prompt.format(instruction=instruction), | |
} | |
def get_lang(task_item): | |
task_description = task_item['task_description'] | |
sentence = task_description | |
data_dict = get_prompt(task_description) | |
# scan_id = task_item['scan_id'] | |
if 'action_steps' in task_item: | |
action_steps = task_item['action_steps'] | |
# tgt_object_id = [int(action['target_id']) for action in action_steps] | |
# tgt_object_name = [action['label'] for action in action_steps] | |
for action in action_steps: | |
sentence += ' ' + action['action'] | |
data_dict['output_gt'] = ' '.join([action['action'] + ' <s>' for action in action_steps]) | |
# return scan_id, tgt_object_id, tgt_object_name, sentence, data_dict | |
return data_dict | |
def load_data(scan_id): | |
one_scan = {} | |
# load scan | |
pcd_data = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_pcd.pth')) | |
inst_to_label = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_inst.pth')) | |
points, colors, instance_labels = pcd_data[0], pcd_data[1], pcd_data[-1] | |
colors = colors / 127.5 - 1 | |
pcds = np.concatenate([points, colors], 1) | |
one_scan['pcds'] = pcds | |
one_scan['instance_labels'] = instance_labels | |
one_scan['inst_to_label'] = inst_to_label | |
# convert to gt object | |
obj_pcds = [] | |
inst_ids = [] | |
inst_labels = [] | |
bg_indices = np.full((points.shape[0], ), 1, dtype=np.bool_) | |
for inst_id in inst_to_label.keys(): | |
if inst_to_label[inst_id] in cat2int.keys(): | |
mask = instance_labels == inst_id | |
if np.sum(mask) == 0: | |
continue | |
obj_pcds.append(pcds[mask]) | |
inst_ids.append(inst_id) | |
inst_labels.append(cat2int[inst_to_label[inst_id]]) | |
if inst_to_label[inst_id] not in ['wall', 'floor', 'ceiling']: | |
bg_indices[mask] = False | |
one_scan['obj_pcds'] = obj_pcds | |
one_scan['inst_labels'] = inst_labels | |
one_scan['inst_ids'] = inst_ids | |
one_scan['bg_pcds'] = pcds[bg_indices] | |
# calculate box for matching | |
obj_center = [] | |
obj_box_size = [] | |
for obj_pcd in obj_pcds: | |
_c, _b = convert_pc_to_box(obj_pcd) | |
obj_center.append(_c) | |
obj_box_size.append(_b) | |
one_scan['obj_loc'] = obj_center | |
one_scan['obj_box'] = obj_box_size | |
# load point feat | |
feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', 'obj_feats.pth') | |
one_scan['obj_feats'] = torch.load(feat_pth, map_location='cpu') | |
# convert to pq3d input | |
obj_labels = one_scan['inst_labels'] # N | |
obj_pcds = one_scan['obj_pcds'] | |
obj_ids = one_scan['inst_ids'] | |
# object filter | |
excluded_labels = ['wall', 'floor', 'ceiling'] | |
def keep_obj(i, obj_label): | |
category = int2cat[obj_label] | |
# filter out background | |
if category in excluded_labels: | |
return False | |
# filter out objects not mentioned in the sentence | |
return True | |
selected_obj_idxs = [i for i, obj_label in enumerate(obj_labels) if keep_obj(i, obj_label)] | |
# crop objects to max_obj_len and reorganize ids ? # TODO | |
obj_labels = [obj_labels[i] for i in selected_obj_idxs] | |
obj_pcds = [obj_pcds[i] for i in selected_obj_idxs] | |
# subsample points | |
obj_pcds = np.array([obj_pcd[np.random.choice(len(obj_pcd), size=1024, | |
replace=len(obj_pcd) < 1024)] for obj_pcd in obj_pcds]) | |
obj_fts, obj_locs, obj_boxes, rot_matrix = obj_processing_post(obj_pcds, rot_aug=False) | |
data_dict = { | |
"scan_id": scan_id, | |
"obj_fts": obj_fts.float(), | |
"obj_locs": obj_locs.float(), | |
"obj_labels": torch.LongTensor(obj_labels), | |
"obj_boxes": obj_boxes, | |
"obj_pad_masks": torch.ones((len(obj_locs)), dtype=torch.bool), # used for padding in collate | |
"obj_ids": torch.LongTensor([obj_ids[i] for i in selected_obj_idxs]) | |
} | |
# convert point feature | |
data_dict['obj_feats'] = one_scan['obj_feats'].squeeze(0) | |
useful_keys = ['tgt_object_id', 'scan_id', 'obj_labels', 'data_idx', | |
'obj_fts', 'obj_locs', 'obj_pad_masks', 'obj_ids', | |
'source', 'prompt_before_obj', 'prompt_middle_1', | |
'prompt_middle_2', 'prompt_after_obj', 'output_gt', 'obj_feats'] | |
for k in list(data_dict.keys()): | |
if k not in useful_keys: | |
del data_dict[k] | |
# add new keys because of leo | |
data_dict['img_fts'] = torch.zeros(3, 224, 224) | |
data_dict['img_masks'] = torch.LongTensor([0]).bool() | |
data_dict['anchor_locs'] = torch.zeros(3) | |
data_dict['anchor_orientation'] = torch.zeros(4) | |
data_dict['anchor_orientation'][-1] = 1 # xyzw | |
# convert to leo format | |
data_dict['obj_masks'] = data_dict['obj_pad_masks'] | |
del data_dict['obj_pad_masks'] | |
return data_dict | |
def form_batch(data_dict): | |
batch = [data_dict] | |
new_batch = {} | |
# pad | |
padding_keys = ['obj_fts', 'obj_locs', 'obj_masks', 'obj_labels', 'obj_ids'] | |
for k in padding_keys: | |
tensors = [sample.pop(k) for sample in batch] | |
padded_tensor = pad_sequence(tensors, pad=0) | |
new_batch[k] = padded_tensor | |
# # list | |
# list_keys = ['tgt_object_id'] | |
# for k in list_keys: | |
# new_batch[k] = [sample.pop(k) for sample in batch] | |
# default collate | |
new_batch.update(default_collate(batch)) | |
return new_batch | |
def inference(scan_id, task, predict_mode=False): | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# device = 'cpu' # ok for predict_mode=False, and both for Gradio demo local preview | |
data_dict = load_data(scan_id) | |
data_dict.update(get_lang(task)) | |
data_dict = form_batch(data_dict) | |
for key, value in data_dict.items(): | |
if isinstance(value, torch.Tensor): | |
data_dict[key] = value.to(device) | |
model = SequentialGrounder(predict_mode) | |
load_msg = model.load_state_dict(torch.load(os.path.join(CKPT_DIR, 'pytorch_model.bin'), map_location='cpu'), strict=False) | |
model.to(device) | |
data_dict = model(data_dict) | |
if predict_mode == False: | |
# calculate result id | |
result_id_list = [data_dict['obj_ids'][0][torch.argmax(data_dict['ground_logits'][i]).item()] | |
for i in range(len(data_dict['ground_logits']))] | |
else: | |
# calculate langauge | |
# tgt_object_id = data_dict['tgt_object_id'] | |
if data_dict['ground_logits'] == None: | |
og_pred = [] | |
else: | |
og_pred = torch.argmax(data_dict['ground_logits'], dim=1) | |
grd_batch_ind_list = data_dict['grd_batch_ind_list'] | |
response_pred = [] | |
for i in range(1): # len(tgt_object_id) | |
# target_sequence = list(tgt_object_id[i].cpu().numpy()) | |
predict_sequence = [] | |
if og_pred != None: | |
for j in range(len(og_pred)): | |
if grd_batch_ind_list[j] == i: | |
predict_sequence.append(og_pred[j].item()) | |
obj_ids = data_dict['obj_ids'] | |
response_pred.append({ | |
'predict_object_id' : [obj_ids[i][o].item() for o in predict_sequence], | |
'predict_object_id': [obj_ids[i][o].item() for o in predict_sequence], | |
'pred_plan_text': data_dict['output_txt'][i] | |
}) | |
return result_id_list if predict_mode == False else response_pred | |
if __name__ == '__main__': | |
inference("scene0050_00", { | |
"task_description": "Find the chair and move it to the table.", | |
"action_steps": [ | |
{ | |
"target_id": "1", | |
"label": "chair", | |
"action": "Find the chair." | |
}, | |
{ | |
"target_id": "2", | |
"label": "table", | |
"action": "Move the chair to the table." | |
} | |
], | |
"scan_id": "scene0050_00" | |
}, predict_mode=True) | |