gpu-utils / app.py
not-lain's picture
update app
991bda2
raw
history blame
3.11 kB
import gradio as gr
import spaces
import torch
from loadimg import load_img
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from diffusers import FluxFillPipeline
from PIL import Image, ImageOps
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
).to("cuda")
def prepare_image_and_mask(
image,
padding_top=0,
padding_bottom=0,
padding_left=0,
padding_right=0,
):
image = load_img(image).convert("RGB")
# expand image (left,top,right,bottom)
background = ImageOps.expand(
image,
border=(padding_left, padding_top, padding_right, padding_bottom),
fill="white",
)
mask = Image.new("RGB", image.size, "black")
mask = ImageOps.expand(mask, border=(0, 20, 0, 0), fill="white")
return background, mask
def inpaint(
image,
padding_top=0,
padding_bottom=0,
padding_left=0,
padding_right=0,
prompt="",
progress=gr.Progress(track_tqdm=True),
):
background, mask = prepare_image_and_mask(
image, padding_top, padding_bottom, padding_left, padding_right
)
# generator = torch.Generator(device="cuda").manual_seed(42)
result = pipe(
prompt=prompt,
height=background.height,
width=background.width,
image=background,
mask_image=mask,
num_inference_steps=28,
guidance_scale=30,
).images[0]
result = result.convert("RGBA")
return result
def rmbg(image, url):
if image is None:
image = url
image = load_img(image).convert("RGB")
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cuda")
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
@spaces.GPU
def main(*args, **kwargs):
if len (args) == 2:
return rmbg(*args)
else :
return inpaint(*args, **kwargs)
rmbg_tab = gr.Interface(
fn=main, inputs=["image", "text"], outputs=["image"], api_name="rmbg"
)
outpaint_tab = gr.Interface(
fn=main,
inputs=[
"image",
gr.Slider(label="padding top"),
gr.Slider(label="padding bottom"),
gr.Slider(label="padding left"),
gr.Slider(label="padding right"),
gr.Text(label="prompt"),
],
outputs=["image"],
api_name="outpainting",
)
demo = gr.TabbedInterface(
[rmbg_tab, outpaint_tab],
["remove background", "outpainting"],
title="Utilities that require GPU",
)
demo.launch()