File size: 3,911 Bytes
b92d070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30b33a5
b92d070
30b33a5
 
 
 
 
 
 
 
 
b92d070
30b33a5
b92d070
 
 
 
 
 
 
 
 
 
 
 
 
 
9190101
 
b92d070
 
 
 
9190101
 
b92d070
 
 
 
 
 
 
 
 
 
 
 
 
 
df03c96
 
b92d070
 
df03c96
 
b92d070
 
 
 
9190101
 
b92d070
 
9190101
b92d070
 
 
 
 
 
 
 
 
 
9190101
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

from typing import List, Optional, Union
from starlette.concurrency import run_in_threadpool
from fastapi import FastAPI, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.embeddings import HuggingFaceBgeEmbeddings
import os

router = APIRouter()

DEFAULT_MODEL_NAME = "intfloat/e5-large-v2"
E5_EMBED_INSTRUCTION = "passage: "
E5_QUERY_INSTRUCTION = "query: "
BGE_EN_QUERY_INSTRUCTION = "Represent this sentence for searching relevant passages: "
BGE_ZH_QUERY_INSTRUCTION = "为这个句子生成表示以用于检索相关文章:"


def create_app():
    app = FastAPI(
        title="Open Text Embeddings API",
        version="0.0.2",
    )
    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )
    app.include_router(router)

    return app


class CreateEmbeddingRequest(BaseModel):
    model: Optional[str] = Field(
        description="The model to use for generating embeddings.", default=None)
    input: Union[str, List[str]] = Field(description="The input to embed.")
    user: Optional[str] = Field(default=None)

    model_config = {
        "json_schema_extra": {
            "examples": [
                {
                    "input": "The food was delicious and the waiter...",
                }
            ]
        }
    }


class Embedding(BaseModel):
    embedding: List[float]


class CreateEmbeddingResponse(BaseModel):
    data: List[Embedding]


embeddings = None


def _create_embedding(
    model: Optional[str],
    input: Union[str, List[str]]
):
    global embeddings

    if embeddings is None:
        if model and model != "text-embedding-ada-002":
            model_name = model
        else:
            model_name = os.environ["MODEL"]
        print("Loading model:", model_name)
        encode_kwargs = {
            "normalize_embeddings": bool(os.environ.get("NORMALIZE_EMBEDDINGS", ""))
        }
        print("encode_kwargs", encode_kwargs)
        if "e5" in model_name:
            embeddings = HuggingFaceInstructEmbeddings(model_name=model_name,
                                                       embed_instruction=E5_EMBED_INSTRUCTION,
                                                       query_instruction=E5_QUERY_INSTRUCTION,
                                                       encode_kwargs=encode_kwargs)
        elif model_name.startswith("BAAI/bge-") and model_name.endswith("-en"):
            embeddings = HuggingFaceBgeEmbeddings(model_name=model_name,
                                                  query_instruction=BGE_EN_QUERY_INSTRUCTION,
                                                  encode_kwargs=encode_kwargs)
        elif model_name.startswith("BAAI/bge-") and model_name.endswith("-zh"):
            embeddings = HuggingFaceBgeEmbeddings(model_name=model_name,
                                                  query_instruction=BGE_ZH_QUERY_INSTRUCTION,
                                                  encode_kwargs=encode_kwargs)
        else:
            embeddings = HuggingFaceEmbeddings(
                model_name=model_name, encode_kwargs=encode_kwargs)

    if isinstance(input, str):
        return CreateEmbeddingResponse(data=[Embedding(embedding=embeddings.embed_query(input))])
    else:
        data = [Embedding(embedding=embedding)
                for embedding in embeddings.embed_documents(input)]
        return CreateEmbeddingResponse(data=data)


@router.post(
    "/v1/embeddings",
    response_model=CreateEmbeddingResponse,
)
async def create_embedding(
    request: CreateEmbeddingRequest
):
    return await run_in_threadpool(
        _create_embedding, **request.dict(exclude={"user"})
    )