Spaces:
Starting
on
L40S
Starting
on
L40S
# Copyright (c) OpenMMLab. All rights reserved. | |
import os | |
from functools import wraps | |
import onnx | |
import pytest | |
import torch | |
from mmcv.ops import nms | |
from mmcv.tensorrt.preprocess import preprocess_onnx | |
if torch.__version__ == 'parrots': | |
pytest.skip('not supported in parrots now', allow_module_level=True) | |
def remove_tmp_file(func): | |
def wrapper(*args, **kwargs): | |
onnx_file = 'tmp.onnx' | |
kwargs['onnx_file'] = onnx_file | |
try: | |
result = func(*args, **kwargs) | |
finally: | |
if os.path.exists(onnx_file): | |
os.remove(onnx_file) | |
return result | |
return wrapper | |
def export_nms_module_to_onnx(module, onnx_file): | |
torch_model = module() | |
torch_model.eval() | |
input = (torch.rand([100, 4], dtype=torch.float32), | |
torch.rand([100], dtype=torch.float32)) | |
torch.onnx.export( | |
torch_model, | |
input, | |
onnx_file, | |
opset_version=11, | |
input_names=['boxes', 'scores'], | |
output_names=['output']) | |
onnx_model = onnx.load(onnx_file) | |
return onnx_model | |
def test_can_handle_nms_with_constant_maxnum(): | |
class ModuleNMS(torch.nn.Module): | |
def forward(self, boxes, scores): | |
return nms(boxes, scores, iou_threshold=0.4, max_num=10) | |
onnx_model = export_nms_module_to_onnx(ModuleNMS) | |
preprocess_onnx_model = preprocess_onnx(onnx_model) | |
for node in preprocess_onnx_model.graph.node: | |
if 'NonMaxSuppression' in node.name: | |
assert len(node.attribute) == 5, 'The NMS must have 5 attributes.' | |
def test_can_handle_nms_with_undefined_maxnum(): | |
class ModuleNMS(torch.nn.Module): | |
def forward(self, boxes, scores): | |
return nms(boxes, scores, iou_threshold=0.4) | |
onnx_model = export_nms_module_to_onnx(ModuleNMS) | |
preprocess_onnx_model = preprocess_onnx(onnx_model) | |
for node in preprocess_onnx_model.graph.node: | |
if 'NonMaxSuppression' in node.name: | |
assert len(node.attribute) == 5, \ | |
'The NMS must have 5 attributes.' | |
assert node.attribute[2].i > 0, \ | |
'The max_output_boxes_per_class is not defined correctly.' | |