wearon / main.py
Bhushan26's picture
Update main.py
03d0a1d verified
raw
history blame contribute delete
No virus
3.26 kB
from fastapi import FastAPI, UploadFile, Form, File, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from gradio_client import Client, file
import aiofiles
import os
import shutil
import base64
import traceback
app = FastAPI()
# Allow CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
client = Client("yisol/IDM-VTON")
# client = Client("kadirnar/IDM-VTON")
# Directory to save uploaded and processed files
UPLOAD_FOLDER = 'static/uploads'
RESULT_FOLDER = 'static/results'
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULT_FOLDER, exist_ok=True)
@app.post("/")
async def hello():
return {"Wearon": "wearon model is running"}
@app.post("/process")
async def predict(product_image_url: str = Form(...), model_image: UploadFile = File(...)):
try:
if not model_image:
raise HTTPException(status_code=400, detail="No model image file provided")
# Save the uploaded file to the upload directory
filename = os.path.join(UPLOAD_FOLDER, model_image.filename)
async with aiofiles.open(filename, "wb") as buffer:
content = await model_image.read()
await buffer.write(content)
base_path = os.getcwd()
full_filename = os.path.normpath(os.path.join(base_path, filename))
print("Product image =", product_image_url)
print("Model image =", full_filename)
# Perform prediction
try:
result = client.predict(
dict={"background": file(full_filename), "layers": [], "composite": None},
garm_img=file(product_image_url),
garment_des="Hello!!",
is_checked=True,
is_checked_crop=False,
denoise_steps=30,
seed=42,
api_name="/tryon"
)
except Exception as e:
traceback.print_exc()
raise
print(result)
# Extract the path of the first output image
output_image_path = result[0]
# Copy the output image to the RESULT_FOLDER
output_image_filename = os.path.basename(output_image_path)
local_output_path = os.path.join(RESULT_FOLDER, output_image_filename)
shutil.copy(output_image_path, local_output_path)
# Remove the uploaded file after processing
os.remove(filename)
# Encode the output image in base64
async with aiofiles.open(local_output_path, "rb") as image_file:
encoded_image = base64.b64encode(await image_file.read()).decode('utf-8')
# Return the output image in JSON format
return JSONResponse(content={"image": encoded_image}, status_code=200)
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
@app.get("/uploads/{filename}")
async def uploaded_file(filename: str):
file_path = os.path.join(UPLOAD_FOLDER, filename)
if os.path.exists(file_path):
return FileResponse(file_path)
else:
raise HTTPException(status_code=404, detail="File not found")