File size: 4,473 Bytes
2dd8c89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b879e42
474f761
2dd8c89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
---
license: mit
license_link: https://huggingface.co./microsoft/Phi-3-vision-128k-instruct/resolve/main/LICENSE

language:
- multilingual
pipeline_tag: text-generation
tags:
- nlp
- code
- vision
widget:
  - messages:
      - role: user
        content: <|image_1|>\nWhat action should the robot take to {lang}?
---

## TraceVLA-Phi3V
``TraceVLA-Phi3V`` model is a vision-language-action model obtained by finetuning the base OpenVLA-Phi3V Model on the Open X-Embodiment robot mixture dataset with [visual trace prompting](https://arxiv.org/pdf/2412.10345) technique.

### Results on SimplerEnv Fractal + SimplerEnv:

#### Fractal:
| Policy/Settings | Pick up Coke | Move near | Open/Close Drawer | Put in Drawer | Average |
|:------:|:------------:|:---------:|:------------:|:-----------:|:-------:|
| (Visual Matching) OpenVLA-Phi3V | **56.7%** | 53.3% | **38.4%** | **15.7%** | **41.0%** |
| (Visual Matching) TraceVLA-Phi3V | **69.7%** | **70.8%** | **35.4%** | 0.% | **44.0%** |
| (Variant Aggregation) OpenVLA-Phi3V | 55.4% | **57.7%** | 19.3% | **10.6%** | 35.8% |
| (Variant Aggregation) TraceVLA-Phi3V | **75.4%** | **67.8%** | **37.5%** | 0.0% | **45.1%** |

#### Bridge:
| Policy/Settings | Put Spoon | Put Carrot | Stack Block | Put Eggplant | Average |
|:------:|:------------:|:---------:|:------------:|:-----------:|:-------:|
| OpenVLA-Phi3V | **12.5%** | 0% | 0% | 8.3% | 5.2% |
| TraceVLA-Phi3V | 8.3% | 0% | **12.5%** | **66.7%** | **21.9%** |


### Sample Inference Code
Here is the sample inference code of OpenVLA-Phi3V.
```
# Load Processor & VLA
from transformers import AutoModelForCausalLM , AutoProcessor
from PIL import Image
import json
processor = AutoProcessor.from_pretrained(
    model_path, trust_remote_code=True, num_crops=1
)

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    _attn_implementation='flash_attention_2',
    use_cache=False
).cuda()

# Load Visual Trace Processor
from prismatic.eval import TraceProcessor
trace_processor = TraceProcessor(cotracker_model_path)

# Load dataset statistics 
dataset_stats_dir = os.path.join(model_path, 'dataset_statistics.json')
with open(dataset_stats_dir, 'r') as file: 
    action_norm_stats = json.load(file)[dataset_name]['action']
    model.prepare_action_inference(action_norm_stats, processor.tokenizer.vocab_size)

lang: str = None # Task language instruction
### IMPORTANT: Make sure image is of size (336,336)
image: PIL.Image = None # Image observation

# Get visual trace overlaid image observation
image = resize_image(image, (256,256)) ### 256x256 is the resolution of Co-Tracker Input Resolution
image_overlaid, has_trace = self.trace_processors[i].process_image(image) 
image_overlaid = resize_image(image_overlaid, (336,336)) ### 336x336 is the resolution of Phi3V image encoder.

# Prepare TraceVLA prompt format
if not has_trace:
    prompt_message = {
    'role': 'user',
    'content': f'<|image_1|><|image_2|>\nWhat action should the robot take to {task_description}?',
    }
else:
    prompt_message = {
        'role': 'user',
        'content': f'You are given two images: one with the original robot observation <|image_1|>, and another one marked with historial traces of the robot end effector and moving objects <|image_2|>.\nWhat action should the robot take to {task_description}?',
    }
prompt = processor.tokenizer.apply_chat_template(
    [prompt_message], tokenize=False, add_generation_prompt=True
)
inputs = processor(prompt, [image, image_overlaid]).to("cuda:0", dtype=torch.bfloat16)

    
# Get the action output from model
model.predict_action(**inputs)
```

For more examples, including scripts for finetuning OpenVLA-Phi3V models on your own robot demonstration datasets, check out our [repository](https://github.com/FrankZheng2022/tracevla/tree/phi3).




### Citation

If you find our code or models useful in your work, please cite [our paper](https://arxiv.org/abs/2412.10345):

```bibtex
@misc{zheng2024tracevlavisualtraceprompting,
      title={TraceVLA: Visual Trace Prompting Enhances Spatial-Temporal Awareness for Generalist Robotic Policies}, 
      author={Ruijie Zheng and Yongyuan Liang and Shuaiyi Huang and Jianfeng Gao and Hal Daumé III and Andrey Kolobov and Furong Huang and Jianwei Yang},
      year={2024},
      eprint={2412.10345},
      archivePrefix={arXiv},
      primaryClass={cs.RO},
      url={https://arxiv.org/abs/2412.10345}, 
}
```