Spaces:
Runtime error
Runtime error
import torch | |
import cv2 | |
import numpy as np | |
from sam.segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator | |
class Segmentor: | |
def __init__(self, sam_args): | |
""" | |
sam_args: | |
sam_checkpoint: path of SAM checkpoint | |
generator_args: args for everything_generator | |
gpu_id: device | |
""" | |
self.device = sam_args["gpu_id"] | |
self.sam = sam_model_registry[sam_args["model_type"]](checkpoint=sam_args["sam_checkpoint"]) | |
self.sam.to(device=self.device) | |
self.everything_generator = SamAutomaticMaskGenerator(model=self.sam, **sam_args['generator_args']) | |
self.interactive_predictor = self.everything_generator.predictor | |
self.have_embedded = False | |
def set_image(self, image): | |
# calculate the embedding only once per frame. | |
if not self.have_embedded: | |
self.interactive_predictor.set_image(image) | |
self.have_embedded = True | |
def interactive_predict(self, prompts, mode, multimask=True): | |
assert self.have_embedded, 'image embedding for sam need be set before predict.' | |
if mode == 'point': | |
masks, scores, logits = self.interactive_predictor.predict(point_coords=prompts['point_coords'], | |
point_labels=prompts['point_modes'], | |
multimask_output=multimask) | |
elif mode == 'mask': | |
masks, scores, logits = self.interactive_predictor.predict(mask_input=prompts['mask_prompt'], | |
multimask_output=multimask) | |
elif mode == 'point_mask': | |
masks, scores, logits = self.interactive_predictor.predict(point_coords=prompts['point_coords'], | |
point_labels=prompts['point_modes'], | |
mask_input=prompts['mask_prompt'], | |
multimask_output=multimask) | |
return masks, scores, logits | |
def segment_with_click(self, origin_frame, coords, modes, multimask=True): | |
''' | |
return: | |
mask: one-hot | |
''' | |
self.set_image(origin_frame) | |
prompts = { | |
'point_coords': coords, | |
'point_modes': modes, | |
} | |
masks, scores, logits = self.interactive_predict(prompts, 'point', multimask) | |
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] | |
prompts = { | |
'point_coords': coords, | |
'point_modes': modes, | |
'mask_prompt': logit[None, :, :] | |
} | |
masks, scores, logits = self.interactive_predict(prompts, 'point_mask', multimask) | |
mask = masks[np.argmax(scores)] | |
return mask.astype(np.uint8) | |
def segment_with_box(self, origin_frame, bbox, reset_image=False): | |
if reset_image: | |
self.interactive_predictor.set_image(origin_frame) | |
else: | |
self.set_image(origin_frame) | |
# coord = np.array([[int((bbox[1][0] - bbox[0][0]) / 2.), int((bbox[1][1] - bbox[0][1]) / 2)]]) | |
# point_label = np.array([1]) | |
masks, scores, logits = self.interactive_predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
box=np.array([bbox[0][0], bbox[0][1], bbox[1][0], bbox[1][1]]), | |
multimask_output=True | |
) | |
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] | |
masks, scores, logits = self.interactive_predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
box=np.array([[bbox[0][0], bbox[0][1], bbox[1][0], bbox[1][1]]]), | |
mask_input=logit[None, :, :], | |
multimask_output=True | |
) | |
mask = masks[np.argmax(scores)] | |
return [mask] | |