BEN2 / onnx_run.py
MaxwellMeyer's picture
Update onnx_run.py
d66d54e verified
import onnxruntime
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torch
import torch.nn.functional as F
session = onnxruntime.InferenceSession("./BEN2_Base.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
def postprocess_image(result_np: np.ndarray, im_size: list) -> np.ndarray:
result = torch.from_numpy(result_np)
if len(result.shape) == 3:
result = result.unsqueeze(0)
result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
im_array = np.squeeze(im_array)
return im_array
def preprocess_image(image):
original_size = image.size
transform = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
])
img_tensor = transform(image)
img_tensor = img_tensor.unsqueeze(0)
return img_tensor.numpy(), image, original_size
def run_inference(image):
input_data, original_image, (w, h) = preprocess_image(image)
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: input_data})
alpha = postprocess_image(outputs[0], im_size=[w, h])
mask = Image.fromarray(alpha)
mask = mask.resize((w, h))
original_image.putalpha(mask)
return original_image
# Example usage
image_path = "image.png"
output_path = "output.png"
image = Image.open(image_path)
result_image = run_inference(image)
result_image.save(output_path)