Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from torchvision.transforms import ToTensor | |
GPU_EFFICIENT_SAM_CHECKPOINT = "efficient_sam_s_gpu.jit" | |
CPU_EFFICIENT_SAM_CHECKPOINT = "efficient_sam_s_cpu.jit" | |
def load(device: torch.device) -> torch.jit.ScriptModule: | |
if device.type == "cuda": | |
model = torch.jit.load(GPU_EFFICIENT_SAM_CHECKPOINT) | |
else: | |
model = torch.jit.load(CPU_EFFICIENT_SAM_CHECKPOINT) | |
model.eval() | |
return model | |
def inference_with_box( | |
image: np.ndarray, | |
box: np.ndarray, | |
model: torch.jit.ScriptModule, | |
device: torch.device | |
) -> np.ndarray: | |
bbox = torch.reshape(torch.tensor(box), [1, 1, 2, 2]) | |
bbox_labels = torch.reshape(torch.tensor([2, 3]), [1, 1, 2]) | |
img_tensor = ToTensor()(image) | |
predicted_logits, predicted_iou = model( | |
img_tensor[None, ...].to(device), | |
bbox.to(device), | |
bbox_labels.to(device), | |
) | |
predicted_logits = predicted_logits.cpu() | |
all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy() | |
predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy() | |
max_predicted_iou = -1 | |
selected_mask_using_predicted_iou = None | |
for m in range(all_masks.shape[0]): | |
curr_predicted_iou = predicted_iou[m] | |
if ( | |
curr_predicted_iou > max_predicted_iou | |
or selected_mask_using_predicted_iou is None | |
): | |
max_predicted_iou = curr_predicted_iou | |
selected_mask_using_predicted_iou = all_masks[m] | |
return selected_mask_using_predicted_iou | |
def inference_with_point( | |
image: np.ndarray, | |
point: np.ndarray, | |
model: torch.jit.ScriptModule, | |
device: torch.device | |
) -> np.ndarray: | |
pts_sampled = torch.reshape(torch.tensor(point), [1, 1, -1, 2]) | |
max_num_pts = pts_sampled.shape[2] | |
pts_labels = torch.ones(1, 1, max_num_pts) | |
img_tensor = ToTensor()(image) | |
predicted_logits, predicted_iou = model( | |
img_tensor[None, ...].to(device), | |
pts_sampled.to(device), | |
pts_labels.to(device), | |
) | |
predicted_logits = predicted_logits.cpu() | |
all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy() | |
predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy() | |
max_predicted_iou = -1 | |
selected_mask_using_predicted_iou = None | |
for m in range(all_masks.shape[0]): | |
curr_predicted_iou = predicted_iou[m] | |
if ( | |
curr_predicted_iou > max_predicted_iou | |
or selected_mask_using_predicted_iou is None | |
): | |
max_predicted_iou = curr_predicted_iou | |
selected_mask_using_predicted_iou = all_masks[m] | |
return selected_mask_using_predicted_iou | |