import os import asyncio import requests from flask import Flask, request, jsonify,send_file from PIL import Image from io import BytesIO import torch import base64 import io import logging import gradio as gr import numpy as np import spaces import uuid import random from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref from src.unet_hacked_tryon import UNet2DConditionModel from transformers import ( CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPTextModel, CLIPTextModelWithProjection, AutoTokenizer, ) from diffusers import DDPMScheduler, AutoencoderKL from utils_mask import get_mask_location from torchvision import transforms import apply_net from preprocess.humanparsing.run_parsing import Parsing from preprocess.openpose.run_openpose import OpenPose from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation from torchvision.transforms.functional import to_pil_image app = Flask(__name__) # Base paths for models base_path = 'yisol/IDM-VTON' # Load models device = "cuda" if torch.cuda.is_available() else "cpu" unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16).to(device) tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", use_fast=False) tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", use_fast=False) noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler") text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16).to(device) image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16).to(device) vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16).to(device) UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16).to(device) parsing_model = Parsing(0).to(device) openpose_model = OpenPose(0).to(device) # Prepare Tryon pipeline pipe = TryonPipeline.from_pretrained( base_path, unet=unet, vae=vae, feature_extractor=CLIPImageProcessor(), text_encoder=text_encoder_one, text_encoder_2=text_encoder_two, tokenizer=tokenizer_one, tokenizer_2=tokenizer_two, scheduler=noise_scheduler, image_encoder=image_encoder, torch_dtype=torch.float16, ).to(device) pipe.unet_encoder = UNet_Encoder # Image transformation tensor_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) def pil_to_binary_mask(pil_image, threshold=0): np_image = np.array(pil_image) grayscale_image = Image.fromarray(np_image).convert("L") binary_mask = np.array(grayscale_image) > threshold mask = np.zeros(binary_mask.shape, dtype=np.uint8) mask[binary_mask] = 1 return Image.fromarray((mask * 255).astype(np.uint8)) def get_image_from_url(url): try: response = requests.get(url) response.raise_for_status() img = Image.open(BytesIO(response.content)) return img except Exception as e: logging.error(f"Error fetching image from URL: {e}") raise def decode_image_from_base64(base64_str): try: img_data = base64.b64decode(base64_str) img = Image.open(BytesIO(img_data)) return img except Exception as e: logging.error(f"Error decoding image: {e}") raise def encode_image_to_base64(img): try: buffered = BytesIO() img.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode("utf-8") except Exception as e: logging.error(f"Error encoding image: {e}") raise def save_image(img): unique_name = str(uuid.uuid4()) + ".webp" img.save(unique_name, format="WEBP", lossless=True) return unique_name @spaces.GPU def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, categorie='upper_body'): garm_img = garm_img.convert("RGB").resize((768, 1024)) human_img_orig = dict["background"].convert("RGB") if is_checked_crop: width, height = human_img_orig.size target_width = int(min(width, height * (3 / 4))) target_height = int(min(height, width * (4 / 3))) left = (width - target_width) / 2 top = (height - target_height) / 2 right = (width + target_width) / 2 bottom = (height + target_height) / 2 cropped_img = human_img_orig.crop((left, top, right, bottom)) crop_size = cropped_img.size human_img = cropped_img.resize((768, 1024)) else: human_img = human_img_orig.resize((768, 1024)) if is_checked: keypoints = openpose_model(human_img.resize((384, 512))) model_parse, _ = parsing_model(human_img.resize((384, 512))) mask, mask_gray = get_mask_location('hd', categorie, model_parse, keypoints) mask = mask.resize((768, 1024)) else: mask = dict['layers'][0].convert("RGB").resize((768, 1024)) mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transform(human_img) mask_gray = to_pil_image((mask_gray + 1.0) / 2.0) human_img_arg = _apply_exif_orientation(human_img.resize((384, 512))) human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR") args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda')) pose_img = args.func(args, human_img_arg) pose_img = pose_img[:, :, ::-1] pose_img = Image.fromarray(pose_img).resize((768, 1024)) with torch.no_grad(): with torch.cuda.amp.autocast(): prompt = "model is wearing " + garment_des negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = pipe.encode_prompt( prompt, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) prompt_c = "a photo of " + garment_des negative_prompt_c = "monochrome, lowres, bad anatomy, worst quality, low quality, change color" prompt = [prompt_c] if not isinstance(prompt_c, list) else prompt_c negative_prompt = [negative_prompt_c] if not isinstance(negative_prompt_c, list) else negative_prompt_c ( prompt_embeds_c, _, _, _, ) = pipe.encode_prompt( prompt, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=negative_prompt, ) pose_img_tensor = tensor_transform(pose_img).unsqueeze(0).to(device, torch.float16) garm_tensor = tensor_transform(garm_img).unsqueeze(0).to(device, torch.float16) generator = torch.Generator(device).manual_seed(seed) if seed is not None else None images = pipe( prompt_embeds=prompt_embeds.to(device), negative_prompt_embeds=negative_prompt_embeds.to(device), pooled_prompt_embeds=pooled_prompt_embeds.to(device), negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device), num_inference_steps=denoise_steps, generator=generator, strength=1.5, pose_img=pose_img_tensor.to(device), text_embeds_cloth=prompt_embeds_c.to(device), cloth=garm_tensor.to(device), mask_image=mask, image=human_img, height=1024, width=768, ) final_image = images[0] if isinstance(images, list) else images return encode_image_to_base64(final_image) @app.route('/tryon-v2', methods=['POST']) def tryon_v2(): data = request.json human_image_data = data['human_image'] garment_image_data = data['garment_image'] human_image = process_image(human_image_data) garment_image = process_image(garment_image_data) description = data.get('description') use_auto_mask = data.get('use_auto_mask', True) use_auto_crop = data.get('use_auto_crop', False) denoise_steps = int(data.get('denoise_steps', 30)) seed = int(data.get('seed', random.randint(0, 9999999))) categorie = data.get('categorie', 'upper_body') mask_image = None if 'mask_image' in data: mask_image_data = data['mask_image'] mask_image = process_image(mask_image_data) human_dict = { 'background': human_image, 'layers': [mask_image] if not use_auto_mask else None, 'composite': None } output_image, mask_image , mask = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed, categorie) return jsonify({ 'image_id': save_image(output_image), 'mask_gray_id' : save_image(mask_image), 'mask_id' : save_image(mask) }) def clear_gpu_memory(): torch.cuda.empty_cache() torch.cuda.synchronize() def process_image(image_data): # Vérifie si l'image est en base64 ou URL if image_data.startswith('http://') or image_data.startswith('https://'): return get_image_from_url(image_data) # Télécharge l'image depuis l'URL else: return decode_image_from_base64(image_data) # Décode l'image base64 @app.route('/tryon', methods=['POST']) def tryon(): data = request.json human_image = process_image(data['human_image']) garment_image = process_image(data['garment_image']) description = data.get('description') use_auto_mask = data.get('use_auto_mask', True) use_auto_crop = data.get('use_auto_crop', False) denoise_steps = int(data.get('denoise_steps', 30)) seed = int(data.get('seed', 42)) categorie = data.get('categorie' , 'upper_body') human_dict = { 'background': human_image, 'layers': [human_image] if not use_auto_mask else None, 'composite': None } clear_gpu_memory() output_image, mask_image = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed , categorie) output_base64 = encode_image_to_base64(output_image) mask_base64 = encode_image_to_base64(mask_image) return jsonify({ 'output_image': output_base64, 'mask_image': mask_base64 }) @spaces.GPU @app.route('/get_mask', methods=['POST']) def get_mask(): try: # Récupérer l'image du corps à partir de la requête data = request.json img_file = process_image(data['image']) img = img_file.convert("RGB").resize((384, 512)) categorie = request.form.get('categorie', 'upper_body') # Paramètre avec valeur par défaut # Appliquer la détection des points clés keypoints = openpose_model(img) # Utilise votre modèle model_parse, _ = parsing_model(img) # Utilise votre modèle # Déplacer le modèle et les images sur le même dispositif device = torch.device("cuda" if torch.cuda.is_available() else "cpu") img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device) # Convertir et déplacer l'image # Assurez-vous que le modèle est sur le même dispositif parsing_model.to(device) # Obtenir le masque mask, mask_gray = get_mask_location('hd', categorie, model_parse, keypoints) # Convertir le masque en image (si nécessaire) mask_gray = (1 - transforms.ToTensor()(mask_gray)) * tensor_transform(img_tensor) mask_gray = to_pil_image((mask_gray + 1.0) / 2.0) # Convertir l'image en base64 si besoin pour le retour img_byte_arr = io.BytesIO() mask_gray.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) return jsonify({'mask': img_byte_arr.getvalue().decode('latin1')}) except Exception as e: print(e) return jsonify({'error': str(e)}), 500 # Route index @app.route('/', methods=['GET']) def index(): # Renvoyer l'image try: return 'Welcome to IDM VTON API' except FileNotFoundError: return jsonify({'error': 'Image not found'}), 404 # Route pour récupérer l'image générée @app.route('/api/get_image/', methods=['GET']) def get_image(image_id): # Construire le chemin complet de l'image image_path = image_id # Assurez-vous que le nom de fichier correspond à celui que vous avez utilisé lors de la sauvegarde # Renvoyer l'image try: return send_file(image_path, mimetype='image/webp') except FileNotFoundError: return jsonify({'error': 'Image not found'}), 404 if __name__ == "__main__": app.run(debug=False, host="0.0.0.0", port=7860)