Spaces:
Running
on
A10G
Running
on
A10G
lock with async wait
Browse files
app.py
CHANGED
@@ -17,6 +17,8 @@ import uuid
|
|
17 |
import logging
|
18 |
from fastapi import FastAPI, Request, HTTPException
|
19 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
20 |
|
21 |
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
|
22 |
|
@@ -24,7 +26,7 @@ MAX_SEED = np.iinfo(np.int32).max
|
|
24 |
USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "0") == "1"
|
25 |
SPACE_ID = os.environ.get("SPACE_ID", "")
|
26 |
DEV = os.environ.get("DEV", "0") == "1"
|
27 |
-
os.environ[
|
28 |
|
29 |
DB_PATH = Path("/data/cache") if SPACE_ID else Path("./cache")
|
30 |
IMGS_PATH = DB_PATH / "imgs"
|
@@ -32,6 +34,7 @@ DB_PATH.mkdir(exist_ok=True, parents=True)
|
|
32 |
IMGS_PATH.mkdir(exist_ok=True, parents=True)
|
33 |
|
34 |
database = Database(DB_PATH)
|
|
|
35 |
|
36 |
dtype = torch.bfloat16
|
37 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
@@ -120,21 +123,25 @@ app.add_middleware(
|
|
120 |
|
121 |
|
122 |
@app.get("/image")
|
123 |
-
async def generate_image(
|
|
|
|
|
124 |
cached_img = database.check(prompt, negative_prompt, seed)
|
125 |
if cached_img:
|
126 |
logging.info(f"Image found in cache: {cached_img[0]}")
|
127 |
return StreamingResponse(open(cached_img[0], "rb"), media_type="image/jpeg")
|
128 |
|
129 |
logging.info(f"Image not found in cache, generating new image")
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
138 |
|
139 |
return StreamingResponse(img_io, media_type="image/jpeg")
|
140 |
|
|
|
17 |
import logging
|
18 |
from fastapi import FastAPI, Request, HTTPException
|
19 |
from fastapi.middleware.cors import CORSMiddleware
|
20 |
+
from asyncio import Lock
|
21 |
+
|
22 |
|
23 |
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
|
24 |
|
|
|
26 |
USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "0") == "1"
|
27 |
SPACE_ID = os.environ.get("SPACE_ID", "")
|
28 |
DEV = os.environ.get("DEV", "0") == "1"
|
29 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
30 |
|
31 |
DB_PATH = Path("/data/cache") if SPACE_ID else Path("./cache")
|
32 |
IMGS_PATH = DB_PATH / "imgs"
|
|
|
34 |
IMGS_PATH.mkdir(exist_ok=True, parents=True)
|
35 |
|
36 |
database = Database(DB_PATH)
|
37 |
+
generate_lock = Lock()
|
38 |
|
39 |
dtype = torch.bfloat16
|
40 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
123 |
|
124 |
|
125 |
@app.get("/image")
|
126 |
+
async def generate_image(
|
127 |
+
prompt: str, negative_prompt: str = "", seed: int = 2134213213
|
128 |
+
):
|
129 |
cached_img = database.check(prompt, negative_prompt, seed)
|
130 |
if cached_img:
|
131 |
logging.info(f"Image found in cache: {cached_img[0]}")
|
132 |
return StreamingResponse(open(cached_img[0], "rb"), media_type="image/jpeg")
|
133 |
|
134 |
logging.info(f"Image not found in cache, generating new image")
|
135 |
+
async with generate_lock:
|
136 |
+
|
137 |
+
pil_image = generate(prompt, negative_prompt, seed)
|
138 |
+
img_id = str(uuid.uuid4())
|
139 |
+
img_path = IMGS_PATH / f"{img_id}.jpg"
|
140 |
+
pil_image.save(img_path)
|
141 |
+
img_io = io.BytesIO()
|
142 |
+
pil_image.save(img_io, "JPEG")
|
143 |
+
img_io.seek(0)
|
144 |
+
database.insert(prompt, negative_prompt, str(img_path), seed)
|
145 |
|
146 |
return StreamingResponse(img_io, media_type="image/jpeg")
|
147 |
|