touch_up_hair_roots / deploy_1.py
qipchip31's picture
initial commit
9be3758 verified
raw
history blame contribute delete
No virus
10.4 kB
# -*- coding: utf-8 -*-
"""deploy_1
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/15bRa4lN0gamY1gSoZhpUGDp61rmTJ0Eg
# Installing Modules
"""
!pip install mediapipe
!pip install --upgrade diffusers[torch]
!pip install transformers
!pip install accelerate
!pip install git+https://github.com/huggingface/diffusers
"""# Importing Modules"""
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
import cv2
from google.colab.patches import cv2_imshow
import math
import numpy as np
from PIL import Image
from cv2 import kmeans, TERM_CRITERIA_MAX_ITER, TERM_CRITERIA_EPS, KMEANS_RANDOM_CENTERS
from numpy import float32
from matplotlib.pyplot import scatter, show
import matplotlib.pyplot as plt
import requests
from transformers import pipeline
import torch
import PIL
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDPMScheduler
from diffusers.utils import load_image
import torch
"""# Stable Diffusion and ControlNet Pipeline"""
# Stable Diffusion Controlnet Pipeline Class
class StableDiffusionControlnetPipeline:
def __init__(self):
self.SELFIE_MULTICLASS_SEGMENTER_MODEL_PATH = "/content/selfie_multiclass_256x256.tflite"
self.CONTROLNET_PATH = "lllyasviel/control_v11p_sd15_inpaint"
self.MODEL_PATH = "Uminosachi/realisticVisionV51_v51VAE-inpainting"
self.device = "cuda"
self.hair_color_pipeline = pipeline("image-classification", model="enzostvs/hair-color")
self.controlnet = ControlNetModel.from_pretrained(
self.CONTROLNET_PATH, torch_dtype=torch.float16
).to(self.device)
pipe = StableDiffusionInpaintPipeline.from_pretrained(
self.MODEL_PATH,
controlnet=self.controlnet,
safety_checker=None,
requires_safety_checker=False,
torch_dtype=torch.float16
).to(self.device)
pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
self.pipe = pipe
def get_hair_dominant_color(self, image_path):
hair_img = Image.open(image_path).convert('RGB')
results = self.hair_color_pipeline.predict(hair_img)
first_score, first_hair_color = results[0]["score"], results[0]["label"]
second_score, second_hair_color = results[1]["score"], results[1]["label"]
if first_hair_color != "completely bald":
return first_hair_color
else:
return second_hair_color
def make_inpaint_condition(self, image, image_mask):
image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
image[image_mask > 0.5] = -1.0 # set as masked pixel
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image
def roundUp(self, input, round):
return input + round - (input % round)
def stable_diffusion_controlnet(self, image_path):
HAIR_ROOT_MASK_PATH = self.create_hair_root_mask(image_path, self.SELFIE_MULTICLASS_SEGMENTER_MODEL_PATH)
HAIR_COLOR = self.get_hair_dominant_color(image_path)
PROMPT = f"({HAIR_COLOR} root:1.2), raw photo, high detail"
NEGATIVE_PROMPT = "black hair root"
init_image = load_image(image_path)
mask_image = load_image(HAIR_ROOT_MASK_PATH)
height = self.roundUp(init_image.height, 8)
width = self.roundUp(init_image.width, 8)
generator = torch.Generator(device=self.device).manual_seed(1)
control_image = self.make_inpaint_condition(init_image, mask_image)
new_image = self.pipe(
prompt=PROMPT,
image=init_image,
mask_image=mask_image,
num_inference_steps=40,
generator=generator,
control_image=control_image,
negative_prompt=NEGATIVE_PROMPT,
strength=1,
height=height,
width=width,
padding_mask_crop=40,
guidance_scale=3.5
).images
hair_root_edited_img = new_image[0]
hair_root_edited_img.save("new_img_modified.jpg")
return hair_root_edited_img
def view_result(self, init_image, touched_up_image):
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(init_image)
axes[0].set_title('Original Image')
axes[0].axis('off')
axes[1].imshow(touched_up_image)
axes[1].set_title('Hair Root Touched-up')
axes[1].axis('off')
plt.show()
def resize_and_show(self, image, INPUT_HEIGHT=512, INPUT_WIDTH=512):
h, w = image.shape[:2]
if h < w:
img = cv2.resize(image, (INPUT_WIDTH, math.floor(h/(w/INPUT_WIDTH))))
else:
img = cv2.resize(image, (math.floor(w/(h/INPUT_HEIGHT)), INPUT_HEIGHT))
cv2_imshow(img)
def create_hair_root_mask(self, image_path, SELFIE_MULTICLASS_SEGMENTER_MODEL_PATH):
BG_COLOR = (0, 0, 0) # Background RGB Color
MASK_COLOR = (255, 255, 255) # Mask RGB Color
HAIR_CLASS_INDEX = 1 # Index of the Hair Class
N_CLUSTERS = 3
img = cv2.imread(image_path)
base_options = python.BaseOptions(model_asset_path=SELFIE_MULTICLASS_SEGMENTER_MODEL_PATH)
options = vision.ImageSegmenterOptions(base_options=base_options, output_category_mask=True)
with vision.ImageSegmenter.create_from_options(options) as segmenter:
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
segmentation_result = segmenter.segment(image)
category_mask = segmentation_result.category_mask
image_data = image.numpy_view()
fg_image = np.zeros(image_data.shape, dtype=np.uint8)
fg_image[:] = MASK_COLOR
bg_image = np.zeros(image_data.shape, dtype=np.uint8)
bg_image[:] = BG_COLOR
condition = np.stack((category_mask.numpy_view(),) * 3, axis=-1) == HAIR_CLASS_INDEX
output_image = np.where(condition, fg_image, bg_image)
cv2.imwrite("hair_mask.png", output_image)
hair_mask_cropped = cv2.bitwise_and(img, output_image)
coords = np.where(output_image != [255, 255, 255])
background = np.full(img.shape, 128, dtype=np.uint8) # gray background color
hair_mask_cropped[coords[0], coords[1], coords[2]] = background[coords[0], coords[1], coords[2]]
rgb_img_hair_mask_cropped = cv2.cvtColor(hair_mask_cropped, cv2.COLOR_BGR2RGB)
pillow_img = Image.fromarray(rgb_img_hair_mask_cropped)
pillow_img.save("hair_mask_cropped.jpg")
img_data = rgb_img_hair_mask_cropped.reshape(-1, 3)
criteria = (TERM_CRITERIA_MAX_ITER + TERM_CRITERIA_EPS, 100, 0.2)
compactness, labels, centers = kmeans(data=img_data.astype(float32), K=N_CLUSTERS, bestLabels=None,
criteria=criteria, attempts=10, flags=KMEANS_RANDOM_CENTERS)
colours = centers[labels].reshape(-1, 3)
img_colours = colours.reshape(rgb_img_hair_mask_cropped.shape)
number_labels = np.bincount(labels.flatten())
minimum_cluster_class = number_labels.argmin()
masked_image = np.copy(rgb_img_hair_mask_cropped)
masked_image = masked_image.reshape((-1, 3))
labels = labels.flatten()
masked_image[labels == minimum_cluster_class] = [255, 255, 255]
masked_image = masked_image.reshape(rgb_img_hair_mask_cropped.shape)
masked_image = np.copy(rgb_img_hair_mask_cropped)
masked_image = masked_image.reshape((-1, 3))
for i in range(0, len(number_labels)):
masked_image[labels == i] = [0, 0, 0]
masked_image[labels == minimum_cluster_class] = [255, 255, 255]
masked_image = masked_image.reshape(rgb_img_hair_mask_cropped.shape)
cv2.imwrite("hair_root_mask.jpg", masked_image)
hair_rost_mask_img = cv2.imread('hair_root_mask.jpg')
gray = cv2.cvtColor(hair_rost_mask_img, cv2.COLOR_BGR2GRAY)
ret, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
contours, hierarchy = cv2.findContours(binary, mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_NONE)
image_copy = hair_rost_mask_img.copy()
image_copy = cv2.drawContours(image_copy, contours, -1, (255, 255, 255), thickness=3, lineType=cv2.LINE_4)
cv2.fillPoly(image_copy, pts=contours, color=(255, 255, 255))
(h, w) = image_copy.shape[:2]
cut_pixel = int((w // 2) * 0.25)
chin_point = ((w // 2) - cut_pixel, (h // 2) - cut_pixel)
image_copy[chin_point[0]:, :] = [0, 0, 0]
cv2.imwrite("hair_root_mask_mdf.png", image_copy)
HAIR_ROOT_MASK_PATH = "/content/hair_root_mask_mdf.png"
return HAIR_ROOT_MASK_PATH
"""# Installing Gradio"""
!pip install gradio --upgrade
"""## Calling the StableDiffusionControlnetPipeline for Gradio Interface
"""
import numpy as np
import gradio as gr
# Assuming StableDiffusionControlnetPipeline class is already defined
# Define the function for Gradio
def process_image(input_img):
# Convert Gradio input image to numpy array
input_img_np = np.array(input_img)
# Save the uploaded image to a temporary file
temp_image_path = "/tmp/uploaded_image.jpg"
input_img.save(temp_image_path)
# Instantiate your pipeline with the uploaded image
SB_ControlNet_pipeline = StableDiffusionControlnetPipeline()
# Process the image using your pipeline
output_img = SB_ControlNet_pipeline.stable_diffusion_controlnet(temp_image_path)
return output_img
# Create a Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs="image",
title="Hair Root Touch Up using AI!",
description="Upload an image to edit hair roots using Stable Diffusion Controlnet:)"
)
# Launch the Gradio interface
iface.launch(debug=True)