Spaces:
Runtime error
Runtime error
import asyncio | |
import base64 | |
import logging | |
from io import BytesIO | |
from pathlib import Path | |
import uvicorn | |
from config import Config | |
from fastapi import FastAPI | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.staticfiles import StaticFiles | |
from PIL import Image | |
from pydantic import BaseModel | |
from wrapper import StreamDiffusionWrapper | |
logger = logging.getLogger("uvicorn") | |
PROJECT_DIR = Path(__file__).parent.parent | |
class PredictInputModel(BaseModel): | |
""" | |
The input model for the /predict endpoint. | |
""" | |
prompt: str | |
class PredictResponseModel(BaseModel): | |
""" | |
The response model for the /predict endpoint. | |
""" | |
base64_images: list[str] | |
class UpdatePromptResponseModel(BaseModel): | |
""" | |
The response model for the /update_prompt endpoint. | |
""" | |
prompt: str | |
class Api: | |
def __init__(self, config: Config) -> None: | |
""" | |
Initialize the API. | |
Parameters | |
---------- | |
config : Config | |
The configuration. | |
""" | |
self.config = config | |
self.stream_diffusion = StreamDiffusionWrapper( | |
model_id=config.model_id, | |
lcm_lora_id=config.lcm_lora_id, | |
vae_id=config.vae_id, | |
device=config.device, | |
dtype=config.dtype, | |
t_index_list=config.t_index_list, | |
warmup=config.warmup, | |
safety_checker=config.safety_checker, | |
) | |
self.app = FastAPI() | |
self.app.add_api_route( | |
"/api/predict", | |
self._predict, | |
methods=["POST"], | |
response_model=PredictResponseModel, | |
) | |
self.app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
self.app.mount( | |
"/", StaticFiles(directory="../view/build", html=True), name="public" | |
) | |
self._predict_lock = asyncio.Lock() | |
self._update_prompt_lock = asyncio.Lock() | |
self.last_prompt: str = "" | |
self.last_images: list[str] = [""] | |
async def _predict(self, inp: PredictInputModel) -> PredictResponseModel: | |
""" | |
Predict an image and return. | |
Parameters | |
---------- | |
inp : PredictInputModel | |
The input. | |
Returns | |
------- | |
PredictResponseModel | |
The prediction result. | |
""" | |
async with self._predict_lock: | |
if ( | |
self._calc_levenstein_distance(inp.prompt, self.last_prompt) | |
< self.config.levenstein_distance_threshold | |
): | |
logger.info("Using cached images") | |
return PredictResponseModel(base64_images=self.last_images) | |
self.last_prompt = inp.prompt | |
self.last_images = [self._pil_to_base64(image) for image in self.stream_diffusion(inp.prompt)] | |
return PredictResponseModel(base64_images=self.last_images) | |
def _pil_to_base64(self, image: Image.Image, format: str = "JPEG") -> bytes: | |
""" | |
Convert a PIL image to base64. | |
Parameters | |
---------- | |
image : Image.Image | |
The PIL image. | |
format : str | |
The image format, by default "JPEG". | |
Returns | |
------- | |
bytes | |
The base64 image. | |
""" | |
buffered = BytesIO() | |
image.convert("RGB").save(buffered, format=format) | |
return base64.b64encode(buffered.getvalue()).decode("ascii") | |
def _base64_to_pil(self, base64_image: str) -> Image.Image: | |
""" | |
Convert a base64 image to PIL. | |
Parameters | |
---------- | |
base64_image : str | |
The base64 image. | |
Returns | |
------- | |
Image.Image | |
The PIL image. | |
""" | |
if "base64," in base64_image: | |
base64_image = base64_image.split("base64,")[1] | |
return Image.open(BytesIO(base64.b64decode(base64_image))).convert("RGB") | |
def _calc_levenstein_distance(self, a: str, b: str) -> int: | |
""" | |
Calculate the Levenstein distance between two strings. | |
Parameters | |
---------- | |
a : str | |
The first string. | |
b : str | |
The second string. | |
Returns | |
------- | |
int | |
The Levenstein distance. | |
""" | |
if a == b: | |
return 0 | |
a_k = len(a) | |
b_k = len(b) | |
if a == "": | |
return b_k | |
if b == "": | |
return a_k | |
matrix = [[] for i in range(a_k + 1)] | |
for i in range(a_k + 1): | |
matrix[i] = [0 for j in range(b_k + 1)] | |
for i in range(a_k + 1): | |
matrix[i][0] = i | |
for j in range(b_k + 1): | |
matrix[0][j] = j | |
for i in range(1, a_k + 1): | |
ac = a[i - 1] | |
for j in range(1, b_k + 1): | |
bc = b[j - 1] | |
cost = 0 if (ac == bc) else 1 | |
matrix[i][j] = min( | |
[ | |
matrix[i - 1][j] + 1, | |
matrix[i][j - 1] + 1, | |
matrix[i - 1][j - 1] + cost, | |
] | |
) | |
return matrix[a_k][b_k] | |
if __name__ == "__main__": | |
from config import Config | |
config = Config() | |
uvicorn.run( | |
Api(config).app, | |
host=config.host, | |
port=config.port, | |
workers=config.workers, | |
) | |