ZiyuG commited on
Commit
9ba7304
·
verified ·
1 Parent(s): 8651c71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -116
app.py CHANGED
@@ -5,45 +5,32 @@ import plotly.graph_objects as go
5
  from sam2point import dataset
6
  import sam2point.configs as configs
7
  from demo_utils import run_demo, create_box
8
- # Sample data for dropdowns
 
9
  samples = {
10
  "3D Indoor Scene - S3DIS": ["Conference Room", "Restroom", "Lobby", "Office1", "Office2"],
11
- # "3D Indoor Scene - ScanNet": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5"],
12
  "3D Indoor Scene - ScanNet": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5", "Scene6"],
13
  "3D Outdoor Driving Scene - KITTI": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5", "Scene6"],
14
  "3D Outdoor Street Scene - Semantic3D": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5", "Scene6", "Scene7"],
15
  "3D Object - Objaverse": ["Plant", "Lego", "Lock", "Eleplant", "Knife Rest", "Skateboard", "Popcorn Machine", "Stove", "Bus Shelter", "Thor Hammer", "Horse"],
16
- # "3D Object - Objaverse": ["Plant", "Eleplant", "Knife Rest", "Skateboard", "Popcorn Machine", "Stove", "Bus Shelter", "Thor Hammer", "Horse", "Dinner Booth"],
17
  }
18
 
19
-
20
  PATH = {
21
  "S3DIS": ['Area_1_conferenceRoom_1.txt', 'Area_2_WC_1.txt', 'Area_4_lobby_2.txt', 'Area_5_office_3.txt', 'Area_6_office_9.txt'],
22
- # "ScanNet": ['scene0001_01.pth', 'scene0005_01.pth', 'scene0010_01.pth', 'scene0016_02.pth', 'scene0019_01.pth'],
23
  "ScanNet": ['scene0005_01.pth', 'scene0010_01.pth', 'scene0016_02.pth', 'scene0019_01.pth', 'scene0000_00.pth', 'scene0002_00.pth'],
24
  "Objaverse": ["plant.npy", "human.npy", "lock.npy", "elephant.npy", "knife_rest.npy", "skateboard.npy", "popcorn_machine.npy", "stove.npy", "bus_shelter.npy", "thor_hammer.npy", "horse.npy"],
25
- # "Objaverse": ["plant.npy", "elephant.npy", "knife_rest.npy", "skateboard.npy", "popcorn_machine.npy", "stove.npy", "bus_shelter.npy", "thor_hammer.npy", "horse.npy", "dinner_booth.npy"],
26
  "KITTI": ["scene1.npy", "scene2.npy", "scene3.npy", "scene4.npy", "scene5.npy", "scene6.npy"],
27
  "Semantic3D": ["scene1.npy", "scene2.npy", "patch19.npy", "patch0.npy", "patch1.npy", "patch50.npy", "patch62.npy"]
28
  }
29
 
30
-
31
  prompt_types = ["Point", "Box", "Mask"]
32
 
33
-
34
- # def select(name, sample_idx):
35
- # DATASET = name.split('-')[1].replace(" ", "")
36
- # gr.Info(f"Visualizing {DATASET} Example {str(sample_idx)}...")
37
-
38
-
39
-
40
-
41
- # Function to load and display 3D scene or object
42
  def load_3d_scene(name, sample_idx=-1, type_=None, prompt=None, final=False, new_color=None):
43
  DATASET = name.split('-')[1].replace(" ", "")
44
  path = 'data/' + DATASET + '/' + PATH[DATASET][sample_idx]
45
  asp, SIZE = 1., 1
46
- # load data
47
  print(path)
48
  if DATASET == 'S3DIS':
49
  point, color = dataset.load_S3DIS_sample(path, sample=True)
@@ -62,25 +49,14 @@ def load_3d_scene(name, sample_idx=-1, type_=None, prompt=None, final=False, new
62
  elif DATASET == 'Semantic3D':
63
  point, color = dataset.load_Semantic3D_sample(path, sample_idx, sample=True)
64
  alpha = 0.2
65
- print("Loading Dataset:", DATASET, ", Point Cloud Size:", point.shape)
66
 
67
-
68
- ##### Initial Showing #####
69
  if not type_:
70
- if point.shape[0] > 100000:
71
  indices = np.random.choice(point.shape[0], 100000, replace=False)
72
  point = point[indices]
73
  color = color[indices]
74
- # #NOTE KITTI
75
- # mask1 = point[:, 1] <= 0.8
76
- # mask4 = point[:, 1] >= 0.6
77
- # mask2 = point[:, 0] >= 0.3
78
- # mask3 = point[:, 0] <= 0.7
79
- # mask = mask1 & mask2 & mask3 & mask4
80
- # point = point[mask]
81
- # color = color[mask]
82
- # alpha = 1
83
- # ######
84
  fig = go.Figure(
85
  data=[
86
  go.Scatter3d(
@@ -101,7 +77,7 @@ def load_3d_scene(name, sample_idx=-1, type_=None, prompt=None, final=False, new
101
  )
102
  )
103
  return fig
104
- ##### Final
105
  if final:
106
  color = new_color
107
  green = np.array([[0.1, 0.1, 0.1]])
@@ -116,10 +92,6 @@ def load_3d_scene(name, sample_idx=-1, type_=None, prompt=None, final=False, new
116
  indices = np.random.choice(point.shape[0], 100000, replace=False)
117
  point = point[indices]
118
  color = color[indices]
119
- # mask = point[:, 1] < 0.8
120
- # point = point[mask]
121
- # color = color[mask]
122
- # alpha = 1
123
  scatter = go.Scatter3d(
124
  x=point[:,0], y=point[:,1], z=point[:,2],
125
  mode='markers',
@@ -128,36 +100,18 @@ def load_3d_scene(name, sample_idx=-1, type_=None, prompt=None, final=False, new
128
  )
129
  if final: scatter = [scatter, add_green] + create_box(prompt)
130
  else: scatter = [scatter] + create_box(prompt)
131
-
132
  elif type_ == "point":
133
  prompt = np.array([prompt])
134
  new = go.Scatter3d(
135
  x=prompt[:,0], y=prompt[:,1], z=prompt[:,2],
136
  mode='markers',
137
- # marker=dict(size=5, color='red', opacity=1),
138
- # marker=dict(size=5, color='rgb(255, 140, 0)', opacity=1),
139
- marker=dict(size=5, color='rgb(139, 0, 0)', opacity=1),
140
  name="Point Prompt"
141
  )
142
- # print(point.shape, color.shape, new_color.shape)
143
  if point.shape[0] > 100000:
144
  indices = np.random.choice(point.shape[0], 100000, replace=False)
145
  point = point[indices]
146
  color = color[indices]
147
- # #NOTE KITTI
148
- # mask1 = point[:, 1] <= 0.8
149
- # mask = point[:, 1] >= 0.35 #2
150
- # < 0.63 #3
151
- # mask2 = point[:, 0] >= 0.3
152
- # mask3 = point[:, 0] <= 0.7
153
- # mask = mask1 & mask2 & mask3 & mask4
154
- # #NOTE S3DIS
155
- # if DATASET == 'S3DIS':
156
- # mask = point[:, 0] > 0.04
157
- # point = point[mask]
158
- # color = color[mask]
159
- # alpha = 1
160
- # ######
161
  scatter = go.Scatter3d(
162
  x=point[:,0], y=point[:,1], z=point[:,2],
163
  mode='markers',
@@ -191,12 +145,6 @@ def load_3d_scene(name, sample_idx=-1, type_=None, prompt=None, final=False, new
191
  indices = np.random.choice(point.shape[0], 100000, replace=False)
192
  point = point[indices]
193
  color = color[indices]
194
- # # cut
195
- # mask = point[:, 0] > 0.1
196
- # point = point[mask]
197
- # color = color[mask]
198
- # alpha = 1
199
- # ######
200
  scatter = go.Scatter3d(
201
  x=point[:,0], y=point[:,1], z=point[:,2],
202
  mode='markers',
@@ -204,12 +152,10 @@ def load_3d_scene(name, sample_idx=-1, type_=None, prompt=None, final=False, new
204
  name="3D Object/Scene"
205
  )
206
  scatter = [scatter, add_green]
207
- print(point.shape, color.shape)
208
  else:
209
  print("Wrong Prompt Type")
210
  exit(1)
211
 
212
-
213
  fig = go.Figure(
214
  data=scatter,
215
  layout=dict(
@@ -224,25 +170,21 @@ def load_3d_scene(name, sample_idx=-1, type_=None, prompt=None, final=False, new
224
  )
225
  return fig
226
 
227
-
228
-
229
-
230
- # Function to display prompt in 3D
231
  def show_prompt_in_3d(name, sample_idx, prompt_type, prompt_idx):
 
 
 
232
  DATASET = name.split('-')[1].replace(" ", "")
233
  TYPE = prompt_type.lower()
234
  theta = 0. if DATASET in "S3DIS ScanNet" else 0.5
235
  mode = "bilinear" if DATASET in "S3DIS ScanNet" else 'nearest'
236
-
237
-
238
  prompt = run_demo(DATASET, TYPE, sample_idx, prompt_idx, 0.02, theta, mode, ret_prompt=True)
239
  fig = load_3d_scene(name, sample_idx, TYPE, prompt)
240
- return fig
241
-
242
 
243
-
244
-
245
- # Function to start segmentation
246
  def start_segmentation(name=None, sample_idx=None, prompt_type=None, prompt_idx=None, vx=0.02):
247
  if name == None or sample_idx == None or prompt_type == None or prompt_idx == None:
248
  return gr.Plot(), gr.Textbox(label="Response", value="Please ensure all options are selected.", visible=True)
@@ -252,26 +194,23 @@ def start_segmentation(name=None, sample_idx=None, prompt_type=None, prompt_idx=
252
  theta = 0. if DATASET in "S3DIS ScanNet" else 0.5
253
  mode = "bilinear" if DATASET in "S3DIS ScanNet" else 'nearest'
254
 
255
-
256
  new_color, prompt = run_demo(DATASET, TYPE, sample_idx, prompt_idx, vx, theta, mode, ret_prompt=False)
257
  fig = load_3d_scene(name, sample_idx, TYPE, prompt, final=True, new_color=new_color)
258
  return fig, gr.Textbox(label="Response", value="Segmentation completed successfully!", visible=True)
259
 
260
-
261
-
262
-
263
  def update1(datasets):
264
  if 'Objaverse' in datasets:
265
- return gr.Radio(label="Select 3D Object", choices=samples[datasets]), gr.Textbox(label="Response", value="", visible=True) #, gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=0.02)
266
- return gr.Radio(label="Select 3D Scene", choices=samples[datasets]), gr.Textbox(label="Response", value="", visible=True) #, gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=0.02)
267
-
268
 
 
269
  def update2(name, sample_idx, prompt_type):
270
  if name == None or sample_idx == None or prompt_type == None:
271
- return gr.Radio(label="Select Prompt Example", choices=[]), gr.Textbox(label="Response", value="", visible=True) #, gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=0.02)
272
  DATASET = name.split('-')[1].replace(" ", "")
273
  TYPE = prompt_type.lower() + '_prompts'
274
- # if DATASET in "ScanNet" and prompt_type == 'Mask': TYPE = 'point_prompts'
275
  if DATASET == 'S3DIS':
276
  info = configs.S3DIS_samples[sample_idx][TYPE]
277
  elif DATASET == 'ScanNet':
@@ -284,14 +223,15 @@ def update2(name, sample_idx, prompt_type):
284
  info = configs.Semantic3D_samples[sample_idx][TYPE]
285
 
286
  cur = ['Example ' + str(i) for i in range(1, len(info) + 1)]
287
- return gr.Radio(label="Select Prompt Example", choices=cur), gr.Textbox(label="Response", value="", visible=True) #, gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=0.02)
288
-
289
-
290
  def update3(name, sample_idx, prompt_type, prompt_idx):
291
  if name == None or sample_idx == None or prompt_type == None:
292
  return gr.Textbox(label="Response", value="", visible=True), gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=0.02)
293
  DATASET = name.split('-')[1].replace(" ", "")
294
  TYPE = configs.VOXEL[prompt_type.lower()]
 
295
  if DATASET in "S3DIS ScanNet":
296
  vx_ = 0.02
297
  elif DATASET == 'Objaverse':
@@ -303,14 +243,9 @@ def update3(name, sample_idx, prompt_type, prompt_idx):
303
 
304
  return gr.Textbox(label="Response", value="", visible=True), gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=vx_)
305
 
306
-
307
  def main():
308
- title = """<h1 style="font-variant: small-caps; font-weight: bold; text-align: center;" align="center">SAM2Point</h1>
309
- <h3 align="center"><b>Segment Any 3D as Videos in Zero-shot and Promptable Manners</h3>
310
- <br>
311
- """
312
- title = """
313
- <h1 style="text-align: center;">
314
  <div style="width: 1.2em; height: 1.2em; display: inline-block;"><img src="https://github.com/ZiyuGuo99/ZiyuGuo99.github.io/blob/main/assets/img/logo.png?raw=true" style='width: 100%; height: 100%; object-fit: contain;' /></div>
315
  <span style="font-variant: small-caps; font-weight: bold;">Sam2Point</span>
316
  </h1>
@@ -351,43 +286,23 @@ def main():
351
  prompt_type_dropdown = gr.Radio(label="Select Prompt Type", choices=prompt_types)
352
  prompt_sample_dropdown = gr.Radio(label="Select Prompt Example", choices=[], type="index")
353
  show_prompt_button = gr.Button("Show Prompt in 3D Scene/Object")
354
- # show_button.input(select, [sample_dropdown, scene_dropdown], [])
355
  with gr.Column():
356
- # vx = gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=0.02)
357
  start_segment_button = gr.Button("Start Segmentation")
358
  plot1 = gr.Plot()
359
 
360
-
361
-
362
-
363
  response = gr.Textbox(label="Response")
364
 
365
  sample_dropdown.change(update1, sample_dropdown, [scene_dropdown, response])
366
  sample_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response])
367
  scene_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response])
368
  prompt_type_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response])
369
-
370
- # sample_dropdown.change(update1, sample_dropdown, [scene_dropdown, response, vx])
371
- # sample_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response, vx])
372
- # scene_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response, vx])
373
- # prompt_type_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response, vx])
374
- # prompt_sample_dropdown.change(update3, [sample_dropdown, scene_dropdown, prompt_type_dropdown, prompt_sample_dropdown], [response, vx])
375
-
376
- # Logic to handle interactions
377
  show_button.click(load_3d_scene, inputs=[sample_dropdown, scene_dropdown], outputs=plot1)
378
- show_prompt_button.click(show_prompt_in_3d, inputs=[sample_dropdown, scene_dropdown, prompt_type_dropdown, prompt_sample_dropdown], outputs=plot1)
379
- # start_segment_button.click(start_segmentation, inputs=[sample_dropdown, scene_dropdown, prompt_type_dropdown, prompt_sample_dropdown, vx], outputs=[plot1, response])
380
  start_segment_button.click(start_segmentation, inputs=[sample_dropdown, scene_dropdown, prompt_type_dropdown, prompt_sample_dropdown], outputs=[plot1, response])
381
 
382
  app.queue(status_update_rate="auto")
383
  app.launch(share=True, favicon_path="./logo.png")
384
 
385
-
386
  if __name__ == "__main__":
387
- main()
388
-
389
-
390
-
391
-
392
-
393
-
 
5
  from sam2point import dataset
6
  import sam2point.configs as configs
7
  from demo_utils import run_demo, create_box
8
+ import spaces
9
+
10
  samples = {
11
  "3D Indoor Scene - S3DIS": ["Conference Room", "Restroom", "Lobby", "Office1", "Office2"],
 
12
  "3D Indoor Scene - ScanNet": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5", "Scene6"],
13
  "3D Outdoor Driving Scene - KITTI": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5", "Scene6"],
14
  "3D Outdoor Street Scene - Semantic3D": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5", "Scene6", "Scene7"],
15
  "3D Object - Objaverse": ["Plant", "Lego", "Lock", "Eleplant", "Knife Rest", "Skateboard", "Popcorn Machine", "Stove", "Bus Shelter", "Thor Hammer", "Horse"],
 
16
  }
17
 
 
18
  PATH = {
19
  "S3DIS": ['Area_1_conferenceRoom_1.txt', 'Area_2_WC_1.txt', 'Area_4_lobby_2.txt', 'Area_5_office_3.txt', 'Area_6_office_9.txt'],
 
20
  "ScanNet": ['scene0005_01.pth', 'scene0010_01.pth', 'scene0016_02.pth', 'scene0019_01.pth', 'scene0000_00.pth', 'scene0002_00.pth'],
21
  "Objaverse": ["plant.npy", "human.npy", "lock.npy", "elephant.npy", "knife_rest.npy", "skateboard.npy", "popcorn_machine.npy", "stove.npy", "bus_shelter.npy", "thor_hammer.npy", "horse.npy"],
 
22
  "KITTI": ["scene1.npy", "scene2.npy", "scene3.npy", "scene4.npy", "scene5.npy", "scene6.npy"],
23
  "Semantic3D": ["scene1.npy", "scene2.npy", "patch19.npy", "patch0.npy", "patch1.npy", "patch50.npy", "patch62.npy"]
24
  }
25
 
 
26
  prompt_types = ["Point", "Box", "Mask"]
27
 
28
+ @spaces.GPU()
 
 
 
 
 
 
 
 
29
  def load_3d_scene(name, sample_idx=-1, type_=None, prompt=None, final=False, new_color=None):
30
  DATASET = name.split('-')[1].replace(" ", "")
31
  path = 'data/' + DATASET + '/' + PATH[DATASET][sample_idx]
32
  asp, SIZE = 1., 1
33
+
34
  print(path)
35
  if DATASET == 'S3DIS':
36
  point, color = dataset.load_S3DIS_sample(path, sample=True)
 
49
  elif DATASET == 'Semantic3D':
50
  point, color = dataset.load_Semantic3D_sample(path, sample_idx, sample=True)
51
  alpha = 0.2
52
+ print("Loading Dataset:", DATASET, "Point Cloud Size:", point.shape, "Path:", path)
53
 
54
+ ##### Initial Show #####
 
55
  if not type_:
56
+ if point.shape[0] > 100000: # sample points for speeding up
57
  indices = np.random.choice(point.shape[0], 100000, replace=False)
58
  point = point[indices]
59
  color = color[indices]
 
 
 
 
 
 
 
 
 
 
60
  fig = go.Figure(
61
  data=[
62
  go.Scatter3d(
 
77
  )
78
  )
79
  return fig
80
+ ##### Final Results #####
81
  if final:
82
  color = new_color
83
  green = np.array([[0.1, 0.1, 0.1]])
 
92
  indices = np.random.choice(point.shape[0], 100000, replace=False)
93
  point = point[indices]
94
  color = color[indices]
 
 
 
 
95
  scatter = go.Scatter3d(
96
  x=point[:,0], y=point[:,1], z=point[:,2],
97
  mode='markers',
 
100
  )
101
  if final: scatter = [scatter, add_green] + create_box(prompt)
102
  else: scatter = [scatter] + create_box(prompt)
 
103
  elif type_ == "point":
104
  prompt = np.array([prompt])
105
  new = go.Scatter3d(
106
  x=prompt[:,0], y=prompt[:,1], z=prompt[:,2],
107
  mode='markers',
108
+ marker=dict(size=5, color='red', opacity=1),
 
 
109
  name="Point Prompt"
110
  )
 
111
  if point.shape[0] > 100000:
112
  indices = np.random.choice(point.shape[0], 100000, replace=False)
113
  point = point[indices]
114
  color = color[indices]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  scatter = go.Scatter3d(
116
  x=point[:,0], y=point[:,1], z=point[:,2],
117
  mode='markers',
 
145
  indices = np.random.choice(point.shape[0], 100000, replace=False)
146
  point = point[indices]
147
  color = color[indices]
 
 
 
 
 
 
148
  scatter = go.Scatter3d(
149
  x=point[:,0], y=point[:,1], z=point[:,2],
150
  mode='markers',
 
152
  name="3D Object/Scene"
153
  )
154
  scatter = [scatter, add_green]
 
155
  else:
156
  print("Wrong Prompt Type")
157
  exit(1)
158
 
 
159
  fig = go.Figure(
160
  data=scatter,
161
  layout=dict(
 
170
  )
171
  return fig
172
 
173
+ @spaces.GPU()
 
 
 
174
  def show_prompt_in_3d(name, sample_idx, prompt_type, prompt_idx):
175
+ if name == None or sample_idx == None or prompt_type == None or prompt_idx == None:
176
+ return gr.Plot(), gr.Textbox(label="Response", value="Please ensure all options are selected.", visible=True)
177
+
178
  DATASET = name.split('-')[1].replace(" ", "")
179
  TYPE = prompt_type.lower()
180
  theta = 0. if DATASET in "S3DIS ScanNet" else 0.5
181
  mode = "bilinear" if DATASET in "S3DIS ScanNet" else 'nearest'
182
+
 
183
  prompt = run_demo(DATASET, TYPE, sample_idx, prompt_idx, 0.02, theta, mode, ret_prompt=True)
184
  fig = load_3d_scene(name, sample_idx, TYPE, prompt)
185
+ return fig, gr.Textbox(label="Response", value="Prompt has been shown in 3D Object/Scene!", visible=True)
 
186
 
187
+ @spaces.GPU()
 
 
188
  def start_segmentation(name=None, sample_idx=None, prompt_type=None, prompt_idx=None, vx=0.02):
189
  if name == None or sample_idx == None or prompt_type == None or prompt_idx == None:
190
  return gr.Plot(), gr.Textbox(label="Response", value="Please ensure all options are selected.", visible=True)
 
194
  theta = 0. if DATASET in "S3DIS ScanNet" else 0.5
195
  mode = "bilinear" if DATASET in "S3DIS ScanNet" else 'nearest'
196
 
 
197
  new_color, prompt = run_demo(DATASET, TYPE, sample_idx, prompt_idx, vx, theta, mode, ret_prompt=False)
198
  fig = load_3d_scene(name, sample_idx, TYPE, prompt, final=True, new_color=new_color)
199
  return fig, gr.Textbox(label="Response", value="Segmentation completed successfully!", visible=True)
200
 
201
+ @spaces.GPU()
 
 
202
  def update1(datasets):
203
  if 'Objaverse' in datasets:
204
+ return gr.Radio(label="Select 3D Object", choices=samples[datasets]), gr.Textbox(label="Response", value="", visible=True)
205
+ return gr.Radio(label="Select 3D Scene", choices=samples[datasets]), gr.Textbox(label="Response", value="", visible=True)
 
206
 
207
+ @spaces.GPU()
208
  def update2(name, sample_idx, prompt_type):
209
  if name == None or sample_idx == None or prompt_type == None:
210
+ return gr.Radio(label="Select Prompt Example", choices=[]), gr.Textbox(label="Response", value="", visible=True)
211
  DATASET = name.split('-')[1].replace(" ", "")
212
  TYPE = prompt_type.lower() + '_prompts'
213
+
214
  if DATASET == 'S3DIS':
215
  info = configs.S3DIS_samples[sample_idx][TYPE]
216
  elif DATASET == 'ScanNet':
 
223
  info = configs.Semantic3D_samples[sample_idx][TYPE]
224
 
225
  cur = ['Example ' + str(i) for i in range(1, len(info) + 1)]
226
+ return gr.Radio(label="Select Prompt Example", choices=cur), gr.Textbox(label="Response", value="", visible=True)
227
+
228
+ @spaces.GPU()
229
  def update3(name, sample_idx, prompt_type, prompt_idx):
230
  if name == None or sample_idx == None or prompt_type == None:
231
  return gr.Textbox(label="Response", value="", visible=True), gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=0.02)
232
  DATASET = name.split('-')[1].replace(" ", "")
233
  TYPE = configs.VOXEL[prompt_type.lower()]
234
+
235
  if DATASET in "S3DIS ScanNet":
236
  vx_ = 0.02
237
  elif DATASET == 'Objaverse':
 
243
 
244
  return gr.Textbox(label="Response", value="", visible=True), gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=vx_)
245
 
246
+ @spaces.GPU()
247
  def main():
248
+ title = """<h1 style="text-align: center;">
 
 
 
 
 
249
  <div style="width: 1.2em; height: 1.2em; display: inline-block;"><img src="https://github.com/ZiyuGuo99/ZiyuGuo99.github.io/blob/main/assets/img/logo.png?raw=true" style='width: 100%; height: 100%; object-fit: contain;' /></div>
250
  <span style="font-variant: small-caps; font-weight: bold;">Sam2Point</span>
251
  </h1>
 
286
  prompt_type_dropdown = gr.Radio(label="Select Prompt Type", choices=prompt_types)
287
  prompt_sample_dropdown = gr.Radio(label="Select Prompt Example", choices=[], type="index")
288
  show_prompt_button = gr.Button("Show Prompt in 3D Scene/Object")
 
289
  with gr.Column():
 
290
  start_segment_button = gr.Button("Start Segmentation")
291
  plot1 = gr.Plot()
292
 
 
 
 
293
  response = gr.Textbox(label="Response")
294
 
295
  sample_dropdown.change(update1, sample_dropdown, [scene_dropdown, response])
296
  sample_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response])
297
  scene_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response])
298
  prompt_type_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response])
299
+
 
 
 
 
 
 
 
300
  show_button.click(load_3d_scene, inputs=[sample_dropdown, scene_dropdown], outputs=plot1)
301
+ show_prompt_button.click(show_prompt_in_3d, inputs=[sample_dropdown, scene_dropdown, prompt_type_dropdown, prompt_sample_dropdown], outputs=[plot1, response])
 
302
  start_segment_button.click(start_segmentation, inputs=[sample_dropdown, scene_dropdown, prompt_type_dropdown, prompt_sample_dropdown], outputs=[plot1, response])
303
 
304
  app.queue(status_update_rate="auto")
305
  app.launch(share=True, favicon_path="./logo.png")
306
 
 
307
  if __name__ == "__main__":
308
+ main()