Bhushan26 commited on
Commit
d7228b2
·
verified ·
1 Parent(s): efa2ff8

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +74 -69
main.py CHANGED
@@ -1,122 +1,127 @@
1
- from quart import Quart, request, jsonify, send_from_directory, websocket
2
- from gradio_client import Client, file
3
- from quart_cors import cors
4
- import os
5
- import traceback
6
  import shutil
7
  import base64
8
- import asyncio
9
- import httpx
10
- from tenacity import retry, stop_after_attempt, wait_fixed
 
 
11
 
12
- app = Quart(__name__)
13
- cors(app)
14
 
 
 
15
  # Directory to save uploaded and processed files
16
- UPLOAD_FOLDER = 'static/uploads'
17
- RESULT_FOLDER = 'static/results'
18
- if not os.path.exists(UPLOAD_FOLDER):
19
- os.makedirs(UPLOAD_FOLDER)
20
- if not os.path.exists(RESULT_FOLDER):
21
- os.makedirs(RESULT_FOLDER)
22
 
23
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
24
  app.config['RESULT_FOLDER'] = RESULT_FOLDER
25
 
26
- @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
27
- async def initialize_client():
28
- return Client("yisol/IDM-VTON")
29
 
30
- client = None
31
 
32
- @app.before_serving
33
- async def startup():
34
- global client
35
- try:
36
- client = await initialize_client()
37
- print("Client initialized successfully")
38
- except Exception as e:
39
- print(f"Failed to initialize client: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  @app.route('/process', methods=['POST'])
42
- async def predict():
43
- global client
44
- if client is None:
45
- return jsonify(error='Client not initialized. Please try again later.'), 503
46
 
47
  try:
48
- form = await request.form
49
- files = await request.files
50
-
51
  # Get the product image URL from the request
52
- product_image_url = form.get('product_image_url')
 
 
 
53
 
54
  # Handle the uploaded model image
55
- if 'model_image' not in files:
 
56
  return jsonify(error='No model image file provided'), 400
57
 
58
- model_image = files['model_image']
59
  if model_image.filename == '':
 
60
  return jsonify(error='No selected file'), 400
61
 
62
  # Save the uploaded file to the upload directory
63
  filename = os.path.join(app.config['UPLOAD_FOLDER'], model_image.filename)
64
- await model_image.save(filename)
65
 
66
- base_path = os.getcwd()
67
- full_filename = os.path.normpath(os.path.join(base_path, filename))
68
 
69
- print("Product image = ", product_image_url)
70
- print("Model image = ", full_filename)
71
 
72
- # Perform prediction
73
- try:
74
- result = await asyncio.to_thread(client.predict,
75
- dict={"background": file(full_filename), "layers": [], "composite": None},
76
- garm_img=file(product_image_url),
77
- garment_des="Hello!!",
78
- is_checked=True,
79
- is_checked_crop=False,
80
- denoise_steps=30,
81
- seed=42,
82
- api_name="/tryon"
83
- )
84
- except Exception as e:
85
- traceback.print_exc()
86
- raise
87
 
88
- print(result)
89
- # Extract the path of the first output image
90
  output_image_path = result[0]
 
 
 
91
 
92
  # Copy the output image to the RESULT_FOLDER
93
  output_image_filename = os.path.basename(output_image_path)
94
  local_output_path = os.path.join(app.config['RESULT_FOLDER'], output_image_filename)
95
  shutil.copy(output_image_path, local_output_path)
96
 
 
97
  # Remove the uploaded file after processing
98
- os.remove(filename)
 
99
 
100
  # Encode the output image in base64
101
  with open(local_output_path, "rb") as image_file:
102
  encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
103
 
104
  # Return the output image in JSON format
 
105
  return jsonify(image=encoded_image), 200
106
 
107
  except Exception as e:
 
108
  traceback.print_exc()
109
  return jsonify(error=str(e)), 500
110
 
111
  @app.route('/uploads/<filename>')
112
- async def uploaded_file(filename):
113
- return await send_from_directory(app.config['UPLOAD_FOLDER'], filename)
114
 
115
- @app.websocket('/ws')
116
- async def ws():
117
- while True:
118
- data = await websocket.receive()
119
- await websocket.send(f"Echo: {data}")
120
 
121
  if __name__ == '__main__':
122
- app.run(host='0.0.0.0', port=5000)
 
 
 
 
 
 
1
  import shutil
2
  import base64
3
+ from gradio_client import Client, file
4
+
5
+
6
+
7
+
8
 
 
 
9
 
10
+ app = Flask(__name__)
11
+ CORS(app)
12
  # Directory to save uploaded and processed files
13
+ UPLOAD_FOLDER = tempfile.mkdtemp()
14
+ RESULT_FOLDER = tempfile.mkdtemp()
 
 
 
 
15
 
16
  app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
17
  app.config['RESULT_FOLDER'] = RESULT_FOLDER
18
 
 
 
 
19
 
 
20
 
21
+
22
+ def predict_with_timeout(model_image_path, product_image_url, timeout=600):
23
+ result = [None] # Mutable object to store the result
24
+
25
+ def target():
26
+ try:
27
+
28
+ result[0] = client.predict(
29
+ dict({"background": file(model_image_path), "layers": [], "composite": None}),
30
+ garm_img=file(product_image_url),
31
+ seed=42,
32
+ api_name="/tryon"
33
+ )
34
+
35
+ except Exception as e:
36
+
37
+ result[0] = str(e)
38
+
39
+ thread = threading.Thread(target=target)
40
+ thread.start()
41
+ thread.join(timeout)
42
+
43
+ if thread.is_alive():
44
+ return "Prediction timed out after {} seconds".format(timeout)
45
+
46
+ if isinstance(result[0], Exception):
47
+ return str(result[0]) # Return the error message
48
+ return result[0]
49
+
50
+ @app.route('/')
51
+ def index():
52
+
53
+ return {'message': 'This is a wearon API'}
54
 
55
  @app.route('/process', methods=['POST'])
56
+ def predict():
 
 
 
57
 
58
  try:
 
 
 
59
  # Get the product image URL from the request
60
+ product_image_url = request.form.get('product_image_url')
61
+ if not product_image_url:
62
+
63
+ return jsonify(error='No product image URL provided'), 400
64
 
65
  # Handle the uploaded model image
66
+ if 'model_image' not in request.files:
67
+
68
  return jsonify(error='No model image file provided'), 400
69
 
70
+ model_image = request.files['model_image']
71
  if model_image.filename == '':
72
+
73
  return jsonify(error='No selected file'), 400
74
 
75
  # Save the uploaded file to the upload directory
76
  filename = os.path.join(app.config['UPLOAD_FOLDER'], model_image.filename)
77
+ model_image.save(filename)
78
 
79
+ full_filename = os.path.abspath(filename)
 
80
 
81
+ print("Product image URL:", product_image_url)
82
+ print("Model image path:", full_filename)
83
 
84
+ # Perform prediction with a timeout
85
+ result = predict_with_timeout(full_filename, product_image_url)
86
+ if isinstance(result, str):
87
+
88
+ return jsonify(error=result), 500
89
+
90
+ print("Prediction result:", result)
 
 
 
 
 
 
 
 
91
 
92
+ # Check if the result contains a valid path
 
93
  output_image_path = result[0]
94
+ if not os.path.exists(output_image_path):
95
+ return jsonify(error='Output image file not found: {}'.format(output_image_path)), 500
96
+
97
 
98
  # Copy the output image to the RESULT_FOLDER
99
  output_image_filename = os.path.basename(output_image_path)
100
  local_output_path = os.path.join(app.config['RESULT_FOLDER'], output_image_filename)
101
  shutil.copy(output_image_path, local_output_path)
102
 
103
+
104
  # Remove the uploaded file after processing
105
+ os.remove(full_filename)
106
+
107
 
108
  # Encode the output image in base64
109
  with open(local_output_path, "rb") as image_file:
110
  encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
111
 
112
  # Return the output image in JSON format
113
+
114
  return jsonify(image=encoded_image), 200
115
 
116
  except Exception as e:
117
+
118
  traceback.print_exc()
119
  return jsonify(error=str(e)), 500
120
 
121
  @app.route('/uploads/<filename>')
122
+ def uploaded_file(filename):
 
123
 
124
+ return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
 
 
 
 
125
 
126
  if __name__ == '__main__':
127
+ app.run(host='0.0.0.0', port=7860)