phython96 commited on
Commit
da7e628
1 Parent(s): 1cff57f

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -286
app.py DELETED
@@ -1,286 +0,0 @@
1
- '''
2
- author: caishaofei <[email protected]>
3
- date: 2024-09-20 20:10:44
4
- Copyright © Team CraftJarvis All rights reserved
5
- '''
6
- import re
7
- import os
8
- import cv2
9
- import time
10
- from pathlib import Path
11
- import argparse
12
- import requests
13
- import gradio as gr
14
- import torch
15
- import numpy as np
16
- from io import BytesIO
17
- from PIL import Image, ImageDraw
18
- from rocket.arm.sessions import Session, Pointer
19
-
20
- COLORS = [
21
- (255, 0, 0), (0, 255, 0), (0, 0, 255),
22
- (255, 255, 0), (255, 0, 255), (0, 255, 255),
23
- (255, 255, 255), (0, 0, 0), (128, 128, 128),
24
- (128, 0, 0), (128, 128, 0), (0, 128, 0),
25
- (128, 0, 128), (0, 128, 128), (0, 0, 128),
26
- ]
27
-
28
- SEGMENT_MAPPING = {
29
- "Hunt": 0, "Use": 3, "Mine": 2, "Interact": 3, "Craft": 4, "Switch": 5, "Approach": 6
30
- }
31
-
32
- NOOP_ACTION = {
33
- "back": 0,
34
- "drop": 0,
35
- "forward": 0,
36
- "hotbar.1": 0,
37
- "hotbar.2": 0,
38
- "hotbar.3": 0,
39
- "hotbar.4": 0,
40
- "hotbar.5": 0,
41
- "hotbar.6": 0,
42
- "hotbar.7": 0,
43
- "hotbar.8": 0,
44
- "hotbar.9": 0,
45
- "inventory": 0,
46
- "jump": 0,
47
- "left": 0,
48
- "right": 0,
49
- "sneak": 0,
50
- "sprint": 0,
51
- "camera": np.array([0, 0]),
52
- "attack": 0,
53
- "use": 0,
54
- }
55
-
56
- def reset_fn(env_name, session):
57
- image = session.reset(env_name)
58
- return image, session
59
-
60
- def step_fn(act_key, session):
61
- action = NOOP_ACTION.copy()
62
- if act_key != "null":
63
- action[act_key] = 1
64
- image = session.step(action)
65
- return image, session
66
-
67
- def loop_step_fn(steps, session):
68
- for i in range(steps):
69
- image = session.step()
70
- status = f"Running Agent `Rocket` steps: {i+1}/{steps}. "
71
- yield image, session.num_steps, status, session
72
-
73
- def clear_memory_fn(session):
74
- image = session.current_image
75
- session.clear_agent_memory()
76
- return image, "0", session
77
-
78
- def get_points_with_draw(image, label, session, evt: gr.SelectData):
79
- points = session.points
80
- point_label = session.points_label
81
- x, y = evt.index[0], evt.index[1]
82
- point_radius, point_color = 5, (0, 255, 0) if label == 'Add Points' else (255, 0, 0)
83
- points.append([x, y])
84
- point_label.append(1 if label == 'Add Points' else 0)
85
- cv2.circle(image, (x, y), point_radius, point_color, -1)
86
- return image, session
87
-
88
- def clear_points_fn(session):
89
- session.clear_points()
90
- return session.current_image, session
91
-
92
- def segment_fn(session):
93
- if len(session.points) == 0:
94
- return session.current_image, session
95
- session.segment()
96
- image = session.apply_mask()
97
- return image, session
98
-
99
- def clear_segment_fn(session):
100
- session.clear_obj_mask()
101
- session.tracking_flag = False
102
- return session.current_image, False, session
103
-
104
- def set_tracking_mode(tracking_flag, session):
105
- session.tracking_flag = tracking_flag
106
- return session
107
-
108
- def set_segment_type(segment_type, session):
109
- session.segment_type = segment_type
110
- return session
111
-
112
- def play_fn(session):
113
- image = session.step()
114
- return image, session
115
-
116
- memory_length = gr.Textbox(value="0", interactive=False, show_label=False)
117
-
118
- def make_video_fn(session, make_video, save_video, progress=gr.Progress()):
119
- images = session.image_history
120
- if len(images) == 0:
121
- return session, make_video, save_video
122
- filepath = "rocket.mp4"
123
- h, w = images[0].shape[:2]
124
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
125
- video = cv2.VideoWriter(filepath, fourcc, 20.0, (w, h))
126
- for image in progress.tqdm(images):
127
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
128
- video.write(image)
129
- video.release()
130
- session.image_history = []
131
- return session, gr.Button("Make Video", visible=False), gr.DownloadButton("Download!", value=filepath, visible=True)
132
-
133
- def save_video_fn(session, make_video, save_video):
134
- return session, gr.Button("Make Video", visible=True), gr.DownloadButton("Download!", visible=False)
135
-
136
- def choose_sam_fn(sam_choice, session):
137
- session.sam_choice = sam_choice
138
- session.load_sam()
139
- return session
140
-
141
- def molmo_fn(molmo_text, molmo_session, rocket_session, display_image):
142
- image = rocket_session.current_image.copy()
143
- points = molmo_session.gen_point(image=image, prompt=molmo_text)
144
- molmo_result = molmo_session.molmo_result
145
- for x, y in points:
146
- x, y = int(x), int(y)
147
- point_radius, point_color = 5, (0, 255, 0)
148
- rocket_session.points.append([x, y])
149
- rocket_session.points_label.append(1)
150
- cv2.circle(display_image, (x, y), point_radius, point_color, -1)
151
- return molmo_result, display_image
152
-
153
- def extract_points(data):
154
- # 匹配 x 和 y 坐标的值,支持 <points> 和 <point> 标签
155
- pattern = r'x\d?="([-+]?\d*\.\d+|\d+)" y\d?="([-+]?\d*\.\d+|\d+)"'
156
- points = re.findall(pattern, data)
157
- # 将提取到的坐标转换为浮点数
158
- points = [(float(x)/100*640, float(y)/100*360) for x, y in points]
159
- return points
160
-
161
- def draw_gradio_components(args):
162
-
163
- with gr.Blocks() as demo:
164
-
165
- gr.Markdown(
166
- """
167
- # Welcome to Explore ROCKET-1 in Minecraft!!
168
- ## Please follow next steps to interact with the agent:
169
- 1. Reset the environment by selecting an environment name.
170
- 2. Select a SAM2 checkpoint to load.
171
- 3. Use your mouse to add or remove points on the image.
172
- 4. Select the segment type you want to perform.
173
- 5. Enable `tracking` mode if you want to track objects while stepping actions.
174
- 6. Click `New Segment` to segment the image based on the points you added.
175
- 7. Call the agent by clicking `Call Rocket` to run the agent for a certain number of steps.
176
- ## Hints:
177
- 1. You can use the `Make Video` button to generate a video of the agent's actions.
178
- 2. You can use the `Clear Memory` button to clear the ROCKET-1's memory.
179
- 3. You can use the `Clear Segment` button to clear SAM's memory.
180
- 4. You can use the `Manually Step` button to manually step the agent.
181
- """
182
- )
183
-
184
- rocket_session = gr.State(Session(
185
- sam_path=args.sam_path,
186
- ))
187
- molmo_session = gr.State(Pointer(
188
- model_id="molmo-72b-0924",
189
- model_url="http://172.17.30.127:8000/v1",
190
- ))
191
- with gr.Row():
192
-
193
- with gr.Column(scale=2):
194
- # start_image = Image.open("start.png").resize((640, 360))
195
- start_image = np.zeros((360, 640, 3), dtype=np.uint8)
196
-
197
- with gr.Group():
198
- display_image = gr.Image(
199
- value=np.array(start_image),
200
- interactive=False,
201
- show_label=False,
202
- label="Real-time Environment Observation",
203
- streaming=True
204
- )
205
- display_status = gr.Textbox("Status Bar", interactive=False, show_label=False)
206
-
207
- with gr.Column(scale=1):
208
-
209
- sam_choice = gr.Radio(
210
- choices=["large", "base", "small", "tiny"],
211
- value="base",
212
- label="Select SAM2 checkpoint",
213
- )
214
- sam_choice.select(fn=choose_sam_fn, inputs=[sam_choice, rocket_session], outputs=[rocket_session], show_progress=False)
215
-
216
- with gr.Group():
217
- add_or_remove = gr.Radio(
218
- choices=["Add Points", "Remove Areas"],
219
- value="Add Points",
220
- label="Use you mouse to add or remove points",
221
- )
222
- clear_points_btn = gr.Button("Clear Points")
223
- clear_points_btn.click(clear_points_fn, inputs=[rocket_session], outputs=[display_image, rocket_session], show_progress=True)
224
-
225
- with gr.Group():
226
- segment_type = gr.Radio(
227
- choices=["Approach", "Interact", "Hunt", "Mine", "Craft", "Switch"],
228
- value="Approach",
229
- label="What do you want with this segment?",
230
- )
231
- track_flag = gr.Checkbox(True, label="Enable tracking objects while steping actions")
232
- track_flag.select(fn=set_tracking_mode, inputs=[track_flag, rocket_session], outputs=[rocket_session], show_progress=False)
233
- with gr.Group(), gr.Row():
234
- new_segment_btn = gr.Button("New Segment")
235
- clear_segment_btn = gr.Button("Clear Segment")
236
- new_segment_btn.click(segment_fn, inputs=[rocket_session], outputs=[display_image, rocket_session], show_progress=True)
237
- clear_segment_btn.click(clear_segment_fn, inputs=[rocket_session], outputs=[display_image, track_flag, rocket_session], show_progress=True)
238
-
239
- display_image.select(get_points_with_draw, inputs=[display_image, add_or_remove, rocket_session], outputs=[display_image, rocket_session])
240
- segment_type.select(set_segment_type, inputs=[segment_type, rocket_session], outputs=[rocket_session], show_progress=False)
241
-
242
- with gr.Row():
243
- with gr.Group():
244
- env_list = [f"rocket/{x.stem}" for x in Path("../env_configs/rocket").glob("*.yaml") if 'base' not in x.name != 'base']
245
- env_name = gr.Dropdown(env_list, multiselect=False, min_width=200, show_label=False, label="Env Name")
246
- reset_btn = gr.Button("Reset Environment")
247
- reset_btn.click(fn=reset_fn, inputs=[env_name, rocket_session], outputs=[display_image, rocket_session], show_progress=True)
248
-
249
- with gr.Group():
250
- action_list = [x for x in NOOP_ACTION.keys()]
251
- act_key = gr.Dropdown(action_list, multiselect=False, min_width=200, show_label=False, label="Action")
252
- step_btn = gr.Button("Manually Step")
253
- step_btn.click(fn=step_fn, inputs=[act_key, rocket_session], outputs=[display_image, rocket_session], show_progress=False)
254
-
255
- with gr.Group():
256
- steps = gr.Slider(1, 600, 30, 1, label="Steps", show_label=False)
257
- play_btn = gr.Button("Call Rocket")
258
- play_btn.click(fn=loop_step_fn, inputs=[steps, rocket_session], outputs=[display_image, memory_length, display_status, rocket_session], show_progress=False)
259
-
260
- with gr.Group():
261
- memory_length.render()
262
- clear_states_btn = gr.Button("Clear Memory")
263
- clear_states_btn.click(fn=clear_memory_fn, inputs=rocket_session, outputs=[display_image, memory_length, rocket_session], show_progress=False)
264
-
265
- make_video_btn = gr.Button("Make Video")
266
- save_video_btn = gr.DownloadButton("Download!!", visible=False)
267
- make_video_btn.click(make_video_fn, inputs=[rocket_session, make_video_btn, save_video_btn], outputs=[rocket_session, make_video_btn, save_video_btn], show_progress=False)
268
- save_video_btn.click(save_video_fn, inputs=[rocket_session, make_video_btn, save_video_btn], outputs=[rocket_session, make_video_btn, save_video_btn], show_progress=False)
269
- with gr.Row():
270
- with gr.Group():
271
- molmo_text = gr.Textbox("pinpoint the", label="Molmo Text", show_label=True, min_width=200)
272
- molmo_btn = gr.Button("Generate")
273
- output_text = gr.Textbox("", label="Molmo Output", show_label=False, min_width=200)
274
- molmo_btn.click(molmo_fn, inputs=[molmo_text, molmo_session, rocket_session, display_image],outputs=[output_text, display_image],show_progress=False)
275
-
276
- demo.queue()
277
- demo.launch(share=False,server_port=args.port)
278
-
279
- if __name__ == '__main__':
280
- parser = argparse.ArgumentParser()
281
- parser.add_argument("--port", type=int, default=7860)
282
- parser.add_argument("--sam-path", type=str, default="/app/ROCKET-1/rocket/realtime_sam/checkpoints")
283
- parser.add_argument("--molmo-id", type=str, default="molmo-72b-0924")
284
- parser.add_argument("--molmo-url", type=str, default="http://127.0.0.1:8000/v1")
285
- args = parser.parse_args()
286
- draw_gradio_components(args)