Spaces:
Sleeping
Sleeping
File size: 2,712 Bytes
aa36c04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import torch
import numpy as np
from torchvision.transforms import ToTensor
GPU_EFFICIENT_SAM_CHECKPOINT = "efficient_sam_s_gpu.jit"
CPU_EFFICIENT_SAM_CHECKPOINT = "efficient_sam_s_cpu.jit"
def load(device: torch.device) -> torch.jit.ScriptModule:
if device.type == "cuda":
model = torch.jit.load(GPU_EFFICIENT_SAM_CHECKPOINT)
else:
model = torch.jit.load(CPU_EFFICIENT_SAM_CHECKPOINT)
model.eval()
return model
def inference_with_box(
image: np.ndarray,
box: np.ndarray,
model: torch.jit.ScriptModule,
device: torch.device
) -> np.ndarray:
bbox = torch.reshape(torch.tensor(box), [1, 1, 2, 2])
bbox_labels = torch.reshape(torch.tensor([2, 3]), [1, 1, 2])
img_tensor = ToTensor()(image)
predicted_logits, predicted_iou = model(
img_tensor[None, ...].to(device),
bbox.to(device),
bbox_labels.to(device),
)
predicted_logits = predicted_logits.cpu()
all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy()
predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
max_predicted_iou = -1
selected_mask_using_predicted_iou = None
for m in range(all_masks.shape[0]):
curr_predicted_iou = predicted_iou[m]
if (
curr_predicted_iou > max_predicted_iou
or selected_mask_using_predicted_iou is None
):
max_predicted_iou = curr_predicted_iou
selected_mask_using_predicted_iou = all_masks[m]
return selected_mask_using_predicted_iou
def inference_with_point(
image: np.ndarray,
point: np.ndarray,
model: torch.jit.ScriptModule,
device: torch.device
) -> np.ndarray:
pts_sampled = torch.reshape(torch.tensor(point), [1, 1, -1, 2])
max_num_pts = pts_sampled.shape[2]
pts_labels = torch.ones(1, 1, max_num_pts)
img_tensor = ToTensor()(image)
predicted_logits, predicted_iou = model(
img_tensor[None, ...].to(device),
pts_sampled.to(device),
pts_labels.to(device),
)
predicted_logits = predicted_logits.cpu()
all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy()
predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
max_predicted_iou = -1
selected_mask_using_predicted_iou = None
for m in range(all_masks.shape[0]):
curr_predicted_iou = predicted_iou[m]
if (
curr_predicted_iou > max_predicted_iou
or selected_mask_using_predicted_iou is None
):
max_predicted_iou = curr_predicted_iou
selected_mask_using_predicted_iou = all_masks[m]
return selected_mask_using_predicted_iou
|