|
import os |
|
import cv2 |
|
from model_args import segtracker_args,sam_args,aot_args |
|
from PIL import Image |
|
from aot_tracker import _palette |
|
import numpy as np |
|
import torch |
|
import gc |
|
import imageio |
|
from scipy.ndimage import binary_dilation |
|
|
|
def save_prediction(pred_mask,output_dir,file_name): |
|
save_mask = Image.fromarray(pred_mask.astype(np.uint8)) |
|
save_mask = save_mask.convert(mode='P') |
|
save_mask.putpalette(_palette) |
|
save_mask.save(os.path.join(output_dir,file_name)) |
|
|
|
def colorize_mask(pred_mask): |
|
save_mask = Image.fromarray(pred_mask.astype(np.uint8)) |
|
save_mask = save_mask.convert(mode='P') |
|
save_mask.putpalette(_palette) |
|
save_mask = save_mask.convert(mode='RGB') |
|
return np.array(save_mask) |
|
|
|
def draw_mask(img, mask, alpha=0.5, id_countour=False): |
|
img_mask = np.zeros_like(img) |
|
img_mask = img |
|
if id_countour: |
|
|
|
obj_ids = np.unique(mask) |
|
obj_ids = obj_ids[obj_ids!=0] |
|
|
|
for id in obj_ids: |
|
|
|
if id <= 255: |
|
color = _palette[id*3:id*3+3] |
|
else: |
|
color = [0,0,0] |
|
foreground = img * (1-alpha) + np.ones_like(img) * alpha * np.array(color) |
|
binary_mask = (mask == id) |
|
|
|
|
|
img_mask[binary_mask] = foreground[binary_mask] |
|
|
|
countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask |
|
img_mask[countours, :] = 0 |
|
else: |
|
binary_mask = (mask!=0) |
|
countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask |
|
foreground = img*(1-alpha)+colorize_mask(mask)*alpha |
|
img_mask[binary_mask] = foreground[binary_mask] |
|
img_mask[countours,:] = 0 |
|
|
|
return img_mask.astype(img.dtype) |
|
|
|
def create_dir(dir_path): |
|
if os.path.isdir(dir_path): |
|
os.system(f"rm -r {dir_path}") |
|
|
|
os.makedirs(dir_path) |
|
|
|
aot_model2ckpt = { |
|
"deaotb": "./ckpt/DeAOTB_PRE_YTB_DAV.pth", |
|
"deaotl": "./ckpt/DeAOTL_PRE_YTB_DAV", |
|
"r50_deaotl": "./ckpt/R50_DeAOTL_PRE_YTB_DAV.pth", |
|
} |
|
|
|
|
|
def tracking_objects_in_video(SegTracker, input_video, input_img_seq, fps): |
|
|
|
if input_video is not None: |
|
video_name = os.path.basename(input_video).split('.')[0] |
|
elif input_img_seq is not None: |
|
file_name = input_img_seq.name.split('/')[-1].split('.')[0] |
|
file_path = f'./assets/{file_name}' |
|
imgs_path = sorted([os.path.join(file_path, img_name) for img_name in os.listdir(file_path)]) |
|
video_name = file_name |
|
else: |
|
return None, None |
|
|
|
|
|
tracking_result_dir = f'{os.path.join(os.path.dirname(__file__), "tracking_results", f"{video_name}")}' |
|
create_dir(tracking_result_dir) |
|
|
|
io_args = { |
|
'tracking_result_dir': tracking_result_dir, |
|
'output_mask_dir': f'{tracking_result_dir}/{video_name}_masks', |
|
'output_masked_frame_dir': f'{tracking_result_dir}/{video_name}_masked_frames', |
|
'output_video': f'{tracking_result_dir}/{video_name}_seg.mp4', |
|
'output_gif': f'{tracking_result_dir}/{video_name}_seg.gif', |
|
} |
|
|
|
if input_video is not None: |
|
return video_type_input_tracking(SegTracker, input_video, io_args, video_name) |
|
elif input_img_seq is not None: |
|
return img_seq_type_input_tracking(SegTracker, io_args, video_name, imgs_path, fps) |
|
|
|
|
|
def video_type_input_tracking(SegTracker, input_video, io_args, video_name): |
|
|
|
|
|
cap = cv2.VideoCapture(input_video) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
|
|
|
output_mask_dir = io_args['output_mask_dir'] |
|
create_dir(io_args['output_mask_dir']) |
|
create_dir(io_args['output_masked_frame_dir']) |
|
|
|
pred_list = [] |
|
masked_pred_list = [] |
|
|
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
sam_gap = SegTracker.sam_gap |
|
frame_idx = 0 |
|
|
|
with torch.cuda.amp.autocast(): |
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) |
|
|
|
if frame_idx == 0: |
|
pred_mask = SegTracker.first_frame_mask |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
elif (frame_idx % sam_gap) == 0: |
|
seg_mask = SegTracker.seg(frame) |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
track_mask = SegTracker.track(frame) |
|
|
|
new_obj_mask = SegTracker.find_new_objs(track_mask,seg_mask) |
|
save_prediction(new_obj_mask, output_mask_dir, str(frame_idx).zfill(5) + '_new.png') |
|
pred_mask = track_mask + new_obj_mask |
|
|
|
SegTracker.add_reference(frame, pred_mask) |
|
else: |
|
pred_mask = SegTracker.track(frame,update_memory=True) |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
save_prediction(pred_mask, output_mask_dir, str(frame_idx).zfill(5) + '.png') |
|
pred_list.append(pred_mask) |
|
|
|
print("processed frame {}, obj_num {}".format(frame_idx, SegTracker.get_obj_num()),end='\r') |
|
frame_idx += 1 |
|
cap.release() |
|
print('\nfinished') |
|
|
|
|
|
|
|
|
|
|
|
|
|
cap = cv2.VideoCapture(input_video) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = cv2.VideoWriter(io_args['output_video'], fourcc, fps, (width, height)) |
|
|
|
frame_idx = 0 |
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) |
|
pred_mask = pred_list[frame_idx] |
|
masked_frame = draw_mask(frame, pred_mask) |
|
cv2.imwrite(f"{io_args['output_masked_frame_dir']}/{str(frame_idx).zfill(5)}.png", masked_frame[:, :, ::-1]) |
|
|
|
masked_pred_list.append(masked_frame) |
|
masked_frame = cv2.cvtColor(masked_frame,cv2.COLOR_RGB2BGR) |
|
out.write(masked_frame) |
|
print('frame {} writed'.format(frame_idx),end='\r') |
|
frame_idx += 1 |
|
out.release() |
|
cap.release() |
|
print("\n{} saved".format(io_args['output_video'])) |
|
print('\nfinished') |
|
|
|
|
|
imageio.mimsave(io_args['output_gif'], masked_pred_list, fps=fps) |
|
print("{} saved".format(io_args['output_gif'])) |
|
|
|
|
|
os.system(f"zip -r {io_args['tracking_result_dir']}/{video_name}_pred_mask.zip {io_args['output_mask_dir']}") |
|
|
|
|
|
del SegTracker |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
return io_args['output_video'], f"{io_args['tracking_result_dir']}/{video_name}_pred_mask.zip" |
|
|
|
|
|
def img_seq_type_input_tracking(SegTracker, io_args, video_name, imgs_path, fps): |
|
|
|
|
|
output_mask_dir = io_args['output_mask_dir'] |
|
create_dir(io_args['output_mask_dir']) |
|
create_dir(io_args['output_masked_frame_dir']) |
|
|
|
pred_list = [] |
|
masked_pred_list = [] |
|
|
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
sam_gap = SegTracker.sam_gap |
|
frame_idx = 0 |
|
|
|
with torch.cuda.amp.autocast(): |
|
for img_path in imgs_path: |
|
frame_name = os.path.basename(img_path).split('.')[0] |
|
frame = cv2.imread(img_path) |
|
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) |
|
|
|
if frame_idx == 0: |
|
pred_mask = SegTracker.first_frame_mask |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
elif (frame_idx % sam_gap) == 0: |
|
seg_mask = SegTracker.seg(frame) |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
track_mask = SegTracker.track(frame) |
|
|
|
new_obj_mask = SegTracker.find_new_objs(track_mask,seg_mask) |
|
save_prediction(new_obj_mask, output_mask_dir, f'{frame_name}_new.png') |
|
pred_mask = track_mask + new_obj_mask |
|
|
|
SegTracker.add_reference(frame, pred_mask) |
|
else: |
|
pred_mask = SegTracker.track(frame,update_memory=True) |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
save_prediction(pred_mask, output_mask_dir, f'{frame_name}.png') |
|
pred_list.append(pred_mask) |
|
|
|
print("processed frame {}, obj_num {}".format(frame_idx, SegTracker.get_obj_num()),end='\r') |
|
frame_idx += 1 |
|
print('\nfinished') |
|
|
|
|
|
|
|
|
|
|
|
|
|
height, width = pred_list[0].shape |
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
|
|
out = cv2.VideoWriter(io_args['output_video'], fourcc, fps, (width, height)) |
|
|
|
frame_idx = 0 |
|
for img_path in imgs_path: |
|
frame_name = os.path.basename(img_path).split('.')[0] |
|
frame = cv2.imread(img_path) |
|
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) |
|
|
|
pred_mask = pred_list[frame_idx] |
|
masked_frame = draw_mask(frame, pred_mask) |
|
masked_pred_list.append(masked_frame) |
|
cv2.imwrite(f"{io_args['output_masked_frame_dir']}/{frame_name}.png", masked_frame[:, :, ::-1]) |
|
|
|
masked_frame = cv2.cvtColor(masked_frame,cv2.COLOR_RGB2BGR) |
|
out.write(masked_frame) |
|
print('frame {} writed'.format(frame_name),end='\r') |
|
frame_idx += 1 |
|
out.release() |
|
print("\n{} saved".format(io_args['output_video'])) |
|
print('\nfinished') |
|
|
|
|
|
imageio.mimsave(io_args['output_gif'], masked_pred_list, fps=fps) |
|
print("{} saved".format(io_args['output_gif'])) |
|
|
|
|
|
os.system(f"zip -r {io_args['tracking_result_dir']}/{video_name}_pred_mask.zip {io_args['output_mask_dir']}") |
|
|
|
|
|
del SegTracker |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
return io_args['output_video'], f"{io_args['tracking_result_dir']}/{video_name}_pred_mask.zip" |