File size: 1,106 Bytes
93ee614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
import ncnn
import torch

def test_inference():
    torch.manual_seed(0)
    in0 = torch.rand(1, 3, 384, 384, dtype=torch.float)
    in1 = torch.rand(1, 3, 384, 384, dtype=torch.float)
    out = []

    with ncnn.Net() as net:
         net.load_param("/Users/raoulritter/STB-VMM/20x/modelpnnx20x.ncnn.param")
         net.load_model("/Users/raoulritter/STB-VMM/20x/modelpnnx20x.ncnn.bin")

         with net.create_extractor() as ex:
            ex.input("in0", ncnn.Mat(in0.squeeze(0).numpy()).clone())
            ex.input("in1", ncnn.Mat(in1.squeeze(0).numpy()).clone())

            _, out0 = ex.extract("out0")
            out.append(torch.from_numpy(np.array(out0)).unsqueeze(0))
            _, out1 = ex.extract("out1")
            out.append(torch.from_numpy(np.array(out1)).unsqueeze(0))
            _, out2 = ex.extract("out2")
            out.append(torch.from_numpy(np.array(out2)).unsqueeze(0))
            _, out3 = ex.extract("out3")
            out.append(torch.from_numpy(np.array(out3)))

    if len(out) == 1:
        return out[0]
    else:
        return tuple(out)