import os import asyncio import requests from flask import Flask, request, jsonify,send_file from PIL import Image from io import BytesIO import torch import base64 import io import logging import gradio as gr import numpy as np import spaces import uuid import random from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref from src.unet_hacked_tryon import UNet2DConditionModel from transformers import ( CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPTextModel, CLIPTextModelWithProjection, AutoTokenizer, ) from diffusers import DDPMScheduler, AutoencoderKL from utils_mask import get_mask_location from torchvision import transforms import apply_net from preprocess.humanparsing.run_parsing import Parsing from preprocess.openpose.run_openpose import OpenPose from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation from torchvision.transforms.functional import to_pil_image app = Flask(__name__) # Chemins de base pour les modèles base_path = 'yisol/IDM-VTON' # Chargement des modèles unet = UNet2DConditionModel.from_pretrained( base_path, subfolder="unet", torch_dtype=torch.float16, force_download=False ) tokenizer_one = AutoTokenizer.from_pretrained( base_path, subfolder="tokenizer", use_fast=False, force_download=False ) tokenizer_two = AutoTokenizer.from_pretrained( base_path, subfolder="tokenizer_2", use_fast=False, force_download=False ) noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler") text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16) text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16) image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16) vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16) UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16) parsing_model = Parsing(0) openpose_model = OpenPose(0) # Préparation du pipeline Tryon pipe = TryonPipeline.from_pretrained( base_path, unet=unet, vae=vae, feature_extractor=CLIPImageProcessor(), text_encoder=text_encoder_one, text_encoder_2=text_encoder_two, tokenizer=tokenizer_one, tokenizer_2=tokenizer_two, scheduler=noise_scheduler, image_encoder=image_encoder, torch_dtype=torch.float16, force_download=False ) pipe.unet_encoder = UNet_Encoder # Utilisation des transformations d'images tensor_transfrom = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) def pil_to_binary_mask(pil_image, threshold=0): np_image = np.array(pil_image) grayscale_image = Image.fromarray(np_image).convert("L") binary_mask = np.array(grayscale_image) > threshold mask = np.zeros(binary_mask.shape, dtype=np.uint8) mask[binary_mask] = 1 return Image.fromarray((mask * 255).astype(np.uint8)) def get_image_from_url(url): try: response = requests.get(url) response.raise_for_status() # Vérifie les erreurs HTTP img = Image.open(BytesIO(response.content)) return img except Exception as e: logging.error(f"Error fetching image from URL: {e}") raise def decode_image_from_base64(base64_str): try: img_data = base64.b64decode(base64_str) img = Image.open(BytesIO(img_data)) return img except Exception as e: logging.error(f"Error decoding image: {e}") raise def encode_image_to_base64(img): try: buffered = BytesIO() img.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode("utf-8") except Exception as e: logging.error(f"Error encoding image: {e}") raise def save_image(img): unique_name = str(uuid.uuid4()) + ".webp" img.save(unique_name, format="WEBP", lossless=True) return unique_name @spaces.GPU def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, categorie = 'upper_body'): device = "cuda" openpose_model.preprocessor.body_estimation.model.to(device) pipe.to(device) pipe.unet_encoder.to(device) garm_img = garm_img.convert("RGB").resize((768, 1024)) human_img_orig = dict["background"].convert("RGB") if is_checked_crop: width, height = human_img_orig.size target_width = int(min(width, height * (3 / 4))) target_height = int(min(height, width * (4 / 3))) left = (width - target_width) / 2 top = (height - target_height) / 2 right = (width + target_width) / 2 bottom = (height + target_height) / 2 cropped_img = human_img_orig.crop((left, top, right, bottom)) crop_size = cropped_img.size human_img = cropped_img.resize((768, 1024)) else: human_img = human_img_orig.resize((768, 1024)) if is_checked: keypoints = openpose_model(human_img.resize((384, 512))) model_parse, _ = parsing_model(human_img.resize((384, 512))) mask, mask_gray = get_mask_location('hd', categorie , model_parse, keypoints) mask = mask.resize((768, 1024)) else: mask = dict['layers'][0].convert("RGB").resize((768, 1024))#pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024))) mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transfrom(human_img) mask_gray = to_pil_image((mask_gray + 1.0) / 2.0) human_img_arg = _apply_exif_orientation(human_img.resize((384, 512))) human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR") args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda')) pose_img = args.func(args, human_img_arg) pose_img = pose_img[:, :, ::-1] pose_img = Image.fromarray(pose_img).resize((768, 1024)) with torch.no_grad(): with torch.cuda.amp.autocast(): prompt = "model is wearing " + garment_des negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" with torch.inference_mode(): ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = pipe.encode_prompt( prompt, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) prompt = "a photo of " + garment_des negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality , change color" if not isinstance(prompt, list): prompt = [prompt] * 1 if not isinstance(negative_prompt, list): negative_prompt = [negative_prompt] * 1 with torch.inference_mode(): ( prompt_embeds_c, _, _, _, ) = pipe.encode_prompt( prompt, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=negative_prompt, ) pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16) garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16) generator = torch.Generator(device).manual_seed(seed) if seed is not None else None images = pipe( prompt_embeds=prompt_embeds.to(device, torch.float16), negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16), pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16), negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16), num_inference_steps=denoise_steps, generator=generator, strength=1.5, pose_img=pose_img.to(device, torch.float16), text_embeds_cloth=prompt_embeds_c.to(device, torch.float16), cloth=garm_tensor.to(device, torch.float16), mask_image=mask, image=human_img, height=1024, width=768, ip_adapter_image=garm_img.resize((768, 1024)), guidance_scale=1.5, )[0] if is_checked_crop: out_img = images[0].resize(crop_size) human_img_orig.paste(out_img, (int(left), int(top))) return human_img_orig, mask_gray else: return images[0], mask_gray , mask @app.route('/tryon-v2', methods=['POST']) def tryon_v2(): data = request.json human_image_data = data['human_image'] garment_image_data = data['garment_image'] human_image = process_image(human_image_data) garment_image = process_image(garment_image_data) description = data.get('description') use_auto_mask = data.get('use_auto_mask', True) use_auto_crop = data.get('use_auto_crop', False) denoise_steps = int(data.get('denoise_steps', 30)) seed = int(data.get('seed', random.randint(0, 9999999))) categorie = data.get('categorie', 'upper_body') mask_image = None if 'mask_image' in data: mask_image_data = data['mask_image'] mask_image = process_image(mask_image_data) human_dict = { 'background': human_image, 'layers': [mask_image] if not use_auto_mask else None, 'composite': None } output_image, mask_image , mask = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed, categorie) return jsonify({ 'image_id': save_image(output_image), 'mask_gray_id' : save_image(mask_image), 'mask_id' : save_image(mask) }) def clear_gpu_memory(): torch.cuda.empty_cache() torch.cuda.synchronize() def process_image(image_data): # Vérifie si l'image est en base64 ou URL if image_data.startswith('http://') or image_data.startswith('https://'): return get_image_from_url(image_data) # Télécharge l'image depuis l'URL else: return decode_image_from_base64(image_data) # Décode l'image base64 @app.route('/tryon', methods=['POST']) def tryon(): data = request.json human_image = process_image(data['human_image']) garment_image = process_image(data['garment_image']) description = data.get('description') use_auto_mask = data.get('use_auto_mask', True) use_auto_crop = data.get('use_auto_crop', False) denoise_steps = int(data.get('denoise_steps', 30)) seed = int(data.get('seed', 42)) categorie = data.get('categorie' , 'upper_body') human_dict = { 'background': human_image, 'layers': [human_image] if not use_auto_mask else None, 'composite': None } clear_gpu_memory() output_image, mask_image = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed , categorie) output_base64 = encode_image_to_base64(output_image) mask_base64 = encode_image_to_base64(mask_image) return jsonify({ 'output_image': output_base64, 'mask_image': mask_base64 }) @app.route('/get_mask', methods=['POST']) def get_mask(): try: # Récupérer l'image du corps à partir de la requête data = request.json img_file = process_image(data['image']) img = img_file.convert("RGB").resize((384, 512)) categorie = request.form.get('categorie', 'upper_body') # Paramètre avec valeur par défaut # Appliquer la détection des points clés keypoints = openpose_model(img) # Utilise votre modèle model_parse, _ = parsing_model(img) # Utilise votre modèle # Obtenir le masque mask, mask_gray = get_mask_location('hd', categorie , model_parse, keypoints) # Convertir le masque en image (si nécessaire) mask_gray = (1 - transforms.ToTensor()(mask_gray)) * tensor_transfrom(img) mask_gray = to_pil_image((mask_gray + 1.0) / 2.0) # Convertir l'image en base64 si besoin pour le retour img_byte_arr = io.BytesIO() mask_gray.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) return jsonify({'mask': img_byte_arr.getvalue().decode('latin1')}) # Utiliser une méthode appropriée pour l'encodage except Exception as e: return jsonify({'error': str(e)}), 500 # Route index @app.route('/', methods=['GET']) def index(): # Renvoyer l'image try: return 'Welcome to IDM VTON API' except FileNotFoundError: return jsonify({'error': 'Image not found'}), 404 # Route pour récupérer l'image générée @app.route('/api/get_image/', methods=['GET']) def get_image(image_id): # Construire le chemin complet de l'image image_path = image_id # Assurez-vous que le nom de fichier correspond à celui que vous avez utilisé lors de la sauvegarde # Renvoyer l'image try: return send_file(image_path, mimetype='image/webp') except FileNotFoundError: return jsonify({'error': 'Image not found'}), 404 if __name__ == "__main__": app.run(debug=False, host="0.0.0.0", port=7860)