Spaces:
Sleeping
Sleeping
from typing import List, Optional, Union | |
from starlette.concurrency import run_in_threadpool | |
from fastapi import FastAPI, APIRouter | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, Field | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.embeddings import HuggingFaceInstructEmbeddings | |
from langchain.embeddings import HuggingFaceBgeEmbeddings | |
import os | |
router = APIRouter() | |
DEFAULT_MODEL_NAME = "intfloat/e5-large-v2" | |
E5_EMBED_INSTRUCTION = "passage: " | |
E5_QUERY_INSTRUCTION = "query: " | |
BGE_EN_QUERY_INSTRUCTION = "Represent this sentence for searching relevant passages: " | |
BGE_ZH_QUERY_INSTRUCTION = "为这个句子生成表示以用于检索相关文章:" | |
def create_app(): | |
app = FastAPI( | |
title="Open Text Embeddings API", | |
version="0.0.2", | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
app.include_router(router) | |
return app | |
class CreateEmbeddingRequest(BaseModel): | |
model: Optional[str] = Field( | |
description="The model to use for generating embeddings.") | |
input: Union[str, List[str]] = Field(description="The input to embed.") | |
user: Optional[str] | |
class Config: | |
schema_extra = { | |
"example": { | |
"input": "The food was delicious and the waiter...", | |
} | |
} | |
class Embedding(BaseModel): | |
embedding: List[float] | |
class CreateEmbeddingResponse(BaseModel): | |
data: List[Embedding] | |
embeddings = None | |
def _create_embedding( | |
request: CreateEmbeddingRequest | |
): | |
global embeddings | |
if embeddings is None: | |
if request.model and request.model != "text-embedding-ada-002": | |
model_name = request.model | |
else: | |
model_name = os.environ["MODEL"] | |
print("Loading model:", model_name) | |
encode_kwargs = { | |
"normalize_embeddings": bool(os.environ.get("NORMALIZE_EMBEDDINGS", "")) | |
} | |
print("encode_kwargs", encode_kwargs) | |
if "e5" in model_name: | |
embeddings = HuggingFaceInstructEmbeddings(model_name=model_name, | |
embed_instruction=E5_EMBED_INSTRUCTION, | |
query_instruction=E5_QUERY_INSTRUCTION, | |
encode_kwargs=encode_kwargs) | |
elif model_name.startswith("BAAI/bge-") and model_name.endswith("-en"): | |
embeddings = HuggingFaceBgeEmbeddings(model_name=model_name, | |
query_instruction=BGE_EN_QUERY_INSTRUCTION, | |
encode_kwargs=encode_kwargs) | |
elif model_name.startswith("BAAI/bge-") and model_name.endswith("-zh"): | |
embeddings = HuggingFaceBgeEmbeddings(model_name=model_name, | |
query_instruction=BGE_ZH_QUERY_INSTRUCTION, | |
encode_kwargs=encode_kwargs) | |
else: | |
embeddings = HuggingFaceEmbeddings( | |
model_name=model_name, encode_kwargs=encode_kwargs) | |
if isinstance(request.input, str): | |
return CreateEmbeddingResponse(data=[Embedding(embedding=embeddings.embed_query(request.input))]) | |
else: | |
data = [Embedding(embedding=embedding) | |
for embedding in embeddings.embed_documents(request.input)] | |
return CreateEmbeddingResponse(data=data) | |
async def create_embedding( | |
request: CreateEmbeddingRequest | |
): | |
return _create_embedding(request) | |
# throw TypeError: 'CreateEmbeddingResponse' object is not callable? | |
# return await run_in_threadpool( | |
# _create_embedding(request) | |
# ) | |