Update app.py
Browse files
app.py
CHANGED
@@ -1,95 +1,246 @@
|
|
|
|
1 |
from flask import Flask, request, jsonify
|
|
|
|
|
|
|
2 |
import torch
|
3 |
-
|
|
|
|
|
4 |
from transformers import (
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
9 |
)
|
10 |
-
from
|
11 |
-
import
|
12 |
-
from
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
app = Flask(__name__)
|
15 |
|
16 |
-
#
|
17 |
-
|
18 |
-
tokenizer_one = None
|
19 |
-
tokenizer_two = None
|
20 |
-
noise_scheduler = None
|
21 |
-
text_encoder_one = None
|
22 |
-
text_encoder_two = None
|
23 |
-
image_encoder = None
|
24 |
-
vae = None
|
25 |
-
UNet_Encoder = None
|
26 |
-
|
27 |
-
# Load models once at startup
|
28 |
-
def load_models():
|
29 |
-
global unet, tokenizer_one, tokenizer_two, noise_scheduler
|
30 |
-
global text_encoder_one, text_encoder_two, image_encoder, vae, UNet_Encoder
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
if tokenizer_one is None:
|
37 |
-
tokenizer_one = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
38 |
-
|
39 |
-
if tokenizer_two is None:
|
40 |
-
tokenizer_two = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14-336")
|
41 |
|
42 |
-
if
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
app.run(host='0.0.0.0', port=7860)
|
|
|
1 |
+
import os
|
2 |
from flask import Flask, request, jsonify
|
3 |
+
from PIL import Image
|
4 |
+
from io import BytesIO
|
5 |
+
import base64
|
6 |
import torch
|
7 |
+
import requests
|
8 |
+
import numpy as np
|
9 |
+
import uuid
|
10 |
from transformers import (
|
11 |
+
CLIPImageProcessor,
|
12 |
+
CLIPVisionModelWithProjection,
|
13 |
+
CLIPTextModel,
|
14 |
+
CLIPTextModelWithProjection,
|
15 |
+
AutoTokenizer
|
16 |
)
|
17 |
+
from diffusers import DDPMScheduler, AutoencoderKL
|
18 |
+
from utils_mask import get_mask_location
|
19 |
+
from torchvision import transforms
|
20 |
+
import apply_net
|
21 |
+
from preprocess.humanparsing.run_parsing import Parsing
|
22 |
+
from preprocess.openpose.run_openpose import OpenPose
|
23 |
+
from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
|
24 |
+
from torchvision.transforms.functional import to_pil_image
|
25 |
|
26 |
app = Flask(__name__)
|
27 |
|
28 |
+
# Variables globales pour stocker les modèles
|
29 |
+
models_loaded = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
def load_models():
|
32 |
+
global unet, tokenizer_one, tokenizer_two, noise_scheduler, text_encoder_one, text_encoder_two
|
33 |
+
global image_encoder, vae, UNet_Encoder, parsing_model, openpose_model, pipe
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
+
if not models_loaded:
|
36 |
+
base_path = 'yisol/IDM-VTON'
|
37 |
+
unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16, force_download=False)
|
38 |
+
unet.requires_grad_(False)
|
39 |
+
|
40 |
+
tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", use_fast=False, force_download=False)
|
41 |
+
tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", use_fast=False, force_download=False)
|
42 |
+
|
43 |
+
noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
|
44 |
+
text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16, force_download=False)
|
45 |
+
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16, force_download=False)
|
46 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16, force_download=False)
|
47 |
+
vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16, force_download=False)
|
48 |
+
UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16, force_download=False)
|
49 |
+
|
50 |
+
parsing_model = Parsing(0)
|
51 |
+
openpose_model = OpenPose(0)
|
52 |
+
|
53 |
+
UNet_Encoder.requires_grad_(False)
|
54 |
+
image_encoder.requires_grad_(False)
|
55 |
+
vae.requires_grad_(False)
|
56 |
+
unet.requires_grad_(False)
|
57 |
+
text_encoder_one.requires_grad_(False)
|
58 |
+
text_encoder_two.requires_grad_(False)
|
59 |
+
|
60 |
+
tensor_transfrom = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
|
61 |
+
|
62 |
+
pipe = TryonPipeline.from_pretrained(
|
63 |
+
base_path,
|
64 |
+
unet=unet,
|
65 |
+
vae=vae,
|
66 |
+
feature_extractor=CLIPImageProcessor(),
|
67 |
+
text_encoder=text_encoder_one,
|
68 |
+
text_encoder_2=text_encoder_two,
|
69 |
+
tokenizer=tokenizer_one,
|
70 |
+
tokenizer_2=tokenizer_two,
|
71 |
+
scheduler=noise_scheduler,
|
72 |
+
image_encoder=image_encoder,
|
73 |
+
torch_dtype=torch.float16,
|
74 |
+
force_download=False
|
75 |
+
)
|
76 |
+
pipe.unet_encoder = UNet_Encoder
|
77 |
+
|
78 |
+
global models_loaded
|
79 |
+
models_loaded = True
|
80 |
+
|
81 |
+
def pil_to_binary_mask(pil_image, threshold=0):
|
82 |
+
np_image = np.array(pil_image.convert("L")) # Convert to grayscale directly
|
83 |
+
binary_mask = np_image > threshold
|
84 |
+
mask = np.uint8(binary_mask * 255)
|
85 |
+
return Image.fromarray(mask)
|
86 |
+
|
87 |
+
def get_image_from_url(url):
|
88 |
+
try:
|
89 |
+
response = requests.get(url)
|
90 |
+
response.raise_for_status()
|
91 |
+
return Image.open(BytesIO(response.content))
|
92 |
+
except Exception as e:
|
93 |
+
logging.error(f"Error fetching image from URL: {e}")
|
94 |
+
raise
|
95 |
+
|
96 |
+
def decode_image_from_base64(base64_str):
|
97 |
+
try:
|
98 |
+
img_data = base64.b64decode(base64_str)
|
99 |
+
return Image.open(BytesIO(img_data))
|
100 |
+
except Exception as e:
|
101 |
+
logging.error(f"Error decoding image: {e}")
|
102 |
+
raise
|
103 |
+
|
104 |
+
def encode_image_to_base64(img):
|
105 |
+
try:
|
106 |
+
buffered = BytesIO()
|
107 |
+
img.save(buffered, format="PNG")
|
108 |
+
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
109 |
+
except Exception as e:
|
110 |
+
logging.error(f"Error encoding image: {e}")
|
111 |
+
raise
|
112 |
+
|
113 |
+
def save_image(img):
|
114 |
+
unique_name = f"{uuid.uuid4()}.webp"
|
115 |
+
img.save(unique_name, format="WEBP", lossless=True)
|
116 |
+
return unique_name
|
117 |
+
|
118 |
+
def clear_gpu_memory():
|
119 |
+
torch.cuda.empty_cache()
|
120 |
+
torch.cuda.ipc_collect()
|
121 |
+
|
122 |
+
@spaces.GPU
|
123 |
+
def start_tryon(human_dict, garment_image, garment_description, use_auto_mask, use_auto_crop, denoise_steps, seed, category='upper_body'):
|
124 |
+
device = "cuda"
|
125 |
+
openpose_model.preprocessor.body_estimation.model.to(device)
|
126 |
+
pipe.to(device)
|
127 |
+
pipe.unet_encoder.to(device)
|
128 |
+
|
129 |
+
garment_image = garment_image.convert("RGB").resize((768, 1024))
|
130 |
+
human_image_orig = human_dict["background"].convert("RGB")
|
131 |
+
|
132 |
+
if use_auto_crop:
|
133 |
+
width, height = human_image_orig.size
|
134 |
+
target_width = int(min(width, height * (3 / 4)))
|
135 |
+
target_height = int(min(height, width * (4 / 3)))
|
136 |
+
left, top = (width - target_width) / 2, (height - target_height) / 2
|
137 |
+
right, bottom = (width + target_width) / 2, (height + target_height) / 2
|
138 |
+
cropped_img = human_image_orig.crop((left, top, right, bottom)).resize((768, 1024))
|
139 |
+
else:
|
140 |
+
cropped_img = human_image_orig.resize((768, 1024))
|
141 |
+
|
142 |
+
if use_auto_mask:
|
143 |
+
keypoints = openpose_model(cropped_img.resize((384, 512)))
|
144 |
+
model_parse, _ = parsing_model(cropped_img.resize((384, 512)))
|
145 |
+
mask, mask_gray = get_mask_location('hd', category, model_parse, keypoints)
|
146 |
+
mask = mask.resize((768, 1024))
|
147 |
+
else:
|
148 |
+
mask = pil_to_binary_mask(human_dict['layers'][0].convert("RGB").resize((768, 1024)))
|
149 |
|
150 |
+
mask_gray = (1 - transforms.ToTensor()(mask)) * transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])(cropped_img)
|
151 |
+
mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
|
152 |
+
|
153 |
+
human_image_arg = _apply_exif_orientation(cropped_img.resize((384, 512)))
|
154 |
+
human_image_arg = convert_PIL_to_numpy(human_image_arg, format="BGR")
|
155 |
+
|
156 |
+
args = apply_net.create_argument_parser().parse_args(
|
157 |
+
('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda'))
|
158 |
+
pose_image = args.func(args, human_image_arg)
|
159 |
+
pose_image = Image.fromarray(pose_image[:, :, ::-1]).resize((768, 1024))
|
160 |
+
|
161 |
+
with torch.no_grad(), torch.cuda.amp.autocast():
|
162 |
+
prompt = "model is wearing " + garment_description
|
163 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
164 |
+
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt(
|
165 |
+
prompt, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt
|
166 |
+
)
|
167 |
+
|
168 |
+
prompt_c = "a photo of " + garment_description
|
169 |
+
negative_prompt_c = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
170 |
+
prompt_embeds_c, _, _, _ = pipe.encode_prompt(
|
171 |
+
prompt_c, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=negative_prompt_c
|
172 |
+
)
|
173 |
+
|
174 |
+
pose_image = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])(pose_image).unsqueeze(0).to(device, torch.float16)
|
175 |
+
garment_tensor = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])(garment_image).unsqueeze(0).to(device, torch.float16)
|
176 |
+
|
177 |
+
images = pipe(
|
178 |
+
prompt_embeds=prompt_embeds.to(device, torch.float16),
|
179 |
+
negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
|
180 |
+
pose_image=pose_image,
|
181 |
+
garment_image=garment_tensor,
|
182 |
+
mask_image=mask_gray.to(device, torch.float16),
|
183 |
+
generator=torch.Generator(device).manual_seed(seed),
|
184 |
+
num_inference_steps=denoise_steps
|
185 |
+
).images
|
186 |
+
|
187 |
+
if images:
|
188 |
+
output_image = images[0]
|
189 |
+
output_base64 = encode_image_to_base64(output_image)
|
190 |
+
mask_image = mask
|
191 |
+
mask_base64 = encode_image_to_base64(mask_image)
|
192 |
+
return output_image, mask_image
|
193 |
+
else:
|
194 |
+
raise ValueError("Failed to generate image")
|
195 |
+
|
196 |
+
|
197 |
+
# Route pour récupérer l'image générée
|
198 |
+
@app.route('/api/get_image/<image_id>', methods=['GET'])
|
199 |
+
def get_image(image_id):
|
200 |
+
# Construire le chemin complet de l'image
|
201 |
+
image_path = image_id # Assurez-vous que le nom de fichier correspond à celui que vous avez utilisé lors de la sauvegarde
|
202 |
+
|
203 |
+
# Renvoyer l'image
|
204 |
+
try:
|
205 |
+
return send_file(image_path, mimetype='image/webp')
|
206 |
+
except FileNotFoundError:
|
207 |
+
return jsonify({'error': 'Image not found'}), 404
|
208 |
+
|
209 |
+
@app.route('/tryon', methods=['POST'])
|
210 |
+
def tryon_handler():
|
211 |
+
try:
|
212 |
+
data = request.json
|
213 |
+
human_image = decode_image_from_base64(data['human_image'])
|
214 |
+
garment_image = decode_image_from_base64(data['garment_image'])
|
215 |
+
description = data.get('description')
|
216 |
+
use_auto_mask = data.get('use_auto_mask', True)
|
217 |
+
use_auto_crop = data.get('use_auto_crop', False)
|
218 |
+
denoise_steps = int(data.get('denoise_steps', 30))
|
219 |
+
seed = int(data.get('seed', 42))
|
220 |
+
category = data.get('category', 'upper_body')
|
221 |
+
|
222 |
+
human_dict = {
|
223 |
+
'background': human_image,
|
224 |
+
'layers': [human_image] if not use_auto_mask else None,
|
225 |
+
'composite': None
|
226 |
+
}
|
227 |
+
clear_gpu_memory()
|
228 |
+
|
229 |
+
output_image, mask_image = start_tryon(
|
230 |
+
human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed, category
|
231 |
+
)
|
232 |
+
|
233 |
+
output_base64 = encode_image_to_base64(output_image)
|
234 |
+
mask_base64 = encode_image_to_base64(mask_image)
|
235 |
+
|
236 |
+
return jsonify({
|
237 |
+
'output_image': output_base64,
|
238 |
+
'mask_image': mask_base64
|
239 |
+
})
|
240 |
+
except Exception as e:
|
241 |
+
logging.error(f"Error in tryon_handler: {e}")
|
242 |
+
return jsonify({'error': str(e)}), 500
|
243 |
+
|
244 |
+
if __name__ == "__main__":
|
245 |
+
load_models() # Charge les modèles au démarrage
|
246 |
app.run(host='0.0.0.0', port=7860)
|