Spaces:
Building
on
A10G
Building
on
A10G
Delete app.py
Browse files
app.py
DELETED
@@ -1,286 +0,0 @@
|
|
1 |
-
'''
|
2 |
-
author: caishaofei <[email protected]>
|
3 |
-
date: 2024-09-20 20:10:44
|
4 |
-
Copyright © Team CraftJarvis All rights reserved
|
5 |
-
'''
|
6 |
-
import re
|
7 |
-
import os
|
8 |
-
import cv2
|
9 |
-
import time
|
10 |
-
from pathlib import Path
|
11 |
-
import argparse
|
12 |
-
import requests
|
13 |
-
import gradio as gr
|
14 |
-
import torch
|
15 |
-
import numpy as np
|
16 |
-
from io import BytesIO
|
17 |
-
from PIL import Image, ImageDraw
|
18 |
-
from rocket.arm.sessions import Session, Pointer
|
19 |
-
|
20 |
-
COLORS = [
|
21 |
-
(255, 0, 0), (0, 255, 0), (0, 0, 255),
|
22 |
-
(255, 255, 0), (255, 0, 255), (0, 255, 255),
|
23 |
-
(255, 255, 255), (0, 0, 0), (128, 128, 128),
|
24 |
-
(128, 0, 0), (128, 128, 0), (0, 128, 0),
|
25 |
-
(128, 0, 128), (0, 128, 128), (0, 0, 128),
|
26 |
-
]
|
27 |
-
|
28 |
-
SEGMENT_MAPPING = {
|
29 |
-
"Hunt": 0, "Use": 3, "Mine": 2, "Interact": 3, "Craft": 4, "Switch": 5, "Approach": 6
|
30 |
-
}
|
31 |
-
|
32 |
-
NOOP_ACTION = {
|
33 |
-
"back": 0,
|
34 |
-
"drop": 0,
|
35 |
-
"forward": 0,
|
36 |
-
"hotbar.1": 0,
|
37 |
-
"hotbar.2": 0,
|
38 |
-
"hotbar.3": 0,
|
39 |
-
"hotbar.4": 0,
|
40 |
-
"hotbar.5": 0,
|
41 |
-
"hotbar.6": 0,
|
42 |
-
"hotbar.7": 0,
|
43 |
-
"hotbar.8": 0,
|
44 |
-
"hotbar.9": 0,
|
45 |
-
"inventory": 0,
|
46 |
-
"jump": 0,
|
47 |
-
"left": 0,
|
48 |
-
"right": 0,
|
49 |
-
"sneak": 0,
|
50 |
-
"sprint": 0,
|
51 |
-
"camera": np.array([0, 0]),
|
52 |
-
"attack": 0,
|
53 |
-
"use": 0,
|
54 |
-
}
|
55 |
-
|
56 |
-
def reset_fn(env_name, session):
|
57 |
-
image = session.reset(env_name)
|
58 |
-
return image, session
|
59 |
-
|
60 |
-
def step_fn(act_key, session):
|
61 |
-
action = NOOP_ACTION.copy()
|
62 |
-
if act_key != "null":
|
63 |
-
action[act_key] = 1
|
64 |
-
image = session.step(action)
|
65 |
-
return image, session
|
66 |
-
|
67 |
-
def loop_step_fn(steps, session):
|
68 |
-
for i in range(steps):
|
69 |
-
image = session.step()
|
70 |
-
status = f"Running Agent `Rocket` steps: {i+1}/{steps}. "
|
71 |
-
yield image, session.num_steps, status, session
|
72 |
-
|
73 |
-
def clear_memory_fn(session):
|
74 |
-
image = session.current_image
|
75 |
-
session.clear_agent_memory()
|
76 |
-
return image, "0", session
|
77 |
-
|
78 |
-
def get_points_with_draw(image, label, session, evt: gr.SelectData):
|
79 |
-
points = session.points
|
80 |
-
point_label = session.points_label
|
81 |
-
x, y = evt.index[0], evt.index[1]
|
82 |
-
point_radius, point_color = 5, (0, 255, 0) if label == 'Add Points' else (255, 0, 0)
|
83 |
-
points.append([x, y])
|
84 |
-
point_label.append(1 if label == 'Add Points' else 0)
|
85 |
-
cv2.circle(image, (x, y), point_radius, point_color, -1)
|
86 |
-
return image, session
|
87 |
-
|
88 |
-
def clear_points_fn(session):
|
89 |
-
session.clear_points()
|
90 |
-
return session.current_image, session
|
91 |
-
|
92 |
-
def segment_fn(session):
|
93 |
-
if len(session.points) == 0:
|
94 |
-
return session.current_image, session
|
95 |
-
session.segment()
|
96 |
-
image = session.apply_mask()
|
97 |
-
return image, session
|
98 |
-
|
99 |
-
def clear_segment_fn(session):
|
100 |
-
session.clear_obj_mask()
|
101 |
-
session.tracking_flag = False
|
102 |
-
return session.current_image, False, session
|
103 |
-
|
104 |
-
def set_tracking_mode(tracking_flag, session):
|
105 |
-
session.tracking_flag = tracking_flag
|
106 |
-
return session
|
107 |
-
|
108 |
-
def set_segment_type(segment_type, session):
|
109 |
-
session.segment_type = segment_type
|
110 |
-
return session
|
111 |
-
|
112 |
-
def play_fn(session):
|
113 |
-
image = session.step()
|
114 |
-
return image, session
|
115 |
-
|
116 |
-
memory_length = gr.Textbox(value="0", interactive=False, show_label=False)
|
117 |
-
|
118 |
-
def make_video_fn(session, make_video, save_video, progress=gr.Progress()):
|
119 |
-
images = session.image_history
|
120 |
-
if len(images) == 0:
|
121 |
-
return session, make_video, save_video
|
122 |
-
filepath = "rocket.mp4"
|
123 |
-
h, w = images[0].shape[:2]
|
124 |
-
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
125 |
-
video = cv2.VideoWriter(filepath, fourcc, 20.0, (w, h))
|
126 |
-
for image in progress.tqdm(images):
|
127 |
-
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
128 |
-
video.write(image)
|
129 |
-
video.release()
|
130 |
-
session.image_history = []
|
131 |
-
return session, gr.Button("Make Video", visible=False), gr.DownloadButton("Download!", value=filepath, visible=True)
|
132 |
-
|
133 |
-
def save_video_fn(session, make_video, save_video):
|
134 |
-
return session, gr.Button("Make Video", visible=True), gr.DownloadButton("Download!", visible=False)
|
135 |
-
|
136 |
-
def choose_sam_fn(sam_choice, session):
|
137 |
-
session.sam_choice = sam_choice
|
138 |
-
session.load_sam()
|
139 |
-
return session
|
140 |
-
|
141 |
-
def molmo_fn(molmo_text, molmo_session, rocket_session, display_image):
|
142 |
-
image = rocket_session.current_image.copy()
|
143 |
-
points = molmo_session.gen_point(image=image, prompt=molmo_text)
|
144 |
-
molmo_result = molmo_session.molmo_result
|
145 |
-
for x, y in points:
|
146 |
-
x, y = int(x), int(y)
|
147 |
-
point_radius, point_color = 5, (0, 255, 0)
|
148 |
-
rocket_session.points.append([x, y])
|
149 |
-
rocket_session.points_label.append(1)
|
150 |
-
cv2.circle(display_image, (x, y), point_radius, point_color, -1)
|
151 |
-
return molmo_result, display_image
|
152 |
-
|
153 |
-
def extract_points(data):
|
154 |
-
# 匹配 x 和 y 坐标的值,支持 <points> 和 <point> 标签
|
155 |
-
pattern = r'x\d?="([-+]?\d*\.\d+|\d+)" y\d?="([-+]?\d*\.\d+|\d+)"'
|
156 |
-
points = re.findall(pattern, data)
|
157 |
-
# 将提取到的坐标转换为浮点数
|
158 |
-
points = [(float(x)/100*640, float(y)/100*360) for x, y in points]
|
159 |
-
return points
|
160 |
-
|
161 |
-
def draw_gradio_components(args):
|
162 |
-
|
163 |
-
with gr.Blocks() as demo:
|
164 |
-
|
165 |
-
gr.Markdown(
|
166 |
-
"""
|
167 |
-
# Welcome to Explore ROCKET-1 in Minecraft!!
|
168 |
-
## Please follow next steps to interact with the agent:
|
169 |
-
1. Reset the environment by selecting an environment name.
|
170 |
-
2. Select a SAM2 checkpoint to load.
|
171 |
-
3. Use your mouse to add or remove points on the image.
|
172 |
-
4. Select the segment type you want to perform.
|
173 |
-
5. Enable `tracking` mode if you want to track objects while stepping actions.
|
174 |
-
6. Click `New Segment` to segment the image based on the points you added.
|
175 |
-
7. Call the agent by clicking `Call Rocket` to run the agent for a certain number of steps.
|
176 |
-
## Hints:
|
177 |
-
1. You can use the `Make Video` button to generate a video of the agent's actions.
|
178 |
-
2. You can use the `Clear Memory` button to clear the ROCKET-1's memory.
|
179 |
-
3. You can use the `Clear Segment` button to clear SAM's memory.
|
180 |
-
4. You can use the `Manually Step` button to manually step the agent.
|
181 |
-
"""
|
182 |
-
)
|
183 |
-
|
184 |
-
rocket_session = gr.State(Session(
|
185 |
-
sam_path=args.sam_path,
|
186 |
-
))
|
187 |
-
molmo_session = gr.State(Pointer(
|
188 |
-
model_id="molmo-72b-0924",
|
189 |
-
model_url="http://172.17.30.127:8000/v1",
|
190 |
-
))
|
191 |
-
with gr.Row():
|
192 |
-
|
193 |
-
with gr.Column(scale=2):
|
194 |
-
# start_image = Image.open("start.png").resize((640, 360))
|
195 |
-
start_image = np.zeros((360, 640, 3), dtype=np.uint8)
|
196 |
-
|
197 |
-
with gr.Group():
|
198 |
-
display_image = gr.Image(
|
199 |
-
value=np.array(start_image),
|
200 |
-
interactive=False,
|
201 |
-
show_label=False,
|
202 |
-
label="Real-time Environment Observation",
|
203 |
-
streaming=True
|
204 |
-
)
|
205 |
-
display_status = gr.Textbox("Status Bar", interactive=False, show_label=False)
|
206 |
-
|
207 |
-
with gr.Column(scale=1):
|
208 |
-
|
209 |
-
sam_choice = gr.Radio(
|
210 |
-
choices=["large", "base", "small", "tiny"],
|
211 |
-
value="base",
|
212 |
-
label="Select SAM2 checkpoint",
|
213 |
-
)
|
214 |
-
sam_choice.select(fn=choose_sam_fn, inputs=[sam_choice, rocket_session], outputs=[rocket_session], show_progress=False)
|
215 |
-
|
216 |
-
with gr.Group():
|
217 |
-
add_or_remove = gr.Radio(
|
218 |
-
choices=["Add Points", "Remove Areas"],
|
219 |
-
value="Add Points",
|
220 |
-
label="Use you mouse to add or remove points",
|
221 |
-
)
|
222 |
-
clear_points_btn = gr.Button("Clear Points")
|
223 |
-
clear_points_btn.click(clear_points_fn, inputs=[rocket_session], outputs=[display_image, rocket_session], show_progress=True)
|
224 |
-
|
225 |
-
with gr.Group():
|
226 |
-
segment_type = gr.Radio(
|
227 |
-
choices=["Approach", "Interact", "Hunt", "Mine", "Craft", "Switch"],
|
228 |
-
value="Approach",
|
229 |
-
label="What do you want with this segment?",
|
230 |
-
)
|
231 |
-
track_flag = gr.Checkbox(True, label="Enable tracking objects while steping actions")
|
232 |
-
track_flag.select(fn=set_tracking_mode, inputs=[track_flag, rocket_session], outputs=[rocket_session], show_progress=False)
|
233 |
-
with gr.Group(), gr.Row():
|
234 |
-
new_segment_btn = gr.Button("New Segment")
|
235 |
-
clear_segment_btn = gr.Button("Clear Segment")
|
236 |
-
new_segment_btn.click(segment_fn, inputs=[rocket_session], outputs=[display_image, rocket_session], show_progress=True)
|
237 |
-
clear_segment_btn.click(clear_segment_fn, inputs=[rocket_session], outputs=[display_image, track_flag, rocket_session], show_progress=True)
|
238 |
-
|
239 |
-
display_image.select(get_points_with_draw, inputs=[display_image, add_or_remove, rocket_session], outputs=[display_image, rocket_session])
|
240 |
-
segment_type.select(set_segment_type, inputs=[segment_type, rocket_session], outputs=[rocket_session], show_progress=False)
|
241 |
-
|
242 |
-
with gr.Row():
|
243 |
-
with gr.Group():
|
244 |
-
env_list = [f"rocket/{x.stem}" for x in Path("../env_configs/rocket").glob("*.yaml") if 'base' not in x.name != 'base']
|
245 |
-
env_name = gr.Dropdown(env_list, multiselect=False, min_width=200, show_label=False, label="Env Name")
|
246 |
-
reset_btn = gr.Button("Reset Environment")
|
247 |
-
reset_btn.click(fn=reset_fn, inputs=[env_name, rocket_session], outputs=[display_image, rocket_session], show_progress=True)
|
248 |
-
|
249 |
-
with gr.Group():
|
250 |
-
action_list = [x for x in NOOP_ACTION.keys()]
|
251 |
-
act_key = gr.Dropdown(action_list, multiselect=False, min_width=200, show_label=False, label="Action")
|
252 |
-
step_btn = gr.Button("Manually Step")
|
253 |
-
step_btn.click(fn=step_fn, inputs=[act_key, rocket_session], outputs=[display_image, rocket_session], show_progress=False)
|
254 |
-
|
255 |
-
with gr.Group():
|
256 |
-
steps = gr.Slider(1, 600, 30, 1, label="Steps", show_label=False)
|
257 |
-
play_btn = gr.Button("Call Rocket")
|
258 |
-
play_btn.click(fn=loop_step_fn, inputs=[steps, rocket_session], outputs=[display_image, memory_length, display_status, rocket_session], show_progress=False)
|
259 |
-
|
260 |
-
with gr.Group():
|
261 |
-
memory_length.render()
|
262 |
-
clear_states_btn = gr.Button("Clear Memory")
|
263 |
-
clear_states_btn.click(fn=clear_memory_fn, inputs=rocket_session, outputs=[display_image, memory_length, rocket_session], show_progress=False)
|
264 |
-
|
265 |
-
make_video_btn = gr.Button("Make Video")
|
266 |
-
save_video_btn = gr.DownloadButton("Download!!", visible=False)
|
267 |
-
make_video_btn.click(make_video_fn, inputs=[rocket_session, make_video_btn, save_video_btn], outputs=[rocket_session, make_video_btn, save_video_btn], show_progress=False)
|
268 |
-
save_video_btn.click(save_video_fn, inputs=[rocket_session, make_video_btn, save_video_btn], outputs=[rocket_session, make_video_btn, save_video_btn], show_progress=False)
|
269 |
-
with gr.Row():
|
270 |
-
with gr.Group():
|
271 |
-
molmo_text = gr.Textbox("pinpoint the", label="Molmo Text", show_label=True, min_width=200)
|
272 |
-
molmo_btn = gr.Button("Generate")
|
273 |
-
output_text = gr.Textbox("", label="Molmo Output", show_label=False, min_width=200)
|
274 |
-
molmo_btn.click(molmo_fn, inputs=[molmo_text, molmo_session, rocket_session, display_image],outputs=[output_text, display_image],show_progress=False)
|
275 |
-
|
276 |
-
demo.queue()
|
277 |
-
demo.launch(share=False,server_port=args.port)
|
278 |
-
|
279 |
-
if __name__ == '__main__':
|
280 |
-
parser = argparse.ArgumentParser()
|
281 |
-
parser.add_argument("--port", type=int, default=7860)
|
282 |
-
parser.add_argument("--sam-path", type=str, default="/app/ROCKET-1/rocket/realtime_sam/checkpoints")
|
283 |
-
parser.add_argument("--molmo-id", type=str, default="molmo-72b-0924")
|
284 |
-
parser.add_argument("--molmo-url", type=str, default="http://127.0.0.1:8000/v1")
|
285 |
-
args = parser.parse_args()
|
286 |
-
draw_gradio_components(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|