File size: 5,925 Bytes
45b10c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
%%writefile handler.py
from typing import Dict, List, Any
import base64
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, AutoencoderKL, StableDiffusionXLControlNetPipeline, AutoPipelineForText2Image
import torch
from diffusers.utils import load_image
import numpy as np
import cv2
import controlnet_hinter
# ADDED AUTO PIPE
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("need to run on GPU")
# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
# controlnet mapping for controlnet id and control hinter
CONTROLNET_MAPPING = {
"canny_edge": {
"model_id": "lllyasviel/sd-controlnet-canny",
"hinter": controlnet_hinter.hint_canny
},
"pose": {
"model_id": "lllyasviel/sd-controlnet-openpose",
"hinter": controlnet_hinter.hint_openpose
},
"depth": {
"model_id": "lllyasviel/sd-controlnet-depth",
"hinter": controlnet_hinter.hint_depth
},
"scribble": {
"model_id": "lllyasviel/sd-controlnet-scribble",
"hinter": controlnet_hinter.hint_scribble,
},
"segmentation": {
"model_id": "lllyasviel/sd-controlnet-seg",
"hinter": controlnet_hinter.hint_segmentation,
},
"normal": {
"model_id": "lllyasviel/sd-controlnet-normal",
"hinter": controlnet_hinter.hint_normal,
},
"hed": {
"model_id": "lllyasviel/sd-controlnet-hed",
"hinter": controlnet_hinter.hint_hed,
},
"hough": {
"model_id": "lllyasviel/sd-controlnet-mlsd",
"hinter": controlnet_hinter.hint_hough,
}
}
class EndpointHandler():
def __init__(self, path=""):
# define default controlnet id and load controlnet
self.control_type = "normal"
self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"], torch_dtype=dtype).to(device)
# Load StableDiffusionControlNetPipeline
self.stable_diffusion_id = "stablediffusionapi/disney-pixar-cartoon"
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
controlnet=self.controlnet,
torch_dtype=dtype,
safety_checker=None).to(device)
# Define Generator with seed
# COMMENTED self.generator = torch.Generator(device="cpu").manual_seed(3)
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
"""
:param data: A dictionary contains `inputs` and optional `image` field.
:return: A dictionary with `image` field contains image in base64.
"""
prompt = data.pop("inputs", None)
image = data.pop("image", None)
controlnet_type = data.pop("controlnet_type", None)
stablediffusion_id = data.pop("stablediffusionid", None) # Get the stablediffusionid from the request data
if stablediffusion_id is not None and stablediffusion_id != self.stable_diffusion_id:
# Change the Stable Diffusion model to the new model ID
self.stable_diffusion_id = stablediffusion_id
# Reinitialize the pipeline with the new model ID
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
self.stable_diffusion_id,
controlnet=self.controlnet,
torch_dtype=dtype,
safety_checker=None
).to(device)
# Check if neither prompt nor image is provided
if prompt is None and image is None:
return {"error": "Please provide a prompt and base64 encoded image."}
# Check if a new controlnet is provided
if controlnet_type is not None and controlnet_type != self.control_type:
print(f"changing controlnet from {self.control_type} to {controlnet_type} using {CONTROLNET_MAPPING[controlnet_type]['model_id']} model")
self.control_type = controlnet_type
self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],
torch_dtype=dtype).to(device)
self.pipe.controlnet = self.controlnet
# hyperparameters
negative_prompt = data.pop("negative_prompt", None)
num_inference_steps = data.pop("num_inference_steps", 150)
guidance_scale = data.pop("guidance_scale", 5)
negative_prompt = data.pop("negative_prompt", None)
height = data.pop("height", None)
width = data.pop("width", None)
controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
# process image
image = self.decode_base64_image(image)
control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
# run inference pipeline
out = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=control_image,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
height=height,
width=width,
controlnet_conditioning_scale=controlnet_conditioning_scale,
guess_mode=True,
)
# generator=self.generator COMMENTED from self.pipe
# return the first generated PIL image
return out.images[0]
# helper to decode input image
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image
|