SkalskiP commited on
Commit
aabd771
1 Parent(s): 16d828f

working on video inference

Browse files
Files changed (4) hide show
  1. app.py +133 -17
  2. requirements.txt +1 -0
  3. utils/models.py +1 -1
  4. utils/video.py +14 -0
app.py CHANGED
@@ -1,14 +1,19 @@
 
1
  from typing import Optional
2
 
 
3
  import gradio as gr
4
  import numpy as np
5
  import supervision as sv
6
  import torch
7
  from PIL import Image
 
8
  from gradio_image_prompter import ImagePrompter
9
 
10
  from utils.models import load_models, CHECKPOINT_NAMES, MODE_NAMES, \
11
- MASK_GENERATION_MODE, BOX_PROMPT_MODE
 
 
12
 
13
  MARKDOWN = """
14
  # Segment Anything Model 2 🔥
@@ -31,6 +36,7 @@ Segment Anything Model 2 (SAM 2) is a foundation model designed to address promp
31
  visual segmentation in both images and videos. **Video segmentation will be available
32
  soon.**
33
  """
 
34
  EXAMPLES = [
35
  ["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", None],
36
  ["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", None],
@@ -41,8 +47,37 @@ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41
  MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
42
  IMAGE_PREDICTORS, MASK_GENERATORS = load_models(device=DEVICE)
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- def process(
 
 
 
 
 
 
 
 
 
 
 
46
  checkpoint_dropdown,
47
  mode_dropdown,
48
  image_input,
@@ -79,6 +114,64 @@ def process(
79
  return MASK_ANNOTATOR.annotate(image_input, detections)
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  with gr.Blocks() as demo:
83
  gr.Markdown(MARKDOWN)
84
  with gr.Row():
@@ -94,7 +187,8 @@ with gr.Blocks() as demo:
94
  label="Mode",
95
  info="Select a mode to use. `box prompt` if you want to generate masks for "
96
  "selected objects, `mask generation` if you want to generate masks "
97
- "for the whole image.",
 
98
  interactive=True
99
  )
100
  with gr.Row():
@@ -102,14 +196,22 @@ with gr.Blocks() as demo:
102
  image_input_component = gr.Image(
103
  type='pil', label='Upload image', visible=False)
104
  image_prompter_input_component = ImagePrompter(
105
- type='pil', label='Image prompt')
106
- submit_button_component = gr.Button(
 
 
 
 
107
  value='Submit', variant='primary')
 
 
108
  with gr.Column():
109
- image_output_component = gr.Image(type='pil', label='Image Output')
 
 
110
  with gr.Row():
111
  gr.Examples(
112
- fn=process,
113
  examples=EXAMPLES,
114
  inputs=[
115
  checkpoint_dropdown_component,
@@ -121,23 +223,27 @@ with gr.Blocks() as demo:
121
  run_on_click=True
122
  )
123
 
124
-
125
- def on_mode_dropdown_change(text):
126
- return [
127
- gr.Image(visible=text == MASK_GENERATION_MODE),
128
- ImagePrompter(visible=text == BOX_PROMPT_MODE)
129
- ]
130
-
131
  mode_dropdown_component.change(
132
  on_mode_dropdown_change,
133
  inputs=[mode_dropdown_component],
134
  outputs=[
135
  image_input_component,
136
- image_prompter_input_component
 
 
 
 
 
 
137
  ]
138
  )
139
- submit_button_component.click(
140
- fn=process,
 
 
 
 
 
141
  inputs=[
142
  checkpoint_dropdown_component,
143
  mode_dropdown_component,
@@ -146,5 +252,15 @@ with gr.Blocks() as demo:
146
  ],
147
  outputs=[image_output_component]
148
  )
 
 
 
 
 
 
 
 
 
 
149
 
150
  demo.launch(debug=False, show_error=True, max_threads=1)
 
1
+ import os
2
  from typing import Optional
3
 
4
+ import cv2
5
  import gradio as gr
6
  import numpy as np
7
  import supervision as sv
8
  import torch
9
  from PIL import Image
10
+ from tqdm import tqdm
11
  from gradio_image_prompter import ImagePrompter
12
 
13
  from utils.models import load_models, CHECKPOINT_NAMES, MODE_NAMES, \
14
+ MASK_GENERATION_MODE, BOX_PROMPT_MODE, VIDEO_SEGMENTATION_MODE
15
+ from utils.video import create_directory, generate_unique_name
16
+ from sam2.build_sam import build_sam2_video_predictor
17
 
18
  MARKDOWN = """
19
  # Segment Anything Model 2 🔥
 
36
  visual segmentation in both images and videos. **Video segmentation will be available
37
  soon.**
38
  """
39
+
40
  EXAMPLES = [
41
  ["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", None],
42
  ["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", None],
 
47
  MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
48
  IMAGE_PREDICTORS, MASK_GENERATORS = load_models(device=DEVICE)
49
 
50
+ SCALE_FACTOR = 0.5
51
+ TARGET_DIRECTORY = "tmp"
52
+ # creating video results directory
53
+ create_directory(directory_path=TARGET_DIRECTORY)
54
+
55
+
56
+ def on_mode_dropdown_change(text):
57
+ return [
58
+ gr.Image(visible=text == MASK_GENERATION_MODE),
59
+ ImagePrompter(visible=text == BOX_PROMPT_MODE),
60
+ gr.Video(visible=text == VIDEO_SEGMENTATION_MODE),
61
+ ImagePrompter(visible=text == VIDEO_SEGMENTATION_MODE),
62
+ gr.Button(visible=text != VIDEO_SEGMENTATION_MODE),
63
+ gr.Button(visible=text == VIDEO_SEGMENTATION_MODE),
64
+ gr.Image(visible=text != VIDEO_SEGMENTATION_MODE),
65
+ gr.Video(visible=text == VIDEO_SEGMENTATION_MODE)
66
+ ]
67
+
68
 
69
+ def on_video_input_change(video_input):
70
+ if not video_input:
71
+ return None
72
+ frames_generator = sv.get_video_frames_generator(video_input)
73
+ frame = next(frames_generator)
74
+ frame = sv.scale_image(frame, SCALE_FACTOR)
75
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
76
+ frame = Image.fromarray(frame)
77
+ return {'image': frame, 'points': []}
78
+
79
+
80
+ def process_image(
81
  checkpoint_dropdown,
82
  mode_dropdown,
83
  image_input,
 
114
  return MASK_ANNOTATOR.annotate(image_input, detections)
115
 
116
 
117
+ def process_video(
118
+ checkpoint_dropdown,
119
+ mode_dropdown,
120
+ video_input,
121
+ video_prompter_input,
122
+ progress=gr.Progress(track_tqdm=True)
123
+ ) -> str:
124
+ if mode_dropdown != VIDEO_SEGMENTATION_MODE:
125
+ return str(video_input)
126
+
127
+ name = generate_unique_name()
128
+ frame_directory_path = os.path.join(TARGET_DIRECTORY, name)
129
+ frames_sink = sv.ImageSink(
130
+ target_dir_path=frame_directory_path,
131
+ image_name_pattern="{:05d}.jpeg"
132
+ )
133
+
134
+ video_info = sv.VideoInfo.from_video_path(video_input)
135
+ frames_generator = sv.get_video_frames_generator(video_input)
136
+ with frames_sink:
137
+ for frame in tqdm(
138
+ frames_generator,
139
+ total=video_info.total_frames,
140
+ desc="splitting video into frames"
141
+ ):
142
+ frame = sv.scale_image(frame, SCALE_FACTOR)
143
+ frames_sink.save_image(frame)
144
+
145
+ model = build_sam2_video_predictor(
146
+ "sam2_hiera_t.yaml",
147
+ "checkpoints/sam2_hiera_tiny.pt",
148
+ device=DEVICE
149
+ )
150
+ inference_state = model.init_state(
151
+ video_path=frame_directory_path,
152
+ offload_video_to_cpu=DEVICE == torch.device('cpu'),
153
+ offload_state_to_cpu=DEVICE == torch.device('cpu'),
154
+ )
155
+
156
+ prompt = video_prompter_input["points"]
157
+ points = np.array([[x1, y1] for x1, y1, _, _, _, _ in prompt])
158
+ labels = np.ones(len(points))
159
+
160
+ _, object_ids, mask_logits = model.add_new_points(
161
+ inference_state=inference_state,
162
+ frame_idx=0,
163
+ obj_id=1,
164
+ points=points,
165
+ labels=labels,
166
+ )
167
+
168
+ del inference_state
169
+ del model
170
+
171
+ video_path = os.path.join(TARGET_DIRECTORY, f"{name}.mp4")
172
+ return str(video_input)
173
+
174
+
175
  with gr.Blocks() as demo:
176
  gr.Markdown(MARKDOWN)
177
  with gr.Row():
 
187
  label="Mode",
188
  info="Select a mode to use. `box prompt` if you want to generate masks for "
189
  "selected objects, `mask generation` if you want to generate masks "
190
+ "for the whole image, and `video segmentation` if you want to track "
191
+ "object on video.",
192
  interactive=True
193
  )
194
  with gr.Row():
 
196
  image_input_component = gr.Image(
197
  type='pil', label='Upload image', visible=False)
198
  image_prompter_input_component = ImagePrompter(
199
+ type='pil', label='Prompt image')
200
+ video_input_component = gr.Video(
201
+ label='Step 1: Upload video', visible=False)
202
+ video_prompter_input_component = ImagePrompter(
203
+ type='pil', label='Step 2: Prompt frame', visible=False)
204
+ submit_image_button_component = gr.Button(
205
  value='Submit', variant='primary')
206
+ submit_video_button_component = gr.Button(
207
+ value='Submit', variant='primary', visible=False)
208
  with gr.Column():
209
+ image_output_component = gr.Image(type='pil', label='Image output')
210
+ video_output_component = gr.Video(
211
+ label='Step 2: Video output', visible=False)
212
  with gr.Row():
213
  gr.Examples(
214
+ fn=process_image,
215
  examples=EXAMPLES,
216
  inputs=[
217
  checkpoint_dropdown_component,
 
223
  run_on_click=True
224
  )
225
 
 
 
 
 
 
 
 
226
  mode_dropdown_component.change(
227
  on_mode_dropdown_change,
228
  inputs=[mode_dropdown_component],
229
  outputs=[
230
  image_input_component,
231
+ image_prompter_input_component,
232
+ video_input_component,
233
+ video_prompter_input_component,
234
+ submit_image_button_component,
235
+ submit_video_button_component,
236
+ image_output_component,
237
+ video_output_component
238
  ]
239
  )
240
+ video_input_component.change(
241
+ fn=on_video_input_change,
242
+ inputs=[video_input_component],
243
+ outputs=[video_prompter_input_component]
244
+ )
245
+ submit_image_button_component.click(
246
+ fn=process_image,
247
  inputs=[
248
  checkpoint_dropdown_component,
249
  mode_dropdown_component,
 
252
  ],
253
  outputs=[image_output_component]
254
  )
255
+ submit_video_button_component.click(
256
+ fn=process_video,
257
+ inputs=[
258
+ checkpoint_dropdown_component,
259
+ mode_dropdown_component,
260
+ video_input_component,
261
+ video_prompter_input_component,
262
+ ],
263
+ outputs=[video_output_component]
264
+ )
265
 
266
  demo.launch(debug=False, show_error=True, max_threads=1)
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  samv2
2
  gradio
3
  supervision
 
1
+ tqdm
2
  samv2
3
  gradio
4
  supervision
utils/models.py CHANGED
@@ -8,7 +8,7 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
8
  BOX_PROMPT_MODE = "box prompt"
9
  MASK_GENERATION_MODE = "mask generation"
10
  VIDEO_SEGMENTATION_MODE = "video segmentation"
11
- MODE_NAMES = [BOX_PROMPT_MODE, MASK_GENERATION_MODE]
12
 
13
  CHECKPOINT_NAMES = ["tiny", "small", "base_plus", "large"]
14
  CHECKPOINTS = {
 
8
  BOX_PROMPT_MODE = "box prompt"
9
  MASK_GENERATION_MODE = "mask generation"
10
  VIDEO_SEGMENTATION_MODE = "video segmentation"
11
+ MODE_NAMES = [BOX_PROMPT_MODE, MASK_GENERATION_MODE, VIDEO_SEGMENTATION_MODE]
12
 
13
  CHECKPOINT_NAMES = ["tiny", "small", "base_plus", "large"]
14
  CHECKPOINTS = {
utils/video.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import datetime
4
+
5
+
6
+ def create_directory(directory_path: str) -> None:
7
+ if not os.path.exists(directory_path):
8
+ os.makedirs(directory_path)
9
+
10
+
11
+ def generate_unique_name():
12
+ current_datetime = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
13
+ unique_id = uuid.uuid4()
14
+ return f"{current_datetime}_{unique_id}"