import gradio as gr from PIL import Image from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref from src.unet_hacked_tryon import UNet2DConditionModel from transformers import ( CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPTextModel, CLIPTextModelWithProjection, ) from diffusers import DDPMScheduler,AutoencoderKL from typing import List import torch import os from transformers import AutoTokenizer import spaces import numpy as np from utils_mask import get_mask_location from torchvision import transforms import apply_net from preprocess.humanparsing.run_parsing import Parsing from preprocess.openpose.run_openpose import OpenPose from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation from torchvision.transforms.functional import to_pil_image def pil_to_binary_mask(pil_image, threshold=0): np_image = np.array(pil_image) grayscale_image = Image.fromarray(np_image).convert("L") binary_mask = np.array(grayscale_image) > threshold mask = np.zeros(binary_mask.shape, dtype=np.uint8) for i in range(binary_mask.shape[0]): for j in range(binary_mask.shape[1]): if binary_mask[i,j] == True : mask[i,j] = 1 mask = (mask*255).astype(np.uint8) output_mask = Image.fromarray(mask) return output_mask base_path = 'yisol/IDM-VTON' example_path = os.path.join(os.path.dirname(__file__), 'example') unet = UNet2DConditionModel.from_pretrained( base_path, subfolder="unet", torch_dtype=torch.float16, ) unet.requires_grad_(False) tokenizer_one = AutoTokenizer.from_pretrained( base_path, subfolder="tokenizer", revision=None, use_fast=False, ) tokenizer_two = AutoTokenizer.from_pretrained( base_path, subfolder="tokenizer_2", revision=None, use_fast=False, ) noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler") text_encoder_one = CLIPTextModel.from_pretrained( base_path, subfolder="text_encoder", torch_dtype=torch.float16, ) text_encoder_two = CLIPTextModelWithProjection.from_pretrained( base_path, subfolder="text_encoder_2", torch_dtype=torch.float16, ) image_encoder = CLIPVisionModelWithProjection.from_pretrained( base_path, subfolder="image_encoder", torch_dtype=torch.float16, ) vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16, ) # "stabilityai/stable-diffusion-xl-base-1.0", UNet_Encoder = UNet2DConditionModel_ref.from_pretrained( base_path, subfolder="unet_encoder", torch_dtype=torch.float16, ) parsing_model = Parsing(0) openpose_model = OpenPose(0) UNet_Encoder.requires_grad_(False) image_encoder.requires_grad_(False) vae.requires_grad_(False) unet.requires_grad_(False) text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) tensor_transfrom = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) pipe = TryonPipeline.from_pretrained( base_path, unet=unet, vae=vae, feature_extractor= CLIPImageProcessor(), text_encoder = text_encoder_one, text_encoder_2 = text_encoder_two, tokenizer = tokenizer_one, tokenizer_2 = tokenizer_two, scheduler = noise_scheduler, image_encoder=image_encoder, torch_dtype=torch.float16, ) pipe.unet_encoder = UNet_Encoder @spaces.GPU def start_tryon(dict,garm_img,garment_des,is_checked,is_checked_crop,denoise_steps,seed): device = "cuda" openpose_model.preprocessor.body_estimation.model.to(device) pipe.to(device) pipe.unet_encoder.to(device) garm_img= garm_img.convert("RGB").resize((768,1024)) human_img_orig = dict["background"].convert("RGB") if is_checked_crop: width, height = human_img_orig.size target_width = int(min(width, height * (3 / 4))) target_height = int(min(height, width * (4 / 3))) left = (width - target_width) / 2 top = (height - target_height) / 2 right = (width + target_width) / 2 bottom = (height + target_height) / 2 cropped_img = human_img_orig.crop((left, top, right, bottom)) crop_size = cropped_img.size human_img = cropped_img.resize((768,1024)) else: human_img = human_img_orig.resize((768,1024)) if is_checked: keypoints = openpose_model(human_img.resize((384,512))) model_parse, _ = parsing_model(human_img.resize((384,512))) mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints) mask = mask.resize((768,1024)) else: mask = pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024))) # mask = transforms.ToTensor()(mask) # mask = mask.unsqueeze(0) mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img) mask_gray = to_pil_image((mask_gray+1.0)/2.0) human_img_arg = _apply_exif_orientation(human_img.resize((384,512))) human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR") args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda')) # verbosity = getattr(args, "verbosity", None) pose_img = args.func(args,human_img_arg) pose_img = pose_img[:,:,::-1] pose_img = Image.fromarray(pose_img).resize((768,1024)) with torch.no_grad(): # Extract the images with torch.cuda.amp.autocast(): with torch.no_grad(): prompt = "model is wearing " + garment_des negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" with torch.inference_mode(): ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = pipe.encode_prompt( prompt, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) prompt = "a photo of " + garment_des negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" if not isinstance(prompt, List): prompt = [prompt] * 1 if not isinstance(negative_prompt, List): negative_prompt = [negative_prompt] * 1 with torch.inference_mode(): ( prompt_embeds_c, _, _, _, ) = pipe.encode_prompt( prompt, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=negative_prompt, ) pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16) garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16) generator = torch.Generator(device).manual_seed(seed) if seed is not None else None images = pipe( prompt_embeds=prompt_embeds.to(device,torch.float16), negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16), pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16), negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16), num_inference_steps=denoise_steps, generator=generator, strength = 1.0, pose_img = pose_img.to(device,torch.float16), text_embeds_cloth=prompt_embeds_c.to(device,torch.float16), cloth = garm_tensor.to(device,torch.float16), mask_image=mask, image=human_img, height=1024, width=768, ip_adapter_image = garm_img.resize((768,1024)), guidance_scale=2.0, )[0] if is_checked_crop: out_img = images[0].resize(crop_size) human_img_orig.paste(out_img, (int(left), int(top))) return human_img_orig, mask_gray else: return images[0], mask_gray # return images[0], mask_gray garm_list = os.listdir(os.path.join(example_path,"cloth")) garm_list_path = [os.path.join(example_path,"cloth",garm) for garm in garm_list] human_list = os.listdir(os.path.join(example_path,"human")) human_list_path = [os.path.join(example_path,"human",human) for human in human_list] human_ex_list = [] for ex_human in human_list_path: ex_dict= {} ex_dict['background'] = ex_human ex_dict['layers'] = None ex_dict['composite'] = None human_ex_list.append(ex_dict) ##default human image_blocks = gr.Blocks().queue() with image_blocks as demo: gr.Markdown("## Soccer Jersey Try-On ⚽👕") gr.Markdown("Experience the thrill of Euro and Copa America 2024 by virtually trying on the jerseys of the final four teams in each competition. Simply upload your photo and select the jersey of your choice to see how you look!") gr.Markdown("Credits to the [IDM-VTON](https://huggingface.co./spaces/yisol/IDM-VTON) project for the inspiration and source code.") with gr.Row(): with gr.Column(): imgs = gr.ImageEditor(sources='upload', type="pil", label='Human. Mask with pen or use auto-masking', interactive=True) with gr.Row(): is_checked = gr.Checkbox(label="Yes", info="Use auto-generated mask (Takes 5 seconds)",value=True) with gr.Row(): is_checked_crop = gr.Checkbox(label="Yes", info="Use auto-crop & resizing",value=False) example = gr.Examples( inputs=imgs, examples_per_page=10, examples=human_ex_list ) with gr.Column(): garm_img = gr.Image(label="Garment", sources='upload', type="pil") with gr.Row(elem_id="prompt-container"): with gr.Row(): prompt = gr.Textbox(placeholder="Description of garment (Soccer jersey)", show_label=False, elem_id="prompt") example = gr.Examples( inputs=garm_img, examples_per_page=8, examples=garm_list_path) with gr.Column(): # image_out = gr.Image(label="Output", elem_id="output-img", height=400) masked_img = gr.Image(label="Masked image output", elem_id="masked-img",show_share_button=False) with gr.Column(): # image_out = gr.Image(label="Output", elem_id="output-img", height=400) image_out = gr.Image(label="Output", elem_id="output-img",show_share_button=False) with gr.Column(): try_button = gr.Button(value="Try-on") with gr.Accordion(label="Advanced Settings", open=False): with gr.Row(): denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1) seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42) try_button.click(fn=start_tryon, inputs=[imgs, garm_img, prompt, is_checked,is_checked_crop, denoise_steps, seed], outputs=[image_out,masked_img], api_name='tryon') image_blocks.launch()