Update app.py
Browse files
app.py
CHANGED
@@ -314,6 +314,7 @@ def tryon():
|
|
314 |
'mask_image': mask_base64
|
315 |
})
|
316 |
|
|
|
317 |
@app.route('/get_mask', methods=['POST'])
|
318 |
def get_mask():
|
319 |
try:
|
@@ -327,20 +328,28 @@ def get_mask():
|
|
327 |
keypoints = openpose_model(img) # Utilise votre modèle
|
328 |
model_parse, _ = parsing_model(img) # Utilise votre modèle
|
329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
# Obtenir le masque
|
331 |
-
mask, mask_gray = get_mask_location('hd', categorie
|
332 |
-
|
333 |
# Convertir le masque en image (si nécessaire)
|
334 |
-
mask_gray = (1 - transforms.ToTensor()(mask_gray)) *
|
335 |
mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
|
336 |
|
337 |
# Convertir l'image en base64 si besoin pour le retour
|
338 |
img_byte_arr = io.BytesIO()
|
339 |
mask_gray.save(img_byte_arr, format='PNG')
|
340 |
img_byte_arr.seek(0)
|
341 |
-
return jsonify({'mask': img_byte_arr.getvalue().decode('latin1')})
|
342 |
|
343 |
except Exception as e:
|
|
|
344 |
return jsonify({'error': str(e)}), 500
|
345 |
|
346 |
# Route index
|
|
|
314 |
'mask_image': mask_base64
|
315 |
})
|
316 |
|
317 |
+
@spaces.GPU
|
318 |
@app.route('/get_mask', methods=['POST'])
|
319 |
def get_mask():
|
320 |
try:
|
|
|
328 |
keypoints = openpose_model(img) # Utilise votre modèle
|
329 |
model_parse, _ = parsing_model(img) # Utilise votre modèle
|
330 |
|
331 |
+
# Déplacer le modèle et les images sur le même dispositif
|
332 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
333 |
+
img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device) # Convertir et déplacer l'image
|
334 |
+
|
335 |
+
# Assurez-vous que le modèle est sur le même dispositif
|
336 |
+
parsing_model.to(device)
|
337 |
+
|
338 |
# Obtenir le masque
|
339 |
+
mask, mask_gray = get_mask_location('hd', categorie, model_parse, keypoints)
|
340 |
+
|
341 |
# Convertir le masque en image (si nécessaire)
|
342 |
+
mask_gray = (1 - transforms.ToTensor()(mask_gray)) * tensor_transform(img_tensor)
|
343 |
mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
|
344 |
|
345 |
# Convertir l'image en base64 si besoin pour le retour
|
346 |
img_byte_arr = io.BytesIO()
|
347 |
mask_gray.save(img_byte_arr, format='PNG')
|
348 |
img_byte_arr.seek(0)
|
349 |
+
return jsonify({'mask': img_byte_arr.getvalue().decode('latin1')})
|
350 |
|
351 |
except Exception as e:
|
352 |
+
print(e)
|
353 |
return jsonify({'error': str(e)}), 500
|
354 |
|
355 |
# Route index
|