AiOS / mmcv /tests /test_ops /test_tensorrt_preprocess.py
ttxskk
update
d7e58f0
raw
history blame
2.23 kB
# Copyright (c) OpenMMLab. All rights reserved.
import os
from functools import wraps
import onnx
import pytest
import torch
from mmcv.ops import nms
from mmcv.tensorrt.preprocess import preprocess_onnx
if torch.__version__ == 'parrots':
pytest.skip('not supported in parrots now', allow_module_level=True)
def remove_tmp_file(func):
@wraps(func)
def wrapper(*args, **kwargs):
onnx_file = 'tmp.onnx'
kwargs['onnx_file'] = onnx_file
try:
result = func(*args, **kwargs)
finally:
if os.path.exists(onnx_file):
os.remove(onnx_file)
return result
return wrapper
@remove_tmp_file
def export_nms_module_to_onnx(module, onnx_file):
torch_model = module()
torch_model.eval()
input = (torch.rand([100, 4], dtype=torch.float32),
torch.rand([100], dtype=torch.float32))
torch.onnx.export(
torch_model,
input,
onnx_file,
opset_version=11,
input_names=['boxes', 'scores'],
output_names=['output'])
onnx_model = onnx.load(onnx_file)
return onnx_model
def test_can_handle_nms_with_constant_maxnum():
class ModuleNMS(torch.nn.Module):
def forward(self, boxes, scores):
return nms(boxes, scores, iou_threshold=0.4, max_num=10)
onnx_model = export_nms_module_to_onnx(ModuleNMS)
preprocess_onnx_model = preprocess_onnx(onnx_model)
for node in preprocess_onnx_model.graph.node:
if 'NonMaxSuppression' in node.name:
assert len(node.attribute) == 5, 'The NMS must have 5 attributes.'
def test_can_handle_nms_with_undefined_maxnum():
class ModuleNMS(torch.nn.Module):
def forward(self, boxes, scores):
return nms(boxes, scores, iou_threshold=0.4)
onnx_model = export_nms_module_to_onnx(ModuleNMS)
preprocess_onnx_model = preprocess_onnx(onnx_model)
for node in preprocess_onnx_model.graph.node:
if 'NonMaxSuppression' in node.name:
assert len(node.attribute) == 5, \
'The NMS must have 5 attributes.'
assert node.attribute[2].i > 0, \
'The max_output_boxes_per_class is not defined correctly.'