radames commited on
Commit
9bba95c
1 Parent(s): 96b49cf

lock with async wait

Browse files
Files changed (1) hide show
  1. app.py +17 -10
app.py CHANGED
@@ -17,6 +17,8 @@ import uuid
17
  import logging
18
  from fastapi import FastAPI, Request, HTTPException
19
  from fastapi.middleware.cors import CORSMiddleware
 
 
20
 
21
  logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
22
 
@@ -24,7 +26,7 @@ MAX_SEED = np.iinfo(np.int32).max
24
  USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "0") == "1"
25
  SPACE_ID = os.environ.get("SPACE_ID", "")
26
  DEV = os.environ.get("DEV", "0") == "1"
27
- os.environ['TOKENIZERS_PARALLELISM'] = 'false'
28
 
29
  DB_PATH = Path("/data/cache") if SPACE_ID else Path("./cache")
30
  IMGS_PATH = DB_PATH / "imgs"
@@ -32,6 +34,7 @@ DB_PATH.mkdir(exist_ok=True, parents=True)
32
  IMGS_PATH.mkdir(exist_ok=True, parents=True)
33
 
34
  database = Database(DB_PATH)
 
35
 
36
  dtype = torch.bfloat16
37
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -120,21 +123,25 @@ app.add_middleware(
120
 
121
 
122
  @app.get("/image")
123
- async def generate_image(prompt: str, negative_prompt: str = "", seed: int = 2134213213):
 
 
124
  cached_img = database.check(prompt, negative_prompt, seed)
125
  if cached_img:
126
  logging.info(f"Image found in cache: {cached_img[0]}")
127
  return StreamingResponse(open(cached_img[0], "rb"), media_type="image/jpeg")
128
 
129
  logging.info(f"Image not found in cache, generating new image")
130
- pil_image = generate(prompt, negative_prompt, seed)
131
- img_id = str(uuid.uuid4())
132
- img_path = IMGS_PATH / f"{img_id}.jpg"
133
- pil_image.save(img_path)
134
- img_io = io.BytesIO()
135
- pil_image.save(img_io, "JPEG")
136
- img_io.seek(0)
137
- database.insert(prompt, negative_prompt, str(img_path), seed)
 
 
138
 
139
  return StreamingResponse(img_io, media_type="image/jpeg")
140
 
 
17
  import logging
18
  from fastapi import FastAPI, Request, HTTPException
19
  from fastapi.middleware.cors import CORSMiddleware
20
+ from asyncio import Lock
21
+
22
 
23
  logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
24
 
 
26
  USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "0") == "1"
27
  SPACE_ID = os.environ.get("SPACE_ID", "")
28
  DEV = os.environ.get("DEV", "0") == "1"
29
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
30
 
31
  DB_PATH = Path("/data/cache") if SPACE_ID else Path("./cache")
32
  IMGS_PATH = DB_PATH / "imgs"
 
34
  IMGS_PATH.mkdir(exist_ok=True, parents=True)
35
 
36
  database = Database(DB_PATH)
37
+ generate_lock = Lock()
38
 
39
  dtype = torch.bfloat16
40
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
123
 
124
 
125
  @app.get("/image")
126
+ async def generate_image(
127
+ prompt: str, negative_prompt: str = "", seed: int = 2134213213
128
+ ):
129
  cached_img = database.check(prompt, negative_prompt, seed)
130
  if cached_img:
131
  logging.info(f"Image found in cache: {cached_img[0]}")
132
  return StreamingResponse(open(cached_img[0], "rb"), media_type="image/jpeg")
133
 
134
  logging.info(f"Image not found in cache, generating new image")
135
+ async with generate_lock:
136
+
137
+ pil_image = generate(prompt, negative_prompt, seed)
138
+ img_id = str(uuid.uuid4())
139
+ img_path = IMGS_PATH / f"{img_id}.jpg"
140
+ pil_image.save(img_path)
141
+ img_io = io.BytesIO()
142
+ pil_image.save(img_io, "JPEG")
143
+ img_io.seek(0)
144
+ database.insert(prompt, negative_prompt, str(img_path), seed)
145
 
146
  return StreamingResponse(img_io, media_type="image/jpeg")
147