|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from segment_anything import sam_model_registry |
|
from segment_anything.utils.onnx import SamOnnxModel |
|
|
|
import argparse |
|
import warnings |
|
|
|
try: |
|
import onnxruntime |
|
|
|
onnxruntime_exists = True |
|
except ImportError: |
|
onnxruntime_exists = False |
|
|
|
parser = argparse.ArgumentParser( |
|
description="Export the SAM prompt encoder and mask decoder to an ONNX model." |
|
) |
|
|
|
parser.add_argument( |
|
"--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint." |
|
) |
|
|
|
parser.add_argument( |
|
"--output", type=str, required=True, help="The filename to save the ONNX model to." |
|
) |
|
|
|
parser.add_argument( |
|
"--model-type", |
|
type=str, |
|
required=True, |
|
help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.", |
|
) |
|
|
|
parser.add_argument( |
|
"--return-single-mask", |
|
action="store_true", |
|
help=( |
|
"If true, the exported ONNX model will only return the best mask, " |
|
"instead of returning multiple masks. For high resolution images " |
|
"this can improve runtime when upscaling masks is expensive." |
|
), |
|
) |
|
|
|
parser.add_argument( |
|
"--opset", |
|
type=int, |
|
default=17, |
|
help="The ONNX opset version to use. Must be >=11", |
|
) |
|
|
|
parser.add_argument( |
|
"--quantize-out", |
|
type=str, |
|
default=None, |
|
help=( |
|
"If set, will quantize the model and save it with this name. " |
|
"Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize." |
|
), |
|
) |
|
|
|
parser.add_argument( |
|
"--gelu-approximate", |
|
action="store_true", |
|
help=( |
|
"Replace GELU operations with approximations using tanh. Useful " |
|
"for some runtimes that have slow or unimplemented erf ops, used in GELU." |
|
), |
|
) |
|
|
|
parser.add_argument( |
|
"--use-stability-score", |
|
action="store_true", |
|
help=( |
|
"Replaces the model's predicted mask quality score with the stability " |
|
"score calculated on the low resolution masks using an offset of 1.0. " |
|
), |
|
) |
|
|
|
parser.add_argument( |
|
"--return-extra-metrics", |
|
action="store_true", |
|
help=( |
|
"The model will return five results: (masks, scores, stability_scores, " |
|
"areas, low_res_logits) instead of the usual three. This can be " |
|
"significantly slower for high resolution outputs." |
|
), |
|
) |
|
|
|
|
|
def run_export( |
|
model_type: str, |
|
checkpoint: str, |
|
output: str, |
|
opset: int, |
|
return_single_mask: bool, |
|
gelu_approximate: bool = False, |
|
use_stability_score: bool = False, |
|
return_extra_metrics=False, |
|
): |
|
print("Loading model...") |
|
sam = sam_model_registry[model_type](checkpoint=checkpoint) |
|
|
|
onnx_model = SamOnnxModel( |
|
model=sam, |
|
return_single_mask=return_single_mask, |
|
use_stability_score=use_stability_score, |
|
return_extra_metrics=return_extra_metrics, |
|
) |
|
|
|
if gelu_approximate: |
|
for n, m in onnx_model.named_modules(): |
|
if isinstance(m, torch.nn.GELU): |
|
m.approximate = "tanh" |
|
|
|
dynamic_axes = { |
|
"point_coords": {1: "num_points"}, |
|
"point_labels": {1: "num_points"}, |
|
} |
|
|
|
embed_dim = sam.prompt_encoder.embed_dim |
|
embed_size = sam.prompt_encoder.image_embedding_size |
|
mask_input_size = [4 * x for x in embed_size] |
|
dummy_inputs = { |
|
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), |
|
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), |
|
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), |
|
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), |
|
"has_mask_input": torch.tensor([1], dtype=torch.float), |
|
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), |
|
} |
|
|
|
_ = onnx_model(**dummy_inputs) |
|
|
|
output_names = ["masks", "iou_predictions", "low_res_masks"] |
|
|
|
with warnings.catch_warnings(): |
|
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) |
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
with open(output, "wb") as f: |
|
print(f"Exporing onnx model to {output}...") |
|
torch.onnx.export( |
|
onnx_model, |
|
tuple(dummy_inputs.values()), |
|
f, |
|
export_params=True, |
|
verbose=False, |
|
opset_version=opset, |
|
do_constant_folding=True, |
|
input_names=list(dummy_inputs.keys()), |
|
output_names=output_names, |
|
dynamic_axes=dynamic_axes, |
|
) |
|
|
|
if onnxruntime_exists: |
|
ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()} |
|
ort_session = onnxruntime.InferenceSession(output) |
|
_ = ort_session.run(None, ort_inputs) |
|
print("Model has successfully been run with ONNXRuntime.") |
|
|
|
|
|
def to_numpy(tensor): |
|
return tensor.cpu().numpy() |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parser.parse_args() |
|
run_export( |
|
model_type=args.model_type, |
|
checkpoint=args.checkpoint, |
|
output=args.output, |
|
opset=args.opset, |
|
return_single_mask=args.return_single_mask, |
|
gelu_approximate=args.gelu_approximate, |
|
use_stability_score=args.use_stability_score, |
|
return_extra_metrics=args.return_extra_metrics, |
|
) |
|
|
|
if args.quantize_out is not None: |
|
assert onnxruntime_exists, "onnxruntime is required to quantize the model." |
|
from onnxruntime.quantization import QuantType |
|
from onnxruntime.quantization.quantize import quantize_dynamic |
|
|
|
print(f"Quantizing model and writing to {args.quantize_out}...") |
|
quantize_dynamic( |
|
model_input=args.output, |
|
model_output=args.quantize_out, |
|
optimize_model=True, |
|
per_channel=False, |
|
reduce_range=False, |
|
weight_type=QuantType.QUInt8, |
|
) |
|
print("Done!") |
|
|