File size: 5,000 Bytes
56bed35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import logging
import os
import time
import cv2
from diffusers import StableDiffusionPipeline
import gradio as gr
# import mediapipe as mp
import numpy as np
import PIL
import torch.cuda
from transformers import pipeline
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
force=True)
LOG = logging.getLogger(__name__)
LOG.info("Loading image segmentation model")
seg_kwargs = {
"task": "image-segmentation",
"model": "nvidia/segformer-b0-finetuned-ade-512-512"
}
img_segmentation_model = pipeline(**seg_kwargs)
# mp_selfie_segmentation = mp.solutions.selfie_segmentation
# img_segmentation_model = mp_selfie_segmentation.SelfieSegmentation(model_selection=0)
LOG.info("Loading diffusion model")
diffusion = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
if torch.cuda.is_available():
LOG.info("Moving diffusion model to GPU")
diffusion.to('cuda')
def image_preprocess(image: PIL.Image):
LOG.info("Preprocessing image %s", image)
start = time.time()
# image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
image = resize_image(image)
# image = np.array(image)
# # Convert RGB to BGR
# image = image[:, :, ::-1].copy()
elapsed = time.time() - start
LOG.info("Image preprocessed, %.2f seconds elapsed", elapsed)
return image
def resize_image(image: PIL.Image):
width, height = image.size
ratio = max(width / 512, height / 512)
width = int(width / ratio) // 8 * 8
height = int(height / ratio) // 8 * 8
image = image.resize((width, height))
return image
def extract_selfie_mask(threshold, image):
LOG.info("Extracting selfie mask")
start = time.time()
segments = img_segmentation_model(image)
kept = None
for s in segments:
if s['score'] is None:
s['score'] = 1
if s['label'] == 'person' and s['score'] > 0.99:
if not kept:
kept = s
elif kept['score'] < s['score']:
kept = s
if not kept:
LOG.info("No person found in the photo, skipping")
mask = np.zeros((image.size[1], image.size[0], 3), dtype='float32')
else:
mask = kept['mask']
mask = np.array(mask, dtype='float32')
cv2.threshold(mask, threshold, 1, cv2.THRESH_BINARY, dst=mask)
cv2.dilate(mask, np.ones((5, 5), np.uint8), iterations=1, dst=mask)
cv2.blur(mask, (10, 10), dst=mask)
elapsed = time.time() - start
LOG.info("Selfie extracted, %.2f seconds elapsed", elapsed)
return mask
def generate_background(prompt, num_inference_steps, height, width):
LOG.info("Generating background")
start = time.time()
background = diffusion(
prompt=prompt,
num_inference_steps=int(num_inference_steps),
height=height,
width=width
)
nsfw = background.nsfw_content_detected[0]
background = background.images[0]
if nsfw:
LOG.info('NSFW detected, skipping')
background = np.zeros((height, width, 3), dtype='uint8')
else:
background = np.array(background)
# Convert RGB to BGR
background = background[:, :, ::-1].copy()
elapsed = time.time() - start
LOG.info("Background generated, elapsed %.2f seconds", elapsed)
return background
def merge_selfie_and_background(selfie, background, mask):
LOG.info("Merging extracted selfie and generated background")
selfie = np.array(selfie)
# Convert RGB to BGR
selfie = selfie[:, :, ::-1].copy()
cv2.blendLinear(selfie, background, mask, 1 - mask, dst=selfie)
selfie = cv2.cvtColor(selfie, cv2.COLOR_BGR2RGB)
selfie = PIL.Image.fromarray(selfie)
return selfie
def demo(threshold, image, prompt, num_inference_steps):
LOG.info("Processing image")
try:
image = image_preprocess(image)
mask = extract_selfie_mask(threshold, image)
background = generate_background(prompt, num_inference_steps,
image.size[1], image.size[0])
output = merge_selfie_and_background(image, background, mask)
except Exception as e:
LOG.error("Some unexpected error occured")
LOG.exception(e)
raise
return output
iface = gr.Interface(
fn=demo,
inputs=[
gr.Slider(minimum=0.1, maximum=1, step=0.05, label="Selfie segmentation threshold",
value=0.8),
gr.Image(type='pil', label="Upload your selfie"),
gr.Text(value="a photo of the Eiffel tower on the right side",
label="Background description"),
gr.Slider(minimum=5, maximum=100, step=5, label="Diffusion inference steps",
value=50)
],
outputs=[
gr.Image(label="Invent yourself a life :)")
])
# iface.launch(server_name="0.0.0.0", server_port=6443)
iface.launch()
|