File size: 3,227 Bytes
74b889b
 
 
 
 
 
 
 
 
 
 
 
 
20ba605
74b889b
 
5562491
74b889b
 
 
 
 
 
 
 
 
 
 
20ba605
74b889b
 
 
 
 
 
 
 
 
 
 
 
 
 
48cf048
74b889b
 
 
 
 
 
 
 
 
48cf048
74b889b
20ba605
 
74b889b
 
 
 
 
 
 
 
 
 
 
 
 
 
48cf048
 
 
74b889b
 
 
48cf048
74b889b
 
 
 
 
 
 
 
 
 
 
 
 
 
48cf048
74b889b
 
 
 
 
 
48cf048
74b889b
 
 
 
 
 
 
 
 
 
 
5562491
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
import markdown2
import sys
import uvicorn

from pathlib import Path
from typing import Union, Optional

from fastapi import FastAPI
from pydantic import BaseModel, Field
from fastapi.responses import HTMLResponse
from tclogger import logger, OSEnver

from transforms.embed import JinaAIEmbedder, JinaAIOnnxEmbedder
from configs.constants import AVAILABLE_MODELS

info_path = Path(__file__).parents[1] / "configs" / "info.json"
ENVER = OSEnver(info_path)


class EmbeddingApp:
    def __init__(self):
        self.app = FastAPI(
            docs_url="/",
            title=ENVER["app_name"],
            swagger_ui_parameters={"defaultModelsExpandDepth": -1},
            version=ENVER["version"],
        )
        self.embedder = JinaAIOnnxEmbedder()
        self.setup_routes()

    def get_available_models(self):
        return AVAILABLE_MODELS

    def get_readme(self):
        readme_path = Path(__file__).parents[1] / "README.md"
        with open(readme_path, "r", encoding="utf-8") as rf:
            readme_str = rf.read()
        readme_html = markdown2.markdown(
            readme_str, extras=["table", "fenced-code-blocks", "highlightjs-lang"]
        )
        return readme_html

    class CalcEmbeddingPostItem(BaseModel):
        text: Union[str, list[str]] = Field(
            default=None,
            summary="Input text(s) to embed",
        )
        model: Optional[str] = Field(
            default=AVAILABLE_MODELS[0],
            summary="Embedding model name",
        )

    def calc_embedding(self, item: CalcEmbeddingPostItem):
        logger.note(f"> Encoding text: [{item.text}]", end=" ")
        # if item.model != self.embedder.model:
        #     self.embedder.switch_model(item.model)
        embeddings = self.embedder.encode(item.text).tolist()
        logger.success(f"[{len(embeddings[0])}]")
        if len(embeddings) == 1:
            return embeddings[0]
        else:
            return embeddings

    def setup_routes(self):
        self.app.get(
            "/models",
            summary="Get available models",
        )(self.get_available_models)

        self.app.post(
            "/embedding",
            summary="Calculate embedding for input text(s)",
        )(self.calc_embedding)

        self.app.get(
            "/readme",
            summary="README of Embed API",
            response_class=HTMLResponse,
            include_in_schema=False,
        )(self.get_readme)


class ArgParser(argparse.ArgumentParser):
    def __init__(self, *args, **kwargs):
        super(ArgParser, self).__init__(*args, **kwargs)

        self.add_argument(
            "-s",
            "--server",
            type=str,
            default=ENVER["server"],
            help=f"Server IP ({ENVER['server']}) for Embed API",
        )
        self.add_argument(
            "-p",
            "--port",
            type=int,
            default=ENVER["port"],
            help=f"Server Port ({ENVER['port']}) for Embed API",
        )

        self.args = self.parse_args(sys.argv[1:])


app = EmbeddingApp().app

if __name__ == "__main__":
    args = ArgParser().args
    uvicorn.run("__main__:app", host=args.server, port=args.port)

    # python -m apps.app