Unique3D / gradio_app /custom_models /mvimg_prediction.py
Wuvin's picture
try
8bfc447
raw
history blame
1.91 kB
import sys
import torch
import gradio as gr
from PIL import Image
import numpy as np
from rembg import remove
from gradio_app.utils import change_rgba_bg, rgba_to_rgb
from gradio_app.custom_models.utils import load_pipeline
from scripts.all_typing import *
from scripts.utils import session, simple_preprocess
training_config = "gradio_app/custom_models/image2mvimage.yaml"
checkpoint_path = "ckpt/img2mvimg/unet_state_dict.pth"
trainer, pipeline = load_pipeline(training_config, checkpoint_path)
def predict(img_list: List[Image.Image], guidance_scale=2., **kwargs):
pipeline.enable_model_cpu_offload()
if isinstance(img_list, Image.Image):
img_list = [img_list]
img_list = [rgba_to_rgb(i) if i.mode == 'RGBA' else i for i in img_list]
ret = []
for img in img_list:
images = trainer.pipeline_forward(
pipeline=pipeline,
image=img,
guidance_scale=guidance_scale,
**kwargs
).images
ret.extend(images)
return ret
def run_mvprediction(input_image: Image.Image, remove_bg=True, guidance_scale=1.5, seed=1145):
if input_image.mode == 'RGB' or np.array(input_image)[..., -1].mean() == 255.:
# still do remove using rembg, since simple_preprocess requires RGBA image
print("RGB image not RGBA! still remove bg!")
remove_bg = True
if remove_bg:
input_image = remove(input_image, session=session)
# make front_pil RGBA with white bg
input_image = change_rgba_bg(input_image, "white")
single_image = simple_preprocess(input_image)
generator = torch.Generator(device="cuda").manual_seed(int(seed)) if seed >= 0 else None
rgb_pils = predict(
single_image,
generator=generator,
guidance_scale=guidance_scale,
width=256,
height=256,
num_inference_steps=30,
)
return rgb_pils, single_image