mask_inpaint / controlnetSD.py
spdraptor's picture
test
02af9e3
from utils_inpaint import resize_image_dimensions, make_inpaint_condition
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
import spaces
import time
from PIL import Image
import numpy as np
device = torch.device('cuda')
@spaces.GPU(duration=50)
def mask_based_updating2(init_image_file,mask_image_file,prompt,strength=0.9, guidance_scale=9, num_inference_steps=100):
# load ControlNet
start_time = time.time()
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint")
# pass ControlNet to the pipeline
pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"fluently/Fluently-v4-inpainting", controlnet=controlnet
)
# pipeline.enable_model_cpu_offload()
pipeline.to(device)
init_image = Image.fromarray(init_image_file)
mask_image = Image.fromarray(mask_image_file)
init_image = init_image.convert("RGB")
mask_image = mask_image.convert("1")
width, height = init_image.size
width_new, height_new = resize_image_dimensions(original_resolution_wh=init_image.size)
init_image = init_image.resize((width_new, height_new), Image.LANCZOS)
mask_image = mask_image.resize((width_new, height_new), Image.NEAREST)
#image and mask_image should be PIL images.
#The mask structure is white for inpainting and black for keeping as is
# image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
control_image = make_inpaint_condition(init_image, mask_image)
print("para: ",strength, guidance_scale,num_inference_steps)
negative_prompt = "ugly, deformed, nsfw, disfigured, worst quality, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch, duplicate, ugly, monochrome, horror, geometry, mutation, disgusting, bad anatomy, faint, unrealistic, Cartoon, drawing"
image = pipeline(prompt=prompt,negative_prompt=negative_prompt, image=init_image, mask_image=mask_image, control_image=control_image,strength = strength, guidance_scale=guidance_scale,num_inference_steps=num_inference_steps).images[0]
image = image.resize((width, height), Image.LANCZOS)
print(f'Time taken by inpainting model: {time.time() - start_time}')
torch.cuda.empty_cache()
return image