Spaces:
Saad0KH
/
Running on Zero

Saad0KH commited on
Commit
601528d
·
verified ·
1 Parent(s): 60362b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -314,6 +314,7 @@ def tryon():
314
  'mask_image': mask_base64
315
  })
316
 
 
317
  @app.route('/get_mask', methods=['POST'])
318
  def get_mask():
319
  try:
@@ -327,20 +328,28 @@ def get_mask():
327
  keypoints = openpose_model(img) # Utilise votre modèle
328
  model_parse, _ = parsing_model(img) # Utilise votre modèle
329
 
 
 
 
 
 
 
 
330
  # Obtenir le masque
331
- mask, mask_gray = get_mask_location('hd', categorie , model_parse, keypoints)
332
-
333
  # Convertir le masque en image (si nécessaire)
334
- mask_gray = (1 - transforms.ToTensor()(mask_gray)) * tensor_transfrom(img)
335
  mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
336
 
337
  # Convertir l'image en base64 si besoin pour le retour
338
  img_byte_arr = io.BytesIO()
339
  mask_gray.save(img_byte_arr, format='PNG')
340
  img_byte_arr.seek(0)
341
- return jsonify({'mask': img_byte_arr.getvalue().decode('latin1')}) # Utiliser une méthode appropriée pour l'encodage
342
 
343
  except Exception as e:
 
344
  return jsonify({'error': str(e)}), 500
345
 
346
  # Route index
 
314
  'mask_image': mask_base64
315
  })
316
 
317
+ @spaces.GPU
318
  @app.route('/get_mask', methods=['POST'])
319
  def get_mask():
320
  try:
 
328
  keypoints = openpose_model(img) # Utilise votre modèle
329
  model_parse, _ = parsing_model(img) # Utilise votre modèle
330
 
331
+ # Déplacer le modèle et les images sur le même dispositif
332
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
333
+ img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(device) # Convertir et déplacer l'image
334
+
335
+ # Assurez-vous que le modèle est sur le même dispositif
336
+ parsing_model.to(device)
337
+
338
  # Obtenir le masque
339
+ mask, mask_gray = get_mask_location('hd', categorie, model_parse, keypoints)
340
+
341
  # Convertir le masque en image (si nécessaire)
342
+ mask_gray = (1 - transforms.ToTensor()(mask_gray)) * tensor_transform(img_tensor)
343
  mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
344
 
345
  # Convertir l'image en base64 si besoin pour le retour
346
  img_byte_arr = io.BytesIO()
347
  mask_gray.save(img_byte_arr, format='PNG')
348
  img_byte_arr.seek(0)
349
+ return jsonify({'mask': img_byte_arr.getvalue().decode('latin1')})
350
 
351
  except Exception as e:
352
+ print(e)
353
  return jsonify({'error': str(e)}), 500
354
 
355
  # Route index