File size: 3,940 Bytes
b92d070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

from typing import List, Optional, Union
from starlette.concurrency import run_in_threadpool
from fastapi import FastAPI, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings
import os

router = APIRouter()

DEFAULT_MODEL_NAME = "intfloat/e5-large-v2"
E5_EMBED_INSTRUCTION = "passage: "
E5_QUERY_INSTRUCTION = "query: "
BGE_EN_QUERY_INSTRUCTION = "Represent this sentence for searching relevant passages: "
BGE_ZH_QUERY_INSTRUCTION = "为这个句子生成表示以用于检索相关文章:"


def create_app():
    app = FastAPI(
        title="Open Text Embeddings API",
        version="0.0.2",
    )
    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )
    app.include_router(router)

    return app


class CreateEmbeddingRequest(BaseModel):
    model: Optional[str] = Field(
        description="The model to use for generating embeddings.")
    input: Union[str, List[str]] = Field(description="The input to embed.")
    user: Optional[str]

    class Config:
        schema_extra = {
            "example": {
                "input": "The food was delicious and the waiter...",
            }
        }


class Embedding(BaseModel):
    embedding: List[float]


class CreateEmbeddingResponse(BaseModel):
    data: List[Embedding]


embeddings = None


def _create_embedding(
    request: CreateEmbeddingRequest
):
    global embeddings

    if embeddings is None:
        if request.model and request.model != "text-embedding-ada-002":
            model_name = request.model
        else:
            model_name = os.environ["MODEL"]
        print("Loading model:", model_name)
        encode_kwargs = {
            "normalize_embeddings": bool(os.environ.get("NORMALIZE_EMBEDDINGS", ""))
        }
        print("encode_kwargs", encode_kwargs)
        if "e5" in model_name:
            embeddings = HuggingFaceInstructEmbeddings(model_name=model_name,
                                                       embed_instruction=E5_EMBED_INSTRUCTION,
                                                       query_instruction=E5_QUERY_INSTRUCTION,
                                                       encode_kwargs=encode_kwargs)
        elif model_name.startswith("BAAI/bge-") and model_name.endswith("-en"):
            embeddings = HuggingFaceBgeEmbeddings(model_name=model_name,
                                                    query_instruction=BGE_EN_QUERY_INSTRUCTION,
                                                    encode_kwargs=encode_kwargs)
        elif model_name.startswith("BAAI/bge-") and model_name.endswith("-zh"):
            embeddings = HuggingFaceBgeEmbeddings(model_name=model_name,
                                                    query_instruction=BGE_ZH_QUERY_INSTRUCTION,
                                                    encode_kwargs=encode_kwargs)
        else:
            embeddings = HuggingFaceEmbeddings(
                model_name=model_name, encode_kwargs=encode_kwargs)

    if isinstance(request.input, str):
        return CreateEmbeddingResponse(data=[Embedding(embedding=embeddings.embed_query(request.input))])
    else:
        data = [Embedding(embedding=embedding)
                for embedding in embeddings.embed_documents(request.input)]
        return CreateEmbeddingResponse(data=data)


@router.post(
    "/v1/embeddings",
    response_model=CreateEmbeddingResponse,
)
async def create_embedding(
    request: CreateEmbeddingRequest
):
    return _create_embedding(request)
#    throw TypeError: 'CreateEmbeddingResponse' object is not callable?
#    return await run_in_threadpool(
#        _create_embedding(request)
#    )