STB-VMM-Simplification / 20x /modelpnnx20x_ncnn.py
raoulritter's picture
Upload 14 files
93ee614
import numpy as np
import ncnn
import torch
def test_inference():
torch.manual_seed(0)
in0 = torch.rand(1, 3, 384, 384, dtype=torch.float)
in1 = torch.rand(1, 3, 384, 384, dtype=torch.float)
out = []
with ncnn.Net() as net:
net.load_param("/Users/raoulritter/STB-VMM/20x/modelpnnx20x.ncnn.param")
net.load_model("/Users/raoulritter/STB-VMM/20x/modelpnnx20x.ncnn.bin")
with net.create_extractor() as ex:
ex.input("in0", ncnn.Mat(in0.squeeze(0).numpy()).clone())
ex.input("in1", ncnn.Mat(in1.squeeze(0).numpy()).clone())
_, out0 = ex.extract("out0")
out.append(torch.from_numpy(np.array(out0)).unsqueeze(0))
_, out1 = ex.extract("out1")
out.append(torch.from_numpy(np.array(out1)).unsqueeze(0))
_, out2 = ex.extract("out2")
out.append(torch.from_numpy(np.array(out2)).unsqueeze(0))
_, out3 = ex.extract("out3")
out.append(torch.from_numpy(np.array(out3)))
if len(out) == 1:
return out[0]
else:
return tuple(out)