import argparse |
import os |
from typing import Dict, List, Tuple |
import torch |
from torch import Tensor, nn |
import detectron2.data.transforms as T |
from detectron2.checkpoint import DetectionCheckpointer |
from detectron2.config import get_cfg |
from detectron2.data import build_detection_test_loader, detection_utils |
from detectron2.evaluation import COCOEvaluator, inference_on_dataset, print_csv_format |
from detectron2.export import ( |
TracingAdapter, |
dump_torchscript_IR, |
scripting_with_instances, |
) |
from detectron2.modeling import GeneralizedRCNN, RetinaNet, build_model |
from detectron2.modeling.postprocessing import detector_postprocess |
from detectron2.projects.point_rend import add_pointrend_config |
from detectron2.structures import Boxes |
from detectron2.utils.env import TORCH_VERSION |
from detectron2.utils.file_io import PathManager |
from detectron2.utils.logger import setup_logger |
def setup_cfg(args): |
cfg = get_cfg() |
add_pointrend_config(cfg) |
cfg.merge_from_file(args.config_file) |
cfg.merge_from_list(args.opts) |
cfg.freeze() |
return cfg |
def export_caffe2_tracing(cfg, torch_model, inputs): |
from detectron2.export import Caffe2Tracer |
tracer = Caffe2Tracer(cfg, torch_model, inputs) |
if args.format == "caffe2": |
caffe2_model = tracer.export_caffe2() |
caffe2_model.save_protobuf(args.output) |
caffe2_model.save_graph(os.path.join(args.output, "model.svg"), inputs=inputs) |
return caffe2_model |
elif args.format == "onnx": |
import onnx |
onnx_model = tracer.export_onnx() |
onnx.save(onnx_model, os.path.join(args.output, "model.onnx")) |
elif args.format == "torchscript": |
ts_model = tracer.export_torchscript() |
with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: |
torch.jit.save(ts_model, f) |
dump_torchscript_IR(ts_model, args.output) |
def export_scripting(torch_model): |
assert TORCH_VERSION >= (1, 8) |
fields = { |
"proposal_boxes": Boxes, |
"objectness_logits": Tensor, |
"pred_boxes": Boxes, |
"scores": Tensor, |
"pred_classes": Tensor, |
"pred_masks": Tensor, |
"pred_keypoints": torch.Tensor, |
"pred_keypoint_heatmaps": torch.Tensor, |
} |
assert args.format == "torchscript", "Scripting only supports torchscript format." |
class ScriptableAdapterBase(nn.Module): |
def __init__(self): |
super().__init__() |
self.model = torch_model |
self.eval() |
if isinstance(torch_model, GeneralizedRCNN): |
class ScriptableAdapter(ScriptableAdapterBase): |
def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]: |
instances = self.model.inference(inputs, do_postprocess=False) |
return [i.get_fields() for i in instances] |
else: |
class ScriptableAdapter(ScriptableAdapterBase): |
def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]: |
instances = self.model(inputs) |
return [i.get_fields() for i in instances] |
ts_model = scripting_with_instances(ScriptableAdapter(), fields) |
with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: |
torch.jit.save(ts_model, f) |
dump_torchscript_IR(ts_model, args.output) |
return None |
def export_tracing(torch_model, inputs): |
assert TORCH_VERSION >= (1, 8) |
image = inputs[0]["image"] |
inputs = [{"image": image}] |
if isinstance(torch_model, GeneralizedRCNN): |
def inference(model, inputs): |
inst = model.inference(inputs, do_postprocess=False)[0] |
return [{"instances": inst}] |
else: |
inference = None |
traceable_model = TracingAdapter(torch_model, inputs, inference) |
if args.format == "torchscript": |
ts_model = torch.jit.trace(traceable_model, (image,)) |
with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: |
torch.jit.save(ts_model, f) |
dump_torchscript_IR(ts_model, args.output) |
elif args.format == "onnx": |
with PathManager.open(os.path.join(args.output, "model.onnx"), "wb") as f: |
torch.onnx.export(traceable_model, (image,), f, opset_version=STABLE_ONNX_OPSET_VERSION) |
logger.info("Inputs schema: " + str(traceable_model.inputs_schema)) |
logger.info("Outputs schema: " + str(traceable_model.outputs_schema)) |
if args.format != "torchscript": |
return None |
if not isinstance(torch_model, (GeneralizedRCNN, RetinaNet)): |
return None |
def eval_wrapper(inputs): |
""" |
The exported model does not contain the final resize step, which is typically |
unused in deployment but needed for evaluation. We add it manually here. |
""" |
input = inputs[0] |
instances = traceable_model.outputs_schema(ts_model(input["image"]))[0]["instances"] |
postprocessed = detector_postprocess(instances, input["height"], input["width"]) |
return [{"instances": postprocessed}] |
return eval_wrapper |
def get_sample_inputs(args): |
if args.sample_image is None: |
data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) |
first_batch = next(iter(data_loader)) |
return first_batch |
else: |
original_image = detection_utils.read_image(args.sample_image, format=cfg.INPUT.FORMAT) |
aug = T.ResizeShortestEdge( |
) |
height, width = original_image.shape[:2] |
image = aug.get_transform(original_image).apply_image(original_image) |
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) |
inputs = {"image": image, "height": height, "width": width} |
sample_inputs = [inputs] |
return sample_inputs |
if __name__ == "__main__": |
parser = argparse.ArgumentParser(description="Export a model for deployment.") |
parser.add_argument( |
"--format", |
choices=["caffe2", "onnx", "torchscript"], |
help="output format", |
default="torchscript", |
) |
parser.add_argument( |
"--export-method", |
choices=["caffe2_tracing", "tracing", "scripting"], |
help="Method to export models", |
default="tracing", |
) |
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") |
parser.add_argument("--sample-image", default=None, type=str, help="sample image for input") |
parser.add_argument("--run-eval", action="store_true") |
parser.add_argument("--output", help="output directory for the converted model") |
parser.add_argument( |
"opts", |
help="Modify config options using the command-line", |
default=None, |
nargs=argparse.REMAINDER, |
) |
args = parser.parse_args() |
logger = setup_logger() |
logger.info("Command line arguments: " + str(args)) |
PathManager.mkdirs(args.output) |
torch._C._jit_set_bailout_depth(1) |
cfg = setup_cfg(args) |
torch_model = build_model(cfg) |
DetectionCheckpointer(torch_model).resume_or_load(cfg.MODEL.WEIGHTS) |
torch_model.eval() |
if args.export_method == "caffe2_tracing": |
sample_inputs = get_sample_inputs(args) |
exported_model = export_caffe2_tracing(cfg, torch_model, sample_inputs) |
elif args.export_method == "scripting": |
exported_model = export_scripting(torch_model) |
elif args.export_method == "tracing": |
sample_inputs = get_sample_inputs(args) |
exported_model = export_tracing(torch_model, sample_inputs) |
if args.run_eval: |
assert exported_model is not None, ( |
"Python inference is not yet implemented for " |
f"export_method={args.export_method}, format={args.format}." |
) |
logger.info("Running evaluation ... this takes a long time if you export to CPU.") |
dataset = cfg.DATASETS.TEST[0] |
data_loader = build_detection_test_loader(cfg, dataset) |
evaluator = COCOEvaluator(dataset, output_dir=args.output) |
metrics = inference_on_dataset(exported_model, data_loader, evaluator) |
print_csv_format(metrics) |
logger.info("Success.") |