File size: 1,675 Bytes
9305c9a d66d54e 9305c9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import onnxruntime
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torch
import torch.nn.functional as F
session = onnxruntime.InferenceSession("./BEN2_Base.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
def postprocess_image(result_np: np.ndarray, im_size: list) -> np.ndarray:
result = torch.from_numpy(result_np)
if len(result.shape) == 3:
result = result.unsqueeze(0)
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
im_array = np.squeeze(im_array)
return im_array
def preprocess_image(image):
original_size = image.size
transform = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
])
img_tensor = transform(image)
img_tensor = img_tensor.unsqueeze(0)
return img_tensor.numpy(), image, original_size
def run_inference(image):
input_data, original_image, (w, h) = preprocess_image(image)
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: input_data})
alpha = postprocess_image(outputs[0], im_size=[w, h])
mask = Image.fromarray(alpha)
mask = mask.resize((w, h))
original_image.putalpha(mask)
return original_image
# Example usage
image_path = "image.png"
output_path = "output.png"
image = Image.open(image_path)
result_image = run_inference(image)
result_image.save(output_path)
|