SG3D-Demo / leo /inference.py
zfzhang-thu
initialize cuda
782b3bb
import json
import os
import torch
import numpy as np
from leo.model import SequentialGrounder
from leo.utils import LabelConverter, convert_pc_to_box, obj_processing_post, pad_sequence
from torch.utils.data import default_collate
ASSET_DIR = os.path.join(os.getcwd(), 'assets')
CKPT_DIR = os.path.join(os.getcwd(), 'checkpoint/leo')
int2cat = json.load(open(os.path.join(ASSET_DIR, "meta/scannetv2_raw_categories.json"), 'r', encoding="utf-8"))
cat2int = {w: i for i, w in enumerate(int2cat)}
label_converter = LabelConverter(os.path.join(ASSET_DIR, "meta/scannetv2-labels.combined.tsv"))
role_prompt = "You are an AI visual assistant situated in a 3D scene. "\
"You can perceive (1) an ego-view image (accessible when necessary) and (2) the objects (including yourself) in the scene (always accessible). "\
"You should properly respond to the USER's instruction according to the given visual information. "
#role_prompt = " "
egoview_prompt = "Ego-view image:"
objects_prompt = "Objects (including you) in the scene:"
task_prompt = "USER: {instruction} ASSISTANT:"
def get_prompt(instruction):
return {
'prompt_before_obj': role_prompt,
'prompt_middle_1': egoview_prompt,
'prompt_middle_2': objects_prompt,
'prompt_after_obj': task_prompt.format(instruction=instruction),
}
def get_lang(task_item):
task_description = task_item['task_description']
sentence = task_description
data_dict = get_prompt(task_description)
# scan_id = task_item['scan_id']
if 'action_steps' in task_item:
action_steps = task_item['action_steps']
# tgt_object_id = [int(action['target_id']) for action in action_steps]
# tgt_object_name = [action['label'] for action in action_steps]
for action in action_steps:
sentence += ' ' + action['action']
data_dict['output_gt'] = ' '.join([action['action'] + ' <s>' for action in action_steps])
# return scan_id, tgt_object_id, tgt_object_name, sentence, data_dict
return data_dict
def load_data(scan_id):
one_scan = {}
# load scan
pcd_data = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_pcd.pth'))
inst_to_label = torch.load(os.path.join(ASSET_DIR, f'inputs/{scan_id}', f'{scan_id}_inst.pth'))
points, colors, instance_labels = pcd_data[0], pcd_data[1], pcd_data[-1]
colors = colors / 127.5 - 1
pcds = np.concatenate([points, colors], 1)
one_scan['pcds'] = pcds
one_scan['instance_labels'] = instance_labels
one_scan['inst_to_label'] = inst_to_label
# convert to gt object
obj_pcds = []
inst_ids = []
inst_labels = []
bg_indices = np.full((points.shape[0], ), 1, dtype=np.bool_)
for inst_id in inst_to_label.keys():
if inst_to_label[inst_id] in cat2int.keys():
mask = instance_labels == inst_id
if np.sum(mask) == 0:
continue
obj_pcds.append(pcds[mask])
inst_ids.append(inst_id)
inst_labels.append(cat2int[inst_to_label[inst_id]])
if inst_to_label[inst_id] not in ['wall', 'floor', 'ceiling']:
bg_indices[mask] = False
one_scan['obj_pcds'] = obj_pcds
one_scan['inst_labels'] = inst_labels
one_scan['inst_ids'] = inst_ids
one_scan['bg_pcds'] = pcds[bg_indices]
# calculate box for matching
obj_center = []
obj_box_size = []
for obj_pcd in obj_pcds:
_c, _b = convert_pc_to_box(obj_pcd)
obj_center.append(_c)
obj_box_size.append(_b)
one_scan['obj_loc'] = obj_center
one_scan['obj_box'] = obj_box_size
# load point feat
feat_pth = os.path.join(ASSET_DIR, f'inputs/{scan_id}', 'obj_feats.pth')
one_scan['obj_feats'] = torch.load(feat_pth, map_location='cpu')
# convert to pq3d input
obj_labels = one_scan['inst_labels'] # N
obj_pcds = one_scan['obj_pcds']
obj_ids = one_scan['inst_ids']
# object filter
excluded_labels = ['wall', 'floor', 'ceiling']
def keep_obj(i, obj_label):
category = int2cat[obj_label]
# filter out background
if category in excluded_labels:
return False
# filter out objects not mentioned in the sentence
return True
selected_obj_idxs = [i for i, obj_label in enumerate(obj_labels) if keep_obj(i, obj_label)]
# crop objects to max_obj_len and reorganize ids ? # TODO
obj_labels = [obj_labels[i] for i in selected_obj_idxs]
obj_pcds = [obj_pcds[i] for i in selected_obj_idxs]
# subsample points
obj_pcds = np.array([obj_pcd[np.random.choice(len(obj_pcd), size=1024,
replace=len(obj_pcd) < 1024)] for obj_pcd in obj_pcds])
obj_fts, obj_locs, obj_boxes, rot_matrix = obj_processing_post(obj_pcds, rot_aug=False)
data_dict = {
"scan_id": scan_id,
"obj_fts": obj_fts.float(),
"obj_locs": obj_locs.float(),
"obj_labels": torch.LongTensor(obj_labels),
"obj_boxes": obj_boxes,
"obj_pad_masks": torch.ones((len(obj_locs)), dtype=torch.bool), # used for padding in collate
"obj_ids": torch.LongTensor([obj_ids[i] for i in selected_obj_idxs])
}
# convert point feature
data_dict['obj_feats'] = one_scan['obj_feats'].squeeze(0)
useful_keys = ['tgt_object_id', 'scan_id', 'obj_labels', 'data_idx',
'obj_fts', 'obj_locs', 'obj_pad_masks', 'obj_ids',
'source', 'prompt_before_obj', 'prompt_middle_1',
'prompt_middle_2', 'prompt_after_obj', 'output_gt', 'obj_feats']
for k in list(data_dict.keys()):
if k not in useful_keys:
del data_dict[k]
# add new keys because of leo
data_dict['img_fts'] = torch.zeros(3, 224, 224)
data_dict['img_masks'] = torch.LongTensor([0]).bool()
data_dict['anchor_locs'] = torch.zeros(3)
data_dict['anchor_orientation'] = torch.zeros(4)
data_dict['anchor_orientation'][-1] = 1 # xyzw
# convert to leo format
data_dict['obj_masks'] = data_dict['obj_pad_masks']
del data_dict['obj_pad_masks']
return data_dict
def form_batch(data_dict):
batch = [data_dict]
new_batch = {}
# pad
padding_keys = ['obj_fts', 'obj_locs', 'obj_masks', 'obj_labels', 'obj_ids']
for k in padding_keys:
tensors = [sample.pop(k) for sample in batch]
padded_tensor = pad_sequence(tensors, pad=0)
new_batch[k] = padded_tensor
# # list
# list_keys = ['tgt_object_id']
# for k in list_keys:
# new_batch[k] = [sample.pop(k) for sample in batch]
# default collate
new_batch.update(default_collate(batch))
return new_batch
def inference(scan_id, task, predict_mode=False):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = 'cpu' # ok for predict_mode=False, and both for Gradio demo local preview
data_dict = load_data(scan_id)
data_dict.update(get_lang(task))
data_dict = form_batch(data_dict)
for key, value in data_dict.items():
if isinstance(value, torch.Tensor):
data_dict[key] = value.to(device)
model = SequentialGrounder(predict_mode)
load_msg = model.load_state_dict(torch.load(os.path.join(CKPT_DIR, 'pytorch_model.bin'), map_location='cpu'), strict=False)
model.to(device)
data_dict = model(data_dict)
if predict_mode == False:
# calculate result id
result_id_list = [data_dict['obj_ids'][0][torch.argmax(data_dict['ground_logits'][i]).item()]
for i in range(len(data_dict['ground_logits']))]
else:
# calculate langauge
# tgt_object_id = data_dict['tgt_object_id']
if data_dict['ground_logits'] == None:
og_pred = []
else:
og_pred = torch.argmax(data_dict['ground_logits'], dim=1)
grd_batch_ind_list = data_dict['grd_batch_ind_list']
response_pred = []
for i in range(1): # len(tgt_object_id)
# target_sequence = list(tgt_object_id[i].cpu().numpy())
predict_sequence = []
if og_pred != None:
for j in range(len(og_pred)):
if grd_batch_ind_list[j] == i:
predict_sequence.append(og_pred[j].item())
obj_ids = data_dict['obj_ids']
response_pred.append({
'predict_object_id' : [obj_ids[i][o].item() for o in predict_sequence],
'predict_object_id': [obj_ids[i][o].item() for o in predict_sequence],
'pred_plan_text': data_dict['output_txt'][i]
})
return result_id_list if predict_mode == False else response_pred
if __name__ == '__main__':
inference("scene0050_00", {
"task_description": "Find the chair and move it to the table.",
"action_steps": [
{
"target_id": "1",
"label": "chair",
"action": "Find the chair."
},
{
"target_id": "2",
"label": "table",
"action": "Move the chair to the table."
}
],
"scan_id": "scene0050_00"
}, predict_mode=True)