Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import TYPE_CHECKING, Union | |
import numpy as np | |
import torch | |
if TYPE_CHECKING: | |
from mmengine.model import BaseModel | |
def inference_model(model: 'BaseModel', img: Union[str, np.ndarray]): | |
"""Inference image(s) with the classifier. | |
Args: | |
model (BaseClassifier): The loaded classifier. | |
img (str/ndarray): The image filename or loaded image. | |
Returns: | |
result (dict): The classification results that contains | |
`class_name`, `pred_label` and `pred_score`. | |
""" | |
from mmengine.dataset import Compose, default_collate | |
from mmengine.registry import DefaultScope | |
import mmcls.datasets # noqa: F401 | |
cfg = model.cfg | |
# build the data pipeline | |
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline | |
if isinstance(img, str): | |
if test_pipeline_cfg[0]['type'] != 'LoadImageFromFile': | |
test_pipeline_cfg.insert(0, dict(type='LoadImageFromFile')) | |
data = dict(img_path=img) | |
else: | |
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile': | |
test_pipeline_cfg.pop(0) | |
data = dict(img=img) | |
with DefaultScope.overwrite_default_scope('mmcls'): | |
test_pipeline = Compose(test_pipeline_cfg) | |
data = test_pipeline(data) | |
data = default_collate([data]) | |
# forward the model | |
with torch.no_grad(): | |
prediction = model.val_step(data)[0].pred_label | |
pred_scores = prediction.score.tolist() | |
pred_score = torch.max(prediction.score).item() | |
pred_label = torch.argmax(prediction.score).item() | |
result = { | |
'pred_label': pred_label, | |
'pred_score': float(pred_score), | |
'pred_scores': pred_scores | |
} | |
if hasattr(model, 'CLASSES'): | |
result['pred_class'] = model.CLASSES[result['pred_label']] | |
return result | |
def inference_model_topk(model: 'BaseModel', img: Union[str, np.ndarray], topk=5): | |
"""Inference image(s) with the classifier. | |
Args: | |
model (BaseClassifier): The loaded classifier. | |
img (str/ndarray): The image filename or loaded image. | |
Returns: | |
result (dict): The classification results that contains | |
`class_name`, `pred_label` and `pred_score`. | |
""" | |
from mmengine.dataset import Compose, default_collate | |
from mmengine.registry import DefaultScope | |
import mmcls.datasets # noqa: F401 | |
cfg = model.cfg | |
# build the data pipeline | |
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline | |
if isinstance(img, str): | |
if test_pipeline_cfg[0]['type'] != 'LoadImageFromFile': | |
test_pipeline_cfg.insert(0, dict(type='LoadImageFromFile')) | |
data = dict(img_path=img) | |
else: | |
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile': | |
test_pipeline_cfg.pop(0) | |
data = dict(img=img) | |
with DefaultScope.overwrite_default_scope('mmcls'): | |
test_pipeline = Compose(test_pipeline_cfg) | |
data = test_pipeline(data) | |
data = default_collate([data]) | |
# forward the model | |
with torch.no_grad(): | |
prediction = model.val_step(data)[0].pred_label | |
pred_scores = prediction.score.numpy() | |
idxs = torch.argsort(prediction.score, descending=True, dim=-1)[:topk] | |
pred_score = prediction.score[idxs].numpy() | |
pred_label = idxs.numpy() | |
result = { | |
'pred_label': pred_label, | |
'pred_score': pred_score, | |
'pred_scores': pred_scores | |
} | |
if hasattr(model, 'CLASSES'): | |
result['pred_class'] = [model.CLASSES[x] for x in result['pred_label']] | |
return result | |