Spaces:
Runtime error
Runtime error
File size: 3,652 Bytes
f549064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import TYPE_CHECKING, Union
import numpy as np
import torch
if TYPE_CHECKING:
from mmengine.model import BaseModel
def inference_model(model: 'BaseModel', img: Union[str, np.ndarray]):
"""Inference image(s) with the classifier.
Args:
model (BaseClassifier): The loaded classifier.
img (str/ndarray): The image filename or loaded image.
Returns:
result (dict): The classification results that contains
`class_name`, `pred_label` and `pred_score`.
"""
from mmengine.dataset import Compose, default_collate
from mmengine.registry import DefaultScope
import mmcls.datasets # noqa: F401
cfg = model.cfg
# build the data pipeline
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
if isinstance(img, str):
if test_pipeline_cfg[0]['type'] != 'LoadImageFromFile':
test_pipeline_cfg.insert(0, dict(type='LoadImageFromFile'))
data = dict(img_path=img)
else:
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
test_pipeline_cfg.pop(0)
data = dict(img=img)
with DefaultScope.overwrite_default_scope('mmcls'):
test_pipeline = Compose(test_pipeline_cfg)
data = test_pipeline(data)
data = default_collate([data])
# forward the model
with torch.no_grad():
prediction = model.val_step(data)[0].pred_label
pred_scores = prediction.score.tolist()
pred_score = torch.max(prediction.score).item()
pred_label = torch.argmax(prediction.score).item()
result = {
'pred_label': pred_label,
'pred_score': float(pred_score),
'pred_scores': pred_scores
}
if hasattr(model, 'CLASSES'):
result['pred_class'] = model.CLASSES[result['pred_label']]
return result
def inference_model_topk(model: 'BaseModel', img: Union[str, np.ndarray], topk=5):
"""Inference image(s) with the classifier.
Args:
model (BaseClassifier): The loaded classifier.
img (str/ndarray): The image filename or loaded image.
Returns:
result (dict): The classification results that contains
`class_name`, `pred_label` and `pred_score`.
"""
from mmengine.dataset import Compose, default_collate
from mmengine.registry import DefaultScope
import mmcls.datasets # noqa: F401
cfg = model.cfg
# build the data pipeline
test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
if isinstance(img, str):
if test_pipeline_cfg[0]['type'] != 'LoadImageFromFile':
test_pipeline_cfg.insert(0, dict(type='LoadImageFromFile'))
data = dict(img_path=img)
else:
if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
test_pipeline_cfg.pop(0)
data = dict(img=img)
with DefaultScope.overwrite_default_scope('mmcls'):
test_pipeline = Compose(test_pipeline_cfg)
data = test_pipeline(data)
data = default_collate([data])
# forward the model
with torch.no_grad():
prediction = model.val_step(data)[0].pred_label
pred_scores = prediction.score.numpy()
idxs = torch.argsort(prediction.score, descending=True, dim=-1)[:topk]
pred_score = prediction.score[idxs].numpy()
pred_label = idxs.numpy()
result = {
'pred_label': pred_label,
'pred_score': pred_score,
'pred_scores': pred_scores
}
if hasattr(model, 'CLASSES'):
result['pred_class'] = [model.CLASSES[x] for x in result['pred_label']]
return result
|