File size: 2,368 Bytes
e76f4a8
 
 
 
 
 
 
 
 
 
02af9e3
e76f4a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from utils_inpaint import resize_image_dimensions, make_inpaint_condition
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
import spaces
import time
from PIL import Image
import numpy as np

device = torch.device('cuda')

@spaces.GPU(duration=50)
def mask_based_updating2(init_image_file,mask_image_file,prompt,strength=0.9, guidance_scale=9, num_inference_steps=100):
    # load ControlNet
    start_time = time.time()
    controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint")

    # pass ControlNet to the pipeline
    pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
        "fluently/Fluently-v4-inpainting", controlnet=controlnet
    )
    # pipeline.enable_model_cpu_offload()
    pipeline.to(device)
    
    init_image = Image.fromarray(init_image_file)
    mask_image = Image.fromarray(mask_image_file)
    init_image = init_image.convert("RGB")
    mask_image = mask_image.convert("1")
    width, height = init_image.size
    width_new, height_new = resize_image_dimensions(original_resolution_wh=init_image.size)
    init_image = init_image.resize((width_new, height_new), Image.LANCZOS)
    mask_image = mask_image.resize((width_new, height_new), Image.NEAREST)
    #image and mask_image should be PIL images.
    #The mask structure is white for inpainting and black for keeping as is
    # image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
    control_image = make_inpaint_condition(init_image, mask_image)
    print("para: ",strength, guidance_scale,num_inference_steps)
    negative_prompt = "ugly, deformed, nsfw, disfigured, worst quality, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch, duplicate, ugly, monochrome, horror, geometry, mutation, disgusting, bad anatomy, faint, unrealistic, Cartoon, drawing"
    image = pipeline(prompt=prompt,negative_prompt=negative_prompt, image=init_image, mask_image=mask_image, control_image=control_image,strength = strength, guidance_scale=guidance_scale,num_inference_steps=num_inference_steps).images[0]
    image = image.resize((width, height), Image.LANCZOS)
    print(f'Time taken by inpainting model: {time.time() - start_time}')
    torch.cuda.empty_cache()
    return image