File size: 1,675 Bytes
9305c9a
 
 
 
 
 
 
 
d66d54e
9305c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import onnxruntime
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torch
import torch.nn.functional as F


session = onnxruntime.InferenceSession("./BEN2_Base.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])

def postprocess_image(result_np: np.ndarray, im_size: list) -> np.ndarray:

    result = torch.from_numpy(result_np)
    

    if len(result.shape) == 3:
        result = result.unsqueeze(0)
    

    result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
    

    ma = torch.max(result)
    mi = torch.min(result)
    result = (result - mi) / (ma - mi)
    
    im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
    im_array = np.squeeze(im_array)
    return im_array

def preprocess_image(image):
    original_size = image.size
    transform = transforms.Compose([
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
    ])
    img_tensor = transform(image)

    img_tensor = img_tensor.unsqueeze(0)
    return img_tensor.numpy(), image, original_size

def run_inference(image):

    input_data, original_image, (w, h) = preprocess_image(image)

    input_name = session.get_inputs()[0].name
    
    outputs = session.run(None, {input_name: input_data})
    

    alpha = postprocess_image(outputs[0], im_size=[w, h])
    

    mask = Image.fromarray(alpha)
    mask = mask.resize((w, h))
    

    original_image.putalpha(mask)
    return original_image

# Example usage
image_path = "image.png"
output_path = "output.png"


image = Image.open(image_path)

result_image = run_inference(image)
result_image.save(output_path)