RoniFinTech commited on
Commit
3e81177
1 Parent(s): afc8995
cache/local_cache.py CHANGED
@@ -1,5 +1,8 @@
1
  from datetime import datetime, timedelta
2
  from functools import wraps
 
 
 
3
 
4
  CACHE_SIZE = 50
5
 
@@ -7,8 +10,7 @@ _cache = {}
7
  _cache_time = {}
8
 
9
 
10
- def ttl_cache(key_name, ttl_secs=20):
11
-
12
  def decorator(func):
13
  @wraps(func)
14
  async def wrapper(*args, **kwargs):
@@ -22,7 +24,10 @@ def ttl_cache(key_name, ttl_secs=20):
22
  del _cache[key]
23
  del _cache_time[key]
24
  else:
25
- return _cache[key]
 
 
 
26
 
27
  # Call the actual function if not in cache or expired
28
  response = await func(*args, **kwargs)
 
1
  from datetime import datetime, timedelta
2
  from functools import wraps
3
+ from io import BytesIO
4
+
5
+ from fastapi.responses import StreamingResponse
6
 
7
  CACHE_SIZE = 50
8
 
 
10
  _cache_time = {}
11
 
12
 
13
+ def ttl_cache(key_name, media_type=None, ttl_secs=20):
 
14
  def decorator(func):
15
  @wraps(func)
16
  async def wrapper(*args, **kwargs):
 
24
  del _cache[key]
25
  del _cache_time[key]
26
  else:
27
+ if media_type == 'image/png':
28
+ return StreamingResponse(BytesIO(_cache[key]), media_type=media_type)
29
+ else:
30
+ return _cache[key]
31
 
32
  # Call the actual function if not in cache or expired
33
  response = await func(*args, **kwargs)
routers/intference/stable_diffusion.py CHANGED
@@ -32,7 +32,7 @@ refiner.enable_attention_slicing()
32
 
33
 
34
  @router.get("/generate")
35
- @ttl_cache(key_name='prompt', ttl_secs=20)
36
  async def generate(prompt: str):
37
  """
38
  generate image
 
32
 
33
 
34
  @router.get("/generate")
35
+ @ttl_cache(key_name='prompt', media_type="image/png", ttl_secs=20)
36
  async def generate(prompt: str):
37
  """
38
  generate image