File size: 1,106 Bytes
93ee614 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import numpy as np
import ncnn
import torch
def test_inference():
torch.manual_seed(0)
in0 = torch.rand(1, 3, 384, 384, dtype=torch.float)
in1 = torch.rand(1, 3, 384, 384, dtype=torch.float)
out = []
with ncnn.Net() as net:
net.load_param("/Users/raoulritter/STB-VMM/20x/modelpnnx20x.ncnn.param")
net.load_model("/Users/raoulritter/STB-VMM/20x/modelpnnx20x.ncnn.bin")
with net.create_extractor() as ex:
ex.input("in0", ncnn.Mat(in0.squeeze(0).numpy()).clone())
ex.input("in1", ncnn.Mat(in1.squeeze(0).numpy()).clone())
_, out0 = ex.extract("out0")
out.append(torch.from_numpy(np.array(out0)).unsqueeze(0))
_, out1 = ex.extract("out1")
out.append(torch.from_numpy(np.array(out1)).unsqueeze(0))
_, out2 = ex.extract("out2")
out.append(torch.from_numpy(np.array(out2)).unsqueeze(0))
_, out3 = ex.extract("out3")
out.append(torch.from_numpy(np.array(out3)))
if len(out) == 1:
return out[0]
else:
return tuple(out)
|