File size: 3,264 Bytes
cf1378b
b89e93c
cf1378b
46d3354
b89e93c
46d3354
b4cb033
 
cf1378b
 
 
d7228b2
cf1378b
 
 
 
 
 
 
 
d7228b2
03d0a1d
 
d7228b2
cf1378b
46d3354
 
ec90435
 
b4cb033
cf1378b
 
b89e93c
cf1378b
 
 
21d171d
cf1378b
 
ec90435
cf1378b
 
b89e93c
 
 
cf1378b
46d3354
 
cf1378b
b89e93c
 
ec90435
cf1378b
 
b89e93c
cf1378b
 
 
 
 
 
 
 
 
 
 
 
 
46d3354
cf1378b
b4cb033
cf1378b
 
b4cb033
cf1378b
ec90435
cf1378b
 
46d3354
cf1378b
 
b89e93c
 
cf1378b
 
 
 
21d171d
b4cb033
cf1378b
21d171d
cf1378b
 
 
 
 
 
b89e93c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from fastapi import FastAPI, UploadFile, Form, File, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from gradio_client import Client, file
import aiofiles
import os
import shutil
import base64
import traceback

app = FastAPI()

# Allow CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

client = Client("yisol/IDM-VTON")
# client = Client("kadirnar/IDM-VTON")

# Directory to save uploaded and processed files
UPLOAD_FOLDER = 'static/uploads'
RESULT_FOLDER = 'static/results'
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(RESULT_FOLDER, exist_ok=True)

@app.post("/")
async def hello():
    return {"Wearon": "wearon model is running"}

@app.post("/process")
async def predict(product_image_url: str = Form(...), model_image: UploadFile = File(...)):
    try:
        if not model_image:
            raise HTTPException(status_code=400, detail="No model image file provided")
        
        # Save the uploaded file to the upload directory
        filename = os.path.join(UPLOAD_FOLDER, model_image.filename)
        async with aiofiles.open(filename, "wb") as buffer:
            content = await model_image.read()
            await buffer.write(content)

        base_path = os.getcwd()
        full_filename = os.path.normpath(os.path.join(base_path, filename))

        print("Product image =", product_image_url)
        print("Model image =", full_filename)
        
        # Perform prediction
        try:
            result = client.predict(
                dict={"background": file(full_filename), "layers": [], "composite": None},
                garm_img=file(product_image_url),
                garment_des="Hello!!",
                is_checked=True,
                is_checked_crop=False,
                denoise_steps=30,
                seed=42,
                api_name="/tryon"
            )
        except Exception as e:
            traceback.print_exc()
            raise

        print(result)
        # Extract the path of the first output image
        output_image_path = result[0]

        # Copy the output image to the RESULT_FOLDER
        output_image_filename = os.path.basename(output_image_path)
        local_output_path = os.path.join(RESULT_FOLDER, output_image_filename)
        shutil.copy(output_image_path, local_output_path)

        # Remove the uploaded file after processing
        os.remove(filename)

        # Encode the output image in base64
        async with aiofiles.open(local_output_path, "rb") as image_file:
            encoded_image = base64.b64encode(await image_file.read()).decode('utf-8')

        # Return the output image in JSON format
        return JSONResponse(content={"image": encoded_image}, status_code=200)

    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/uploads/{filename}")
async def uploaded_file(filename: str):
    file_path = os.path.join(UPLOAD_FOLDER, filename)
    if os.path.exists(file_path):
        return FileResponse(file_path)
    else:
        raise HTTPException(status_code=404, detail="File not found")