MaxwellMeyer commited on
Commit
9305c9a
·
verified ·
1 Parent(s): 19d0d22

Upload 2 files

Browse files
Files changed (2) hide show
  1. BEN2_Base.onnx +3 -0
  2. onnx_run.py +69 -0
BEN2_Base.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22cea62108ff53b7ccc20f7a008bf30494228d84b1687f29ecbe76936a998101
3
+ size 222932053
onnx_run.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ session = onnxruntime.InferenceSession("./onnx/BEN2_Base.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
10
+
11
+ def postprocess_image(result_np: np.ndarray, im_size: list) -> np.ndarray:
12
+
13
+ result = torch.from_numpy(result_np)
14
+
15
+
16
+ if len(result.shape) == 3:
17
+ result = result.unsqueeze(0)
18
+
19
+
20
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
21
+
22
+
23
+ ma = torch.max(result)
24
+ mi = torch.min(result)
25
+ result = (result - mi) / (ma - mi)
26
+
27
+ im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
28
+ im_array = np.squeeze(im_array)
29
+ return im_array
30
+
31
+ def preprocess_image(image):
32
+ original_size = image.size
33
+ transform = transforms.Compose([
34
+ transforms.Resize((1024, 1024)),
35
+ transforms.ToTensor(),
36
+ ])
37
+ img_tensor = transform(image)
38
+
39
+ img_tensor = img_tensor.unsqueeze(0)
40
+ return img_tensor.numpy(), image, original_size
41
+
42
+ def run_inference(image):
43
+
44
+ input_data, original_image, (w, h) = preprocess_image(image)
45
+
46
+ input_name = session.get_inputs()[0].name
47
+
48
+ outputs = session.run(None, {input_name: input_data})
49
+
50
+
51
+ alpha = postprocess_image(outputs[0], im_size=[w, h])
52
+
53
+
54
+ mask = Image.fromarray(alpha)
55
+ mask = mask.resize((w, h))
56
+
57
+
58
+ original_image.putalpha(mask)
59
+ return original_image
60
+
61
+ # Example usage
62
+ image_path = "image.png"
63
+ output_path = "output.png"
64
+
65
+
66
+ image = Image.open(image_path)
67
+
68
+ result_image = run_inference(image)
69
+ result_image.save(output_path)