Spaces:
Saad0KH
/
Running on Zero

Saad0KH commited on
Commit
a70e8db
·
verified ·
1 Parent(s): 3f5f533

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -107
app.py CHANGED
@@ -34,26 +34,39 @@ from torchvision.transforms.functional import to_pil_image
34
 
35
  app = Flask(__name__)
36
 
37
- # Base paths for models
38
  base_path = 'yisol/IDM-VTON'
39
 
40
- # Load models
41
- device = "cuda" if torch.cuda.is_available() else "cpu"
42
-
43
- unet = UNet2DConditionModel.from_pretrained(base_path, subfolder="unet", torch_dtype=torch.float16).to(device)
44
- tokenizer_one = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer", use_fast=False)
45
- tokenizer_two = AutoTokenizer.from_pretrained(base_path, subfolder="tokenizer_2", use_fast=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
47
- text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device)
48
- text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16).to(device)
49
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16).to(device)
50
- vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16).to(device)
51
- UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16).to(device)
52
 
53
- parsing_model = Parsing(0).to(device)
54
- openpose_model = OpenPose(0).to(device)
55
 
56
- # Prepare Tryon pipeline
57
  pipe = TryonPipeline.from_pretrained(
58
  base_path,
59
  unet=unet,
@@ -66,11 +79,12 @@ pipe = TryonPipeline.from_pretrained(
66
  scheduler=noise_scheduler,
67
  image_encoder=image_encoder,
68
  torch_dtype=torch.float16,
69
- ).to(device)
 
70
  pipe.unet_encoder = UNet_Encoder
71
 
72
- # Image transformation
73
- tensor_transform = transforms.Compose([
74
  transforms.ToTensor(),
75
  transforms.Normalize([0.5], [0.5]),
76
  ])
@@ -82,11 +96,13 @@ def pil_to_binary_mask(pil_image, threshold=0):
82
  mask = np.zeros(binary_mask.shape, dtype=np.uint8)
83
  mask[binary_mask] = 1
84
  return Image.fromarray((mask * 255).astype(np.uint8))
 
 
85
 
86
  def get_image_from_url(url):
87
  try:
88
  response = requests.get(url)
89
- response.raise_for_status()
90
  img = Image.open(BytesIO(response.content))
91
  return img
92
  except Exception as e:
@@ -117,7 +133,12 @@ def save_image(img):
117
  return unique_name
118
 
119
  @spaces.GPU
120
- def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, categorie='upper_body'):
 
 
 
 
 
121
  garm_img = garm_img.convert("RGB").resize((768, 1024))
122
  human_img_orig = dict["background"].convert("RGB")
123
 
@@ -138,12 +159,11 @@ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denois
138
  if is_checked:
139
  keypoints = openpose_model(human_img.resize((384, 512)))
140
  model_parse, _ = parsing_model(human_img.resize((384, 512)))
141
- mask, mask_gray = get_mask_location('hd', categorie, model_parse, keypoints)
142
  mask = mask.resize((768, 1024))
143
  else:
144
- mask = dict['layers'][0].convert("RGB").resize((768, 1024))
145
-
146
- mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transform(human_img)
147
  mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
148
 
149
  human_img_arg = _apply_exif_orientation(human_img.resize((384, 512)))
@@ -158,58 +178,66 @@ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denois
158
  with torch.cuda.amp.autocast():
159
  prompt = "model is wearing " + garment_des
160
  negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
161
- (
162
- prompt_embeds,
163
- negative_prompt_embeds,
164
- pooled_prompt_embeds,
165
- negative_pooled_prompt_embeds,
166
- ) = pipe.encode_prompt(
167
- prompt,
168
- num_images_per_prompt=1,
169
- do_classifier_free_guidance=True,
170
- negative_prompt=negative_prompt,
171
- )
172
-
173
- prompt_c = "a photo of " + garment_des
174
- negative_prompt_c = "monochrome, lowres, bad anatomy, worst quality, low quality, change color"
175
- prompt = [prompt_c] if not isinstance(prompt_c, list) else prompt_c
176
- negative_prompt = [negative_prompt_c] if not isinstance(negative_prompt_c, list) else negative_prompt_c
177
-
178
- (
179
- prompt_embeds_c,
180
- _,
181
- _,
182
- _,
183
- ) = pipe.encode_prompt(
184
- prompt,
185
- num_images_per_prompt=1,
186
- do_classifier_free_guidance=False,
187
- negative_prompt=negative_prompt,
188
- )
189
-
190
- pose_img_tensor = tensor_transform(pose_img).unsqueeze(0).to(device, torch.float16)
191
- garm_tensor = tensor_transform(garm_img).unsqueeze(0).to(device, torch.float16)
192
- generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
193
-
194
- images = pipe(
195
- prompt_embeds=prompt_embeds.to(device),
196
- negative_prompt_embeds=negative_prompt_embeds.to(device),
197
- pooled_prompt_embeds=pooled_prompt_embeds.to(device),
198
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device),
199
- num_inference_steps=denoise_steps,
200
- generator=generator,
201
- strength=1.5,
202
- pose_img=pose_img_tensor.to(device),
203
- text_embeds_cloth=prompt_embeds_c.to(device),
204
- cloth=garm_tensor.to(device),
205
- mask_image=mask,
206
- image=human_img,
207
- height=1024,
208
- width=768,
209
- )
210
-
211
- final_image = images[0] if isinstance(images, list) else images
212
- return encode_image_to_base64(final_image)
 
 
 
 
 
 
 
 
213
 
214
 
215
  @app.route('/tryon-v2', methods=['POST'])
@@ -286,42 +314,51 @@ def tryon():
286
  'mask_image': mask_base64
287
  })
288
 
289
- @spaces.GPU
290
- @app.route('/get_mask', methods=['POST'])
291
- def get_mask():
292
- try:
293
- # Récupérer l'image du corps à partir de la requête
294
- data = request.json
295
- img_file = process_image(data['image'])
296
- img = img_file.convert("RGB").resize((384, 512))
297
- categorie = request.form.get('categorie', 'upper_body') # Paramètre avec valeur par défaut
298
-
299
- # Appliquer la détection des points clés
300
- keypoints = openpose_model(img) # Utilise votre modèle
301
- model_parse, _ = parsing_model(img) # Utilise votre modèle
302
 
303
- # Déplacer le modèle et les images sur le même dispositif
304
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
305
- img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device) # Convertir et déplacer l'image
306
-
307
- # Assurez-vous que le modèle est sur le même dispositif
308
- parsing_model.to(device)
309
-
310
- # Obtenir le masque
311
- mask, mask_gray = get_mask_location('hd', categorie, model_parse, keypoints)
312
 
313
- # Convertir le masque en image (si nécessaire)
314
- mask_gray = (1 - transforms.ToTensor()(mask_gray)) * tensor_transform(img_tensor)
315
- mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
 
316
 
317
- # Convertir l'image en base64 si besoin pour le retour
318
- img_byte_arr = io.BytesIO()
319
- mask_gray.save(img_byte_arr, format='PNG')
320
- img_byte_arr.seek(0)
321
- return jsonify({'mask': img_byte_arr.getvalue().decode('latin1')})
 
 
322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  except Exception as e:
324
- print(e)
325
  return jsonify({'error': str(e)}), 500
326
 
327
  # Route index
 
34
 
35
  app = Flask(__name__)
36
 
37
+ # Chemins de base pour les modèles
38
  base_path = 'yisol/IDM-VTON'
39
 
40
+ # Chargement des modèles
41
+ unet = UNet2DConditionModel.from_pretrained(
42
+ base_path,
43
+ subfolder="unet",
44
+ torch_dtype=torch.float16,
45
+ force_download=False
46
+ )
47
+ tokenizer_one = AutoTokenizer.from_pretrained(
48
+ base_path,
49
+ subfolder="tokenizer",
50
+ use_fast=False,
51
+ force_download=False
52
+ )
53
+ tokenizer_two = AutoTokenizer.from_pretrained(
54
+ base_path,
55
+ subfolder="tokenizer_2",
56
+ use_fast=False,
57
+ force_download=False
58
+ )
59
  noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
60
+ text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16)
61
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16)
62
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16)
63
+ vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16)
64
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16)
65
 
66
+ parsing_model = Parsing(0)
67
+ openpose_model = OpenPose(0)
68
 
69
+ # Préparation du pipeline Tryon
70
  pipe = TryonPipeline.from_pretrained(
71
  base_path,
72
  unet=unet,
 
79
  scheduler=noise_scheduler,
80
  image_encoder=image_encoder,
81
  torch_dtype=torch.float16,
82
+ force_download=False
83
+ )
84
  pipe.unet_encoder = UNet_Encoder
85
 
86
+ # Utilisation des transformations d'images
87
+ tensor_transfrom = transforms.Compose([
88
  transforms.ToTensor(),
89
  transforms.Normalize([0.5], [0.5]),
90
  ])
 
96
  mask = np.zeros(binary_mask.shape, dtype=np.uint8)
97
  mask[binary_mask] = 1
98
  return Image.fromarray((mask * 255).astype(np.uint8))
99
+
100
+
101
 
102
  def get_image_from_url(url):
103
  try:
104
  response = requests.get(url)
105
+ response.raise_for_status() # Vérifie les erreurs HTTP
106
  img = Image.open(BytesIO(response.content))
107
  return img
108
  except Exception as e:
 
133
  return unique_name
134
 
135
  @spaces.GPU
136
+ def start_tryon(dict, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed, categorie = 'upper_body'):
137
+ device = "cuda"
138
+ openpose_model.preprocessor.body_estimation.model.to(device)
139
+ pipe.to(device)
140
+ pipe.unet_encoder.to(device)
141
+
142
  garm_img = garm_img.convert("RGB").resize((768, 1024))
143
  human_img_orig = dict["background"].convert("RGB")
144
 
 
159
  if is_checked:
160
  keypoints = openpose_model(human_img.resize((384, 512)))
161
  model_parse, _ = parsing_model(human_img.resize((384, 512)))
162
+ mask, mask_gray = get_mask_location('hd', categorie , model_parse, keypoints)
163
  mask = mask.resize((768, 1024))
164
  else:
165
+ mask = dict['layers'][0].convert("RGB").resize((768, 1024))#pil_to_binary_mask(dict['layers'][0].convert("RGB").resize((768, 1024)))
166
+ mask_gray = (1 - transforms.ToTensor()(mask)) * tensor_transfrom(human_img)
 
167
  mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
168
 
169
  human_img_arg = _apply_exif_orientation(human_img.resize((384, 512)))
 
178
  with torch.cuda.amp.autocast():
179
  prompt = "model is wearing " + garment_des
180
  negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
181
+ with torch.inference_mode():
182
+ (
183
+ prompt_embeds,
184
+ negative_prompt_embeds,
185
+ pooled_prompt_embeds,
186
+ negative_pooled_prompt_embeds,
187
+ ) = pipe.encode_prompt(
188
+ prompt,
189
+ num_images_per_prompt=1,
190
+ do_classifier_free_guidance=True,
191
+ negative_prompt=negative_prompt,
192
+ )
193
+
194
+ prompt = "a photo of " + garment_des
195
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality , change color"
196
+ if not isinstance(prompt, list):
197
+ prompt = [prompt] * 1
198
+ if not isinstance(negative_prompt, list):
199
+ negative_prompt = [negative_prompt] * 1
200
+ with torch.inference_mode():
201
+ (
202
+ prompt_embeds_c,
203
+ _,
204
+ _,
205
+ _,
206
+ ) = pipe.encode_prompt(
207
+ prompt,
208
+ num_images_per_prompt=1,
209
+ do_classifier_free_guidance=False,
210
+ negative_prompt=negative_prompt,
211
+ )
212
+
213
+ pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device, torch.float16)
214
+ garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device, torch.float16)
215
+ generator = torch.Generator(device).manual_seed(seed) if seed is not None else None
216
+ images = pipe(
217
+ prompt_embeds=prompt_embeds.to(device, torch.float16),
218
+ negative_prompt_embeds=negative_prompt_embeds.to(device, torch.float16),
219
+ pooled_prompt_embeds=pooled_prompt_embeds.to(device, torch.float16),
220
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device, torch.float16),
221
+ num_inference_steps=denoise_steps,
222
+ generator=generator,
223
+ strength=1.5,
224
+ pose_img=pose_img.to(device, torch.float16),
225
+ text_embeds_cloth=prompt_embeds_c.to(device, torch.float16),
226
+ cloth=garm_tensor.to(device, torch.float16),
227
+ mask_image=mask,
228
+ image=human_img,
229
+ height=1024,
230
+ width=768,
231
+ ip_adapter_image=garm_img.resize((768, 1024)),
232
+ guidance_scale=1.5,
233
+ )[0]
234
+
235
+ if is_checked_crop:
236
+ out_img = images[0].resize(crop_size)
237
+ human_img_orig.paste(out_img, (int(left), int(top)))
238
+ return human_img_orig, mask_gray
239
+ else:
240
+ return images[0], mask_gray , mask
241
 
242
 
243
  @app.route('/tryon-v2', methods=['POST'])
 
314
  'mask_image': mask_base64
315
  })
316
 
317
+ @spaces.GPU
318
+ def generate_mask(human_img, categorie='upper_body'):
319
+ device = "cuda"
320
+ openpose_model.preprocessor.body_estimation.model.to(device)
321
+ pipe.to(device)
 
 
 
 
 
 
 
 
322
 
323
+ try:
324
+ # Redimensionner l'image pour le modèle
325
+ human_img_resized = human_img.convert("RGB").resize((384, 512))
 
 
 
 
 
 
326
 
327
+ # Générer les points clés et le masque
328
+ keypoints = openpose_model(human_img_resized)
329
+ model_parse, _ = parsing_model(human_img_resized)
330
+ mask, _ = get_mask_location('hd', categorie, model_parse, keypoints)
331
 
332
+ # Redimensionner le masque à la taille d'origine de l'image
333
+ mask_resized = mask.resize(human_img.size)
334
+
335
+ return mask_resized
336
+ except Exception as e:
337
+ logging.error(f"Error generating mask: {e}")
338
+ raise e
339
 
340
+ @app.route('/generate_mask', methods=['POST'])
341
+ def generate_mask_api():
342
+ try:
343
+ # Récupérer les données de l'image à partir de la requête
344
+ data = request.json
345
+ base64_image = data.get('image')
346
+ categorie = data.get('categorie', 'upper_body')
347
+
348
+ # Décodage de l'image à partir de base64
349
+ human_img = process_image(base64_image)
350
+
351
+ # Appeler la fonction pour générer le masque
352
+ mask_resized = generate_mask(human_img, categorie)
353
+
354
+ # Encodage du masque en base64 pour la réponse
355
+ mask_base64 = encode_image_to_base64(mask_resized)
356
+
357
+ return jsonify({
358
+ 'mask_image': mask_base64
359
+ }), 200
360
  except Exception as e:
361
+ logging.error(f"Error generating mask: {e}")
362
  return jsonify({'error': str(e)}), 500
363
 
364
  # Route index