import gradio as gr import numpy as np import cv2 import torch import os import logging import contextlib from sam2.build_sam import build_sam2_video_predictor # Add current directory to path import sys sys.path.append(os.getcwd()) sys.path.append(os.path.join(os.getcwd(), "sam2")) # Add sam2 directory to path print(f"current dir is {os.getcwd()}") # Ensure device setup matches the official code force_cpu_device = os.environ.get("SAM2_DEMO_FORCE_CPU_DEVICE", "0") == "1" if force_cpu_device: logging.info("forcing CPU device for SAM 2 demo") if torch.cuda.is_available() and not force_cpu_device: DEVICE = torch.device("cuda") elif torch.backends.mps.is_available() and not force_cpu_device: DEVICE = torch.device("mps") else: DEVICE = torch.device("cpu") logging.info(f"using device: {DEVICE}") if DEVICE.type == "cuda": if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True elif DEVICE.type == "mps": logging.warning( "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might " "give numerically different outputs and sometimes degraded performance on MPS. " "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion." ) def load_model_paths(checkpoint_name): """Get model checkpoint and config paths""" if checkpoint_name == "SAM2-T": sam2_checkpoint = "models/sam2.1_hiera_tiny.pt" model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml" elif checkpoint_name == "SAM2-S": sam2_checkpoint = "models/sam2.1_hiera_small.pt" model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml" elif checkpoint_name == "SAM2-B_PLUS": sam2_checkpoint = "models/sam2.1_hiera_base_plus.pt" model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml" else: raise ValueError(f"Invalid checkpoint name: {checkpoint_name}") return sam2_checkpoint, model_cfg # Available checkpoints CHECKPOINTS = { "SAM2-B_PLUS": "Base Plus Model", "SAM2-S": "Small Model", "SAM2-T": "Tiny Model", } class GolfTracker: def __init__(self, checkpoint="SAM2-T"): """Initialize with specified checkpoint model""" self.current_checkpoint = checkpoint self.predictor = None self.points = [] self.frames = [] self.current_frame_idx = 0 self.video_info = None self.state = None self.obj_id = 1 # Track single object (golf ball) self.device = DEVICE self.out_mask_logits = None self.load_model(checkpoint) def load_model(self, checkpoint_name): """Load specified checkpoint model""" if checkpoint_name not in CHECKPOINTS: raise ValueError(f"Invalid checkpoint: {checkpoint_name}") print(f"Loading checkpoint: {checkpoint_name}") sam2_checkpoint, model_cfg = load_model_paths(checkpoint_name) # Build predictor with model config and checkpoint self.predictor = build_sam2_video_predictor( model_cfg, sam2_checkpoint, self.device ) print(f"Model loaded successfully: {CHECKPOINTS[checkpoint_name]}") self.current_checkpoint = checkpoint_name def process_video(self, video_path): """Process the video and initialize tracking""" if not os.path.exists(video_path): return None, None, None, "Video file not found" # Reset state self.points = [] self.frames = [] self.current_frame_idx = 0 self.state = None # Read video frames cap = cv2.VideoCapture(video_path) while True: ret, frame = cap.read() if not ret: break self.frames.append(frame) if not self.frames: return None, None, None, "Failed to read video" # Store video info self.video_info = { "path": video_path, "height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), "width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), "fps": cap.get(cv2.CAP_PROP_FPS), "total_frames": len(self.frames), } cap.release() # Initialize SAM2 state with self.autocast_context(), torch.inference_mode(): self.state = self.predictor.init_state(video_path) return ( self.frames[0], # First frame self.current_checkpoint, gr.Slider(minimum=0, maximum=len(self.frames) - 1, step=1, value=0), "Navigate through frames and click on the golf ball to track", ) def update_frame(self, frame_idx): """Update displayed frame""" if not self.frames or frame_idx >= len(self.frames): return None self.current_frame_idx = int(frame_idx) frame = self.frames[self.current_frame_idx].copy() # Draw existing points and trajectory self._draw_tracking(frame) return frame def add_point(self, frame, evt: gr.SelectData): """Add a point and get ball prediction with enhanced mask visualization""" if self.state is None: return frame x, y = evt.index[0], evt.index[1] self.points.append((self.current_frame_idx, x, y)) frame_with_points = frame.copy() # Get ball prediction using SAM2.1 with self.autocast_context(), torch.inference_mode(): # Convert points and labels to numpy arrays points = np.array([(x, y)], dtype=np.float32) labels = np.array([1], dtype=np.int32) # 1 for positive click # Add point and get mask _, out_obj_ids, out_mask_logits = self.predictor.add_new_points( inference_state=self.state, frame_idx=self.current_frame_idx, obj_id=self.obj_id, points=points, labels=labels, ) if out_mask_logits is not None and len(out_mask_logits) > 0: self.out_mask_logits = out_mask_logits # Draw tracking visualization self._draw_tracking(frame_with_points) return frame_with_points def propagate_masks(self): """Propagate masks to the entire video after user selection""" if self.state is None: return "No state initialized" logging.info(f"Propagating masks in video with state: {self.state}") # Propagate the masks across the video with self.autocast_context(), torch.inference_mode(): frame_idx, obj_ids, video_res_masks = self.predictor.propagate_in_video( inference_state=self.state, start_frame_idx=0, reverse=False, ) self.out_mask_logits = video_res_masks return "Propagation complete" def autocast_context(self): if self.device.type == "cuda": return torch.autocast("cuda", dtype=torch.bfloat16) else: return contextlib.nullcontext() def _draw_tracking(self, frame): """Draw object mask on frame with enhanced visualization""" # Assuming out_mask_logits is available from propagate_masks if self.current_frame_idx < len(self.frames): mask_np = (self.out_mask_logits[self.current_frame_idx] > 0.0).cpu().numpy() if mask_np.shape[:2] == frame.shape[:2]: overlay = frame.copy() overlay[mask_np > 0] = [0, 0, 255] # Red color for mask alpha = 0.5 # Transparency factor frame = cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0) return frame def clear_points(self): """Clear all tracked points""" self.points = [] if self.frames: return self.frames[self.current_frame_idx].copy() return None def change_model(self, checkpoint_name): """Change the current model checkpoint""" if checkpoint_name != self.current_checkpoint: self.load_model(checkpoint_name) return f"Loaded {CHECKPOINTS[checkpoint_name]}" def save_output_video(self): """Save the processed video with tracking visualization""" if not self.frames or not self.video_info: return None, "No video loaded" output_path = "output_tracked.mp4" # Initialize video writer fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter( output_path, fourcc, self.video_info["fps"], (self.video_info["width"], self.video_info["height"]), ) # Process each frame for frame_idx in range(len(self.frames)): frame = self.frames[frame_idx].copy() # Draw tracking for this frame frame_points = [(x, y) for f, x, y in self.points if f == frame_idx] if frame_points: # Draw points for x, y in frame_points: cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1) # Fit and draw trajectory if enough points if len(frame_points) >= 3: points_arr = np.array(frame_points) # fit_results = self.trajectory_fitter.fit_trajectory(points_arr) # if fit_results is not None: # trajectory = fit_results["trajectory"] # points = trajectory.astype(np.int32) # for i in range(len(points) - 1): # cv2.line( # frame, # tuple(points[i]), # tuple(points[i + 1]), # (0, 255, 0), # 2, # ) # # Calculate and display metrics # metrics = self.trajectory_fitter.calculate_metrics(fit_results) # cv2.putText( # frame, # f"Speed: {metrics['initial_velocity_mph']:.1f} mph", # (10, 30), # cv2.FONT_HERSHEY_SIMPLEX, # 1, # (255, 255, 255), # 2, # ) # cv2.putText( # frame, # f"Height: {metrics['max_height']:.1f} m", # (10, 70), # cv2.FONT_HERSHEY_SIMPLEX, # 1, # (255, 255, 255), # 2, # ) out.write(frame) out.release() return output_path, "Video saved successfully!" def create_ui(): tracker = GolfTracker() with gr.Blocks() as app: gr.Markdown("# Golf Ball Trajectory Tracker") gr.Markdown( "Upload a video and click on the golf ball positions to track its trajectory" ) with gr.Row(): with gr.Column(): video_input = gr.Video(label="Input Video") model_dropdown = gr.Dropdown( choices=list(CHECKPOINTS.keys()), value="SAM2-T", label="Select Model", ) upload_button = gr.Button("Process Video") clear_button = gr.Button("Clear Points") save_button = gr.Button("Save Output Video") propagate_button = gr.Button("Propagate Masks") with gr.Column(): image_output = gr.Image(label="Click on golf ball positions") frame_slider = gr.Slider( minimum=0, maximum=0, step=1, value=0, label="Frame", interactive=True, ) current_model = gr.Textbox(label="Current Model", interactive=False) status_text = gr.Textbox(label="Status", interactive=False) output_video = gr.Video(label="Output Video") # Event handlers model_dropdown.change( fn=tracker.change_model, inputs=[model_dropdown], outputs=[status_text] ) video_input.change( fn=tracker.process_video, inputs=[video_input], outputs=[image_output, current_model, frame_slider, status_text], ) upload_button.click( fn=tracker.process_video, inputs=[video_input], outputs=[image_output, current_model, frame_slider, status_text], ) clear_button.click(fn=tracker.clear_points, inputs=[], outputs=[image_output]) frame_slider.change( fn=tracker.update_frame, inputs=[frame_slider], outputs=[image_output] ) image_output.select( fn=tracker.add_point, inputs=[image_output], outputs=[image_output] ) save_button.click( fn=tracker.save_output_video, inputs=[], outputs=[output_video, status_text] ) propagate_button.click( fn=tracker.propagate_masks, inputs=[], outputs=[status_text] ) return app if __name__ == "__main__": app = create_ui() app.launch()