gpu-utils / app.py
not-lain's picture
switch to sam2 large
430a7a9
import gradio as gr
import spaces
import torch
from loadimg import load_img
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from diffusers import FluxFillPipeline
from PIL import Image, ImageOps
from sam2.sam2_image_predictor import SAM2ImagePredictor
import numpy as np
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
).to("cuda")
def prepare_image_and_mask(
image,
padding_top=0,
padding_bottom=0,
padding_left=0,
padding_right=0,
):
image = load_img(image).convert("RGB")
# expand image (left,top,right,bottom)
background = ImageOps.expand(
image,
border=(padding_left, padding_top, padding_right, padding_bottom),
fill="white",
)
mask = Image.new("RGB", image.size, "black")
mask = ImageOps.expand(
mask,
border=(padding_left, padding_top, padding_right, padding_bottom),
fill="white",
)
return background, mask
def outpaint(
image,
padding_top=0,
padding_bottom=0,
padding_left=0,
padding_right=0,
prompt="",
num_inference_steps=28,
guidance_scale=50,
):
background, mask = prepare_image_and_mask(
image, padding_top, padding_bottom, padding_left, padding_right
)
result = pipe(
prompt=prompt,
height=background.height,
width=background.width,
image=background,
mask_image=mask,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
result = result.convert("RGBA")
return result
def inpaint(
image,
mask,
prompt="",
num_inference_steps=28,
guidance_scale=50,
):
background = image.convert("RGB")
mask = mask.convert("L")
result = pipe(
prompt=prompt,
height=background.height,
width=background.width,
image=background,
mask_image=mask,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
result = result.convert("RGBA")
return result
def rmbg(image=None, url=None):
if image is None:
image = url
image = load_img(image).convert("RGB")
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cuda")
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
def mask_generation(image=None, d=None):
d = eval(d) # convert this to dictionary
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
predictor.set_image(image)
input_point = np.array(d["input_points"])
input_label = np.array(d["input_labels"])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind]
scores = scores[sorted_ind]
logits = logits[sorted_ind]
out = []
for i in range(len(masks)):
m = Image.fromarray(masks[i] * 255).convert("L")
comp = Image.composite(image, m, m)
out.append((comp, f"image {i}"))
return out
@spaces.GPU
def main(*args):
api_num = args[0]
args = args[1:]
if api_num == 1:
return rmbg(*args)
elif api_num == 2:
return outpaint(*args)
elif api_num == 3:
return inpaint(*args)
elif api_num == 4:
return mask_generation(*args)
rmbg_tab = gr.Interface(
fn=main,
inputs=[
gr.Number(1, interactive=False),
"image",
gr.Text("", label="url"),
],
outputs=["image"],
api_name="rmbg",
examples=[[1, "./assets/Inpainting mask.png", ""]],
cache_examples=False,
description="pass an image or a url of an image",
)
outpaint_tab = gr.Interface(
fn=main,
inputs=[
gr.Number(2, interactive=False),
gr.Image(label="image", type="pil"),
gr.Number(label="padding top"),
gr.Number(label="padding bottom"),
gr.Number(label="padding left"),
gr.Number(label="padding right"),
gr.Text(label="prompt"),
gr.Number(value=50, label="num_inference_steps"),
gr.Number(value=28, label="guidance_scale"),
],
outputs=["image"],
api_name="outpainting",
examples=[[2, "./assets/rocket.png", 100, 0, 0, 0, "", 50, 28]],
cache_examples=False,
)
inpaint_tab = gr.Interface(
fn=main,
inputs=[
gr.Number(3, interactive=False),
gr.Image(label="image", type="pil"),
gr.Image(label="mask", type="pil"),
gr.Text(label="prompt"),
gr.Number(value=50, label="num_inference_steps"),
gr.Number(value=28, label="guidance_scale"),
],
outputs=["image"],
api_name="inpaint",
examples=[[3, "./assets/rocket.png", "./assets/Inpainting mask.png"]],
cache_examples=False,
description="it is recommended that you use https://github.com/la-voliere/react-mask-editor when creating an image mask in JS and then inverse it before sending it to this space",
)
sam2_tab = gr.Interface(
main,
inputs=[
gr.Number(4, interactive=False),
gr.Image(type="pil"),
gr.Text(),
],
outputs=gr.Gallery(),
examples=[
[
4,
"./assets/truck.jpg",
'{"input_points": [[500, 375], [1125, 625]], "input_labels": [1, 0]}',
]
],
api_name="sam2",
cache_examples=False,
)
demo = gr.TabbedInterface(
[rmbg_tab, outpaint_tab, inpaint_tab, sam2_tab],
["remove background", "outpainting", "inpainting", "sam2"],
title="Utilities that require GPU",
)
demo.launch()