File size: 6,148 Bytes
c985ba4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from segment_anything import sam_model_registry
from segment_anything.utils.onnx import SamOnnxModel
import argparse
import warnings
try:
import onnxruntime # type: ignore
onnxruntime_exists = True
except ImportError:
onnxruntime_exists = False
parser = argparse.ArgumentParser(
description="Export the SAM prompt encoder and mask decoder to an ONNX model."
)
parser.add_argument(
"--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint."
)
parser.add_argument(
"--output", type=str, required=True, help="The filename to save the ONNX model to."
)
parser.add_argument(
"--model-type",
type=str,
required=True,
help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.",
)
parser.add_argument(
"--return-single-mask",
action="store_true",
help=(
"If true, the exported ONNX model will only return the best mask, "
"instead of returning multiple masks. For high resolution images "
"this can improve runtime when upscaling masks is expensive."
),
)
parser.add_argument(
"--opset",
type=int,
default=17,
help="The ONNX opset version to use. Must be >=11",
)
parser.add_argument(
"--quantize-out",
type=str,
default=None,
help=(
"If set, will quantize the model and save it with this name. "
"Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
),
)
parser.add_argument(
"--gelu-approximate",
action="store_true",
help=(
"Replace GELU operations with approximations using tanh. Useful "
"for some runtimes that have slow or unimplemented erf ops, used in GELU."
),
)
parser.add_argument(
"--use-stability-score",
action="store_true",
help=(
"Replaces the model's predicted mask quality score with the stability "
"score calculated on the low resolution masks using an offset of 1.0. "
),
)
parser.add_argument(
"--return-extra-metrics",
action="store_true",
help=(
"The model will return five results: (masks, scores, stability_scores, "
"areas, low_res_logits) instead of the usual three. This can be "
"significantly slower for high resolution outputs."
),
)
def run_export(
model_type: str,
checkpoint: str,
output: str,
opset: int,
return_single_mask: bool,
gelu_approximate: bool = False,
use_stability_score: bool = False,
return_extra_metrics=False,
):
print("Loading model...")
sam = sam_model_registry[model_type](checkpoint=checkpoint)
onnx_model = SamOnnxModel(
model=sam,
return_single_mask=return_single_mask,
use_stability_score=use_stability_score,
return_extra_metrics=return_extra_metrics,
)
if gelu_approximate:
for n, m in onnx_model.named_modules():
if isinstance(m, torch.nn.GELU):
m.approximate = "tanh"
dynamic_axes = {
"point_coords": {1: "num_points"},
"point_labels": {1: "num_points"},
}
embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
"has_mask_input": torch.tensor([1], dtype=torch.float),
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
_ = onnx_model(**dummy_inputs)
output_names = ["masks", "iou_predictions", "low_res_masks"]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
with open(output, "wb") as f:
print(f"Exporing onnx model to {output}...")
torch.onnx.export(
onnx_model,
tuple(dummy_inputs.values()),
f,
export_params=True,
verbose=False,
opset_version=opset,
do_constant_folding=True,
input_names=list(dummy_inputs.keys()),
output_names=output_names,
dynamic_axes=dynamic_axes,
)
if onnxruntime_exists:
ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
ort_session = onnxruntime.InferenceSession(output)
_ = ort_session.run(None, ort_inputs)
print("Model has successfully been run with ONNXRuntime.")
def to_numpy(tensor):
return tensor.cpu().numpy()
if __name__ == "__main__":
args = parser.parse_args()
run_export(
model_type=args.model_type,
checkpoint=args.checkpoint,
output=args.output,
opset=args.opset,
return_single_mask=args.return_single_mask,
gelu_approximate=args.gelu_approximate,
use_stability_score=args.use_stability_score,
return_extra_metrics=args.return_extra_metrics,
)
if args.quantize_out is not None:
assert onnxruntime_exists, "onnxruntime is required to quantize the model."
from onnxruntime.quantization import QuantType # type: ignore
from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
print(f"Quantizing model and writing to {args.quantize_out}...")
quantize_dynamic(
model_input=args.output,
model_output=args.quantize_out,
optimize_model=True,
per_channel=False,
reduce_range=False,
weight_type=QuantType.QUInt8,
)
print("Done!")
|