"""Tests for async interface.""" |
import asyncio |
import os |
import sys |
import asynctest |
import mmcv |
import torch |
from mmdet.apis import async_inference_detector, init_detector |
if sys.version_info >= (3, 7): |
from mmdet.utils.contextmanagers import concurrent |
class AsyncTestCase(asynctest.TestCase): |
use_default_loop = False |
forbid_get_event_loop = True |
TEST_TIMEOUT = int(os.getenv('ASYNCIO_TEST_TIMEOUT', '30')) |
def _run_test_method(self, method): |
result = method() |
if asyncio.iscoroutine(result): |
self.loop.run_until_complete( |
asyncio.wait_for(result, timeout=self.TEST_TIMEOUT)) |
class MaskRCNNDetector: |
def __init__(self, |
model_config, |
checkpoint=None, |
streamqueue_size=3, |
device='cuda:0'): |
self.streamqueue_size = streamqueue_size |
self.device = device |
self.model = init_detector( |
model_config, checkpoint=None, device=self.device) |
self.streamqueue = None |
async def init(self): |
self.streamqueue = asyncio.Queue() |
for _ in range(self.streamqueue_size): |
stream = torch.cuda.Stream(device=self.device) |
self.streamqueue.put_nowait(stream) |
if sys.version_info >= (3, 7): |
async def apredict(self, img): |
if isinstance(img, str): |
img = mmcv.imread(img) |
async with concurrent(self.streamqueue): |
result = await async_inference_detector(self.model, img) |
return result |
class AsyncInferenceTestCase(AsyncTestCase): |
if sys.version_info >= (3, 7): |
async def test_simple_inference(self): |
if not torch.cuda.is_available(): |
import pytest |
pytest.skip('test requires GPU and torch+cuda') |
ori_grad_enabled = torch.is_grad_enabled() |
root_dir = os.path.dirname(os.path.dirname(__name__)) |
model_config = os.path.join( |
root_dir, 'configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py') |
detector = MaskRCNNDetector(model_config) |
await detector.init() |
img_path = os.path.join(root_dir, 'demo/demo.jpg') |
bboxes, _ = await detector.apredict(img_path) |
self.assertTrue(bboxes) |
torch.set_grad_enabled(ori_grad_enabled) |