--- license: mit --- # Segment Anything 8-Bit ONNX How to run: ```python import onnxruntime as ort import numpy as np from PIL import Image import matplotlib.pyplot as plt # Path to the image file image_path = "example.png" # Load the image and preprocess it image = Image.open(image_path).convert("RGB") orig_width, orig_height = image.size input_tensor = np.array(image) mean = np.array([123.675, 116.28, 103.53]) std = np.array([58.395, 57.12, 57.375]) input_tensor = (input_tensor - mean) / std input_tensor = input_tensor.transpose(2, 0, 1)[None, :, :, :].astype(np.float32) # Pad input tensor to 1024x1024 pad_height = 1024 - input_tensor.shape[2] pad_width = 1024 - input_tensor.shape[3] input_tensor = np.pad(input_tensor, ((0, 0), (0, 0), (0, pad_height), (0, pad_width))) # Load the encoder model and run inference encoder = ort.InferenceSession("sam_encoder.onnx") embeddings = encoder.run(None, {"images": input_tensor})[0] # Choose a point (e.g., x=150, y=100) in the original image point = [150, 100] # Convert point coordinates to match the padded image point = np.array([[point]]) coords = point.astype(float) coords[..., 0] = coords[..., 0] * (1024 / orig_width) coords[..., 1] = coords[..., 1] * (1024 / orig_height) onnx_coord = coords.astype("float32") # Prepare inputs for the decoder onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) onnx_has_mask_input = np.zeros(1, dtype=np.float32) onnx_label = np.array([1, -1]).astype(np.float32)[None, :] # Load the decoder model and run inference decoder = ort.InferenceSession("sam_decoder.onnx") masks_output, _, _ = decoder.run(None, { "image_embeddings": embeddings, "point_coords": onnx_coord, "point_labels": onnx_label, "mask_input": onnx_mask_input, "has_mask_input": onnx_has_mask_input, "orig_im_size": np.array([orig_height, orig_width], dtype=np.float32) }) # Process the output mask mask = masks_output[0][0] mask = (mask > 0).astype('uint8') * 255 ```