RoboVLMs model card

Introduction

This repo contains the pre-trained models through RoboVLMs, which is a unified framework for easily building VLAs from VLMs.

We open-source three pre-trained model checkpoints and their configs:

  • kosmos_ph_calvin_abcd: RoboKosMos(KosMos+Policy Head) trained on the CALVIN dataset (split ABCD).
  • kosmos_ph_calvin_abc: RoboKosMos(KosMos+Policy Head) trained on the CALVIN dataset (split ABC).
  • kosmos_ph_oxe-pretrain: RoboKosMos(KosMos+Policy Head) trained on the OXE-magic-soup dataset.

Usage

The model can be used to predict action based on the vision and language input. RoboVLMs supports several VLA structures, multi-view input and various backbones. Taking kosmos_ph_calvin_abcd as an example:

import torch
import json, functools
from PIL import Image
from robovlms.train.base_trainer import BaseTrainer
from robovlms.data.data_utils import preprocess_image
from robovlms.data.data_utils import get_text_function

configs = josn.load(open('configs/kosmos_ph_calvin_abcd.json', 'r'))
pretrained_path = 'checkpoints/kosmos_ph_calvin_abcd.pt'
configs['model_load_path'] = pretrained_path

model = BaseTrainer.from_checkpoint(configs)

image_fn = functools.partial(
    preprocess_image,
    image_processor=model.model.image_processor,
    model_type=configs["model"],
)
text_fn = get_text_function(model.model.tokenizer, configs["model"])
prompt = "Task: pickup the bottle on the table"
text_tensor, attention_mask = text_preprocess([lang])

for step in range(MAX_STEPS):
    
    image: Image.Image = get_from_side_camera(...)
    image = image_fn([image]).unsqueeze(0)
    
    input_dict["rgb"] = image
    input_dict["text"] = text_tensor
    input_dict['text_mask'] = attention_mask

    ### if wrist camera is available
    wrist_image: Image.Image = get_from_wrist_camera(...)
    wrist_image = image_fn([wrist_image]).unsqueeze(0)
    input_dict["hand_rgb"] = wrist_image

    action = model.inference_step(input_dict)["action"]

    # unormalize / reproject the action if necessary
    from robovlms.data.data_utils import unnoramalize_action
    if isinstance(action, tuple):
        action = (
            unnoramalize_action(
                action[0], self.configs["norm_min"], self.configs["norm_max"]
            ),
            action[1],
        )
    else:
        action = unnoramalize_action(
            action, self.configs["norm_min"], self.configs["norm_max"]
        )

Evaluation

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.