|
import onnxruntime |
|
import numpy as np |
|
from PIL import Image |
|
import torchvision.transforms as transforms |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
session = onnxruntime.InferenceSession("./BEN2_Base.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) |
|
|
|
def postprocess_image(result_np: np.ndarray, im_size: list) -> np.ndarray: |
|
|
|
result = torch.from_numpy(result_np) |
|
|
|
|
|
if len(result.shape) == 3: |
|
result = result.unsqueeze(0) |
|
|
|
|
|
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0) |
|
|
|
|
|
ma = torch.max(result) |
|
mi = torch.min(result) |
|
result = (result - mi) / (ma - mi) |
|
|
|
im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8) |
|
im_array = np.squeeze(im_array) |
|
return im_array |
|
|
|
def preprocess_image(image): |
|
original_size = image.size |
|
transform = transforms.Compose([ |
|
transforms.Resize((1024, 1024)), |
|
transforms.ToTensor(), |
|
]) |
|
img_tensor = transform(image) |
|
|
|
img_tensor = img_tensor.unsqueeze(0) |
|
return img_tensor.numpy(), image, original_size |
|
|
|
def run_inference(image): |
|
|
|
input_data, original_image, (w, h) = preprocess_image(image) |
|
|
|
input_name = session.get_inputs()[0].name |
|
|
|
outputs = session.run(None, {input_name: input_data}) |
|
|
|
|
|
alpha = postprocess_image(outputs[0], im_size=[w, h]) |
|
|
|
|
|
mask = Image.fromarray(alpha) |
|
mask = mask.resize((w, h)) |
|
|
|
|
|
original_image.putalpha(mask) |
|
return original_image |
|
|
|
|
|
image_path = "image.png" |
|
output_path = "output.png" |
|
|
|
|
|
image = Image.open(image_path) |
|
|
|
result_image = run_inference(image) |
|
result_image.save(output_path) |
|
|