|
import sys |
|
sys.path.append("..") |
|
sys.path.append("./sam") |
|
from sam.segment_anything import sam_model_registry, SamAutomaticMaskGenerator |
|
from aot_tracker import get_aot |
|
import numpy as np |
|
from tool.segmentor import Segmentor |
|
from tool.detector import Detector |
|
from tool.transfer_tools import draw_outline, draw_points |
|
import cv2 |
|
from seg_track_anything import draw_mask |
|
|
|
|
|
class SegTracker(): |
|
def __init__(self,segtracker_args, sam_args, aot_args) -> None: |
|
""" |
|
Initialize SAM and AOT. |
|
""" |
|
self.sam = Segmentor(sam_args) |
|
self.tracker = get_aot(aot_args) |
|
self.detector = Detector(self.sam.device) |
|
self.sam_gap = segtracker_args['sam_gap'] |
|
self.min_area = segtracker_args['min_area'] |
|
self.max_obj_num = segtracker_args['max_obj_num'] |
|
self.min_new_obj_iou = segtracker_args['min_new_obj_iou'] |
|
self.reference_objs_list = [] |
|
self.object_idx = 1 |
|
self.curr_idx = 1 |
|
self.origin_merged_mask = None |
|
self.first_frame_mask = None |
|
|
|
|
|
self.everything_points = [] |
|
self.everything_labels = [] |
|
print("SegTracker has been initialized") |
|
|
|
def seg(self,frame): |
|
''' |
|
Arguments: |
|
frame: numpy array (h,w,3) |
|
Return: |
|
origin_merged_mask: numpy array (h,w) |
|
''' |
|
frame = frame[:, :, ::-1] |
|
anns = self.sam.everything_generator.generate(frame) |
|
|
|
|
|
if len(anns) == 0: |
|
return |
|
|
|
|
|
self.origin_merged_mask = np.zeros(anns[0]['segmentation'].shape,dtype=np.uint8) |
|
idx = 1 |
|
for ann in anns: |
|
if ann['area'] > self.min_area: |
|
m = ann['segmentation'] |
|
self.origin_merged_mask[m==1] = idx |
|
idx += 1 |
|
self.everything_points.append(ann["point_coords"][0]) |
|
self.everything_labels.append(1) |
|
|
|
obj_ids = np.unique(self.origin_merged_mask) |
|
obj_ids = obj_ids[obj_ids!=0] |
|
|
|
self.object_idx = 1 |
|
for id in obj_ids: |
|
if np.sum(self.origin_merged_mask==id) < self.min_area or self.object_idx > self.max_obj_num: |
|
self.origin_merged_mask[self.origin_merged_mask==id] = 0 |
|
else: |
|
self.origin_merged_mask[self.origin_merged_mask==id] = self.object_idx |
|
self.object_idx += 1 |
|
|
|
self.first_frame_mask = self.origin_merged_mask |
|
return self.origin_merged_mask |
|
|
|
def update_origin_merged_mask(self, updated_merged_mask): |
|
self.origin_merged_mask = updated_merged_mask |
|
|
|
|
|
|
|
|
|
def reset_origin_merged_mask(self, mask, id): |
|
self.origin_merged_mask = mask |
|
self.curr_idx = id |
|
|
|
def add_reference(self,frame,mask,frame_step=0): |
|
''' |
|
Add objects in a mask for tracking. |
|
Arguments: |
|
frame: numpy array (h,w,3) |
|
mask: numpy array (h,w) |
|
''' |
|
self.reference_objs_list.append(np.unique(mask)) |
|
self.curr_idx = self.get_obj_num() + 1 |
|
self.tracker.add_reference_frame(frame,mask, self.curr_idx - 1, frame_step) |
|
|
|
def track(self,frame,update_memory=False): |
|
''' |
|
Track all known objects. |
|
Arguments: |
|
frame: numpy array (h,w,3) |
|
Return: |
|
origin_merged_mask: numpy array (h,w) |
|
''' |
|
pred_mask = self.tracker.track(frame) |
|
if update_memory: |
|
self.tracker.update_memory(pred_mask) |
|
return pred_mask.squeeze(0).squeeze(0).detach().cpu().numpy().astype(np.uint8) |
|
|
|
def get_tracking_objs(self): |
|
objs = set() |
|
for ref in self.reference_objs_list: |
|
objs.update(set(ref)) |
|
objs = list(sorted(list(objs))) |
|
objs = [i for i in objs if i!=0] |
|
return objs |
|
|
|
def get_obj_num(self): |
|
objs = self.get_tracking_objs() |
|
if len(objs) == 0: return 0 |
|
return int(max(objs)) |
|
|
|
def find_new_objs(self, track_mask, seg_mask): |
|
''' |
|
Compare tracked results from AOT with segmented results from SAM. Select objects from background if they are not tracked. |
|
Arguments: |
|
track_mask: numpy array (h,w) |
|
seg_mask: numpy array (h,w) |
|
Return: |
|
new_obj_mask: numpy array (h,w) |
|
''' |
|
new_obj_mask = (track_mask==0) * seg_mask |
|
new_obj_ids = np.unique(new_obj_mask) |
|
new_obj_ids = new_obj_ids[new_obj_ids!=0] |
|
|
|
obj_num = self.curr_idx |
|
for idx in new_obj_ids: |
|
new_obj_area = np.sum(new_obj_mask==idx) |
|
obj_area = np.sum(seg_mask==idx) |
|
if new_obj_area/obj_area < self.min_new_obj_iou or new_obj_area < self.min_area\ |
|
or obj_num > self.max_obj_num: |
|
new_obj_mask[new_obj_mask==idx] = 0 |
|
else: |
|
new_obj_mask[new_obj_mask==idx] = obj_num |
|
obj_num += 1 |
|
return new_obj_mask |
|
|
|
def restart_tracker(self): |
|
self.tracker.restart() |
|
|
|
def seg_acc_bbox(self, origin_frame: np.ndarray, bbox: np.ndarray,): |
|
'''' |
|
Use bbox-prompt to get mask |
|
Parameters: |
|
origin_frame: H, W, C |
|
bbox: [[x0, y0], [x1, y1]] |
|
Return: |
|
refined_merged_mask: numpy array (h, w) |
|
masked_frame: numpy array (h, w, c) |
|
''' |
|
|
|
interactive_mask = self.sam.segment_with_box(origin_frame, bbox)[0] |
|
refined_merged_mask = self.add_mask(interactive_mask) |
|
|
|
|
|
masked_frame = draw_mask(origin_frame.copy(), refined_merged_mask) |
|
|
|
|
|
masked_frame = cv2.rectangle(masked_frame, bbox[0], bbox[1], (0, 0, 255)) |
|
|
|
return refined_merged_mask, masked_frame |
|
|
|
def seg_acc_click(self, origin_frame: np.ndarray, coords: np.ndarray, modes: np.ndarray, multimask=True): |
|
''' |
|
Use point-prompt to get mask |
|
Parameters: |
|
origin_frame: H, W, C |
|
coords: nd.array [[x, y]] |
|
modes: nd.array [[1]] |
|
Return: |
|
refined_merged_mask: numpy array (h, w) |
|
masked_frame: numpy array (h, w, c) |
|
''' |
|
|
|
interactive_mask = self.sam.segment_with_click(origin_frame, coords, modes, multimask) |
|
|
|
refined_merged_mask = self.add_mask(interactive_mask) |
|
|
|
|
|
masked_frame = draw_mask(origin_frame.copy(), refined_merged_mask) |
|
|
|
|
|
|
|
|
|
|
|
masked_frame = draw_points(coords, modes, masked_frame) |
|
|
|
|
|
masked_frame = draw_outline(interactive_mask, masked_frame) |
|
|
|
return refined_merged_mask, masked_frame |
|
|
|
def add_mask(self, interactive_mask: np.ndarray): |
|
''' |
|
Merge interactive mask with self.origin_merged_mask |
|
Parameters: |
|
interactive_mask: numpy array (h, w) |
|
Return: |
|
refined_merged_mask: numpy array (h, w) |
|
''' |
|
if self.origin_merged_mask is None: |
|
self.origin_merged_mask = np.zeros(interactive_mask.shape,dtype=np.uint8) |
|
|
|
refined_merged_mask = self.origin_merged_mask.copy() |
|
refined_merged_mask[interactive_mask > 0] = self.curr_idx |
|
|
|
return refined_merged_mask |
|
|
|
def detect_and_seg(self, origin_frame: np.ndarray, grounding_caption, box_threshold, text_threshold, box_size_threshold=1, reset_image=False): |
|
''' |
|
Using Grounding-DINO to detect object acc Text-prompts |
|
Retrun: |
|
refined_merged_mask: numpy array (h, w) |
|
annotated_frame: numpy array (h, w, 3) |
|
''' |
|
|
|
bc_id = self.curr_idx |
|
bc_mask = self.origin_merged_mask |
|
|
|
|
|
annotated_frame, boxes = self.detector.run_grounding(origin_frame, grounding_caption, box_threshold, text_threshold) |
|
for i in range(len(boxes)): |
|
bbox = boxes[i] |
|
if (bbox[1][0] - bbox[0][0]) * (bbox[1][1] - bbox[0][1]) > annotated_frame.shape[0] * annotated_frame.shape[1] * box_size_threshold: |
|
continue |
|
interactive_mask = self.sam.segment_with_box(origin_frame, bbox, reset_image)[0] |
|
refined_merged_mask = self.add_mask(interactive_mask) |
|
self.update_origin_merged_mask(refined_merged_mask) |
|
self.curr_idx += 1 |
|
|
|
|
|
self.reset_origin_merged_mask(bc_mask, bc_id) |
|
|
|
return refined_merged_mask, annotated_frame |
|
|
|
if __name__ == '__main__': |
|
from model_args import segtracker_args,sam_args,aot_args |
|
|
|
Seg_Tracker = SegTracker(segtracker_args, sam_args, aot_args) |
|
|
|
|
|
|
|
origin_frame = cv2.imread('/data2/cym/Seg_Tra_any/Segment-and-Track-Anything/debug/point.png') |
|
origin_frame = cv2.cvtColor(origin_frame, cv2.COLOR_BGR2RGB) |
|
grounding_caption = "swan.water" |
|
box_threshold = 0.25 |
|
text_threshold = 0.25 |
|
|
|
predicted_mask, annotated_frame = Seg_Tracker.detect_and_seg(origin_frame, grounding_caption, box_threshold, text_threshold) |
|
masked_frame = draw_mask(annotated_frame, predicted_mask) |
|
origin_frame = cv2.cvtColor(origin_frame, cv2.COLOR_RGB2BGR) |
|
|
|
cv2.imwrite('./debug/masked_frame.png', masked_frame) |
|
cv2.imwrite('./debug/x.png', annotated_frame) |