File size: 451 Bytes
d44e389
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import numpy as np
import onnxruntime


class ONNXModel:
    def __init__(self, onnx_mode_path):
        self.path = onnx_mode_path
        self.ort_session = onnxruntime.InferenceSession(str(self.path))
        self.input_name = self.ort_session.get_inputs()[0].name

    def __call__(self, img):
        ort_inputs = {self.input_name: img.astype(dtype=np.float32)}
        ort_outs = self.ort_session.run(None, ort_inputs)[0]
        return ort_outs