File size: 3,110 Bytes
003d203
 
 
0e6c023
 
 
aa16383
dcfda89
0e6c023
 
 
 
 
 
 
 
 
 
 
 
 
 
 
003d203
aa16383
 
 
 
 
 
 
dcfda89
 
 
 
aa16383
dcfda89
 
 
 
 
 
aa16383
dcfda89
 
aa16383
 
 
 
 
dcfda89
 
 
 
 
aa16383
 
 
dcfda89
aa16383
 
 
 
 
dcfda89
 
 
 
aa16383
dcfda89
aa16383
 
 
 
dcfda89
aa16383
 
 
 
e7bef73
 
0e6c023
 
 
 
 
 
 
 
 
 
 
 
dcfda89
 
991bda2
 
 
 
aa16383
 
 
dcfda89
aa16383
 
 
dcfda89
 
 
 
 
 
 
 
 
 
 
aa16383
0e6c023
 
aa16383
 
94386db
0e6c023
 
003d203
0e6c023
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import spaces
import torch
from loadimg import load_img
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from diffusers import FluxFillPipeline
from PIL import Image, ImageOps

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")

transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

pipe = FluxFillPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
).to("cuda")


def prepare_image_and_mask(
    image,
    padding_top=0,
    padding_bottom=0,
    padding_left=0,
    padding_right=0,
):
    image = load_img(image).convert("RGB")
    # expand image (left,top,right,bottom)
    background = ImageOps.expand(
        image,
        border=(padding_left, padding_top, padding_right, padding_bottom),
        fill="white",
    )
    mask = Image.new("RGB", image.size, "black")
    mask = ImageOps.expand(mask, border=(0, 20, 0, 0), fill="white")
    return background, mask


def inpaint(
    image,
    padding_top=0,
    padding_bottom=0,
    padding_left=0,
    padding_right=0,
    prompt="",
    progress=gr.Progress(track_tqdm=True),
):
    background, mask = prepare_image_and_mask(
        image, padding_top, padding_bottom, padding_left, padding_right
    )

    # generator = torch.Generator(device="cuda").manual_seed(42)

    result = pipe(
        prompt=prompt,
        height=background.height,
        width=background.width,
        image=background,
        mask_image=mask,
        num_inference_steps=28,
        guidance_scale=30,
    ).images[0]

    result = result.convert("RGBA")
    return result


def rmbg(image, url):
    if image is None:
        image = url
    image = load_img(image).convert("RGB")
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    image.putalpha(mask)
    return image


@spaces.GPU
def main(*args, **kwargs):
    if len (args) == 2:
        return rmbg(*args)
    else : 
        return inpaint(*args, **kwargs)


rmbg_tab = gr.Interface(
    fn=main, inputs=["image", "text"], outputs=["image"], api_name="rmbg"
)

outpaint_tab = gr.Interface(
    fn=main,
    inputs=[
        "image",
        gr.Slider(label="padding top"),
        gr.Slider(label="padding bottom"),
        gr.Slider(label="padding left"),
        gr.Slider(label="padding right"),
        gr.Text(label="prompt"),
    ],
    outputs=["image"],
    api_name="outpainting",
)

demo = gr.TabbedInterface(
    [rmbg_tab, outpaint_tab],
    ["remove background", "outpainting"],
    title="Utilities that require GPU",
)


demo.launch()