File size: 3,652 Bytes
f549064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Copyright (c) OpenMMLab. All rights reserved.
from typing import TYPE_CHECKING, Union

import numpy as np
import torch

if TYPE_CHECKING:
    from mmengine.model import BaseModel


def inference_model(model: 'BaseModel', img: Union[str, np.ndarray]):
    """Inference image(s) with the classifier.

    Args:
        model (BaseClassifier): The loaded classifier.
        img (str/ndarray): The image filename or loaded image.

    Returns:
        result (dict): The classification results that contains
            `class_name`, `pred_label` and `pred_score`.
    """
    from mmengine.dataset import Compose, default_collate
    from mmengine.registry import DefaultScope

    import mmcls.datasets  # noqa: F401

    cfg = model.cfg
    # build the data pipeline
    test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
    if isinstance(img, str):
        if test_pipeline_cfg[0]['type'] != 'LoadImageFromFile':
            test_pipeline_cfg.insert(0, dict(type='LoadImageFromFile'))
        data = dict(img_path=img)
    else:
        if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
            test_pipeline_cfg.pop(0)
        data = dict(img=img)
    with DefaultScope.overwrite_default_scope('mmcls'):
        test_pipeline = Compose(test_pipeline_cfg)
    data = test_pipeline(data)
    data = default_collate([data])

    # forward the model
    with torch.no_grad():
        prediction = model.val_step(data)[0].pred_label
        pred_scores = prediction.score.tolist()
        pred_score = torch.max(prediction.score).item()
        pred_label = torch.argmax(prediction.score).item()
        result = {
            'pred_label': pred_label,
            'pred_score': float(pred_score),
            'pred_scores': pred_scores
        }
    if hasattr(model, 'CLASSES'):
        result['pred_class'] = model.CLASSES[result['pred_label']]
    return result


def inference_model_topk(model: 'BaseModel', img: Union[str, np.ndarray], topk=5):
    """Inference image(s) with the classifier.

    Args:
        model (BaseClassifier): The loaded classifier.
        img (str/ndarray): The image filename or loaded image.

    Returns:
        result (dict): The classification results that contains
            `class_name`, `pred_label` and `pred_score`.
    """
    from mmengine.dataset import Compose, default_collate
    from mmengine.registry import DefaultScope

    import mmcls.datasets  # noqa: F401

    cfg = model.cfg
    # build the data pipeline
    test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
    if isinstance(img, str):
        if test_pipeline_cfg[0]['type'] != 'LoadImageFromFile':
            test_pipeline_cfg.insert(0, dict(type='LoadImageFromFile'))
        data = dict(img_path=img)
    else:
        if test_pipeline_cfg[0]['type'] == 'LoadImageFromFile':
            test_pipeline_cfg.pop(0)
        data = dict(img=img)
    with DefaultScope.overwrite_default_scope('mmcls'):
        test_pipeline = Compose(test_pipeline_cfg)
    data = test_pipeline(data)
    data = default_collate([data])

    # forward the model
    with torch.no_grad():
        prediction = model.val_step(data)[0].pred_label
        pred_scores = prediction.score.numpy()

        idxs = torch.argsort(prediction.score, descending=True, dim=-1)[:topk]
        pred_score = prediction.score[idxs].numpy()
        pred_label = idxs.numpy()
        result = {
            'pred_label': pred_label,
            'pred_score': pred_score,
            'pred_scores': pred_scores
        }
    if hasattr(model, 'CLASSES'):
        result['pred_class'] = [model.CLASSES[x] for x in result['pred_label']]
    return result