Den-d3j2d's picture
Update app.py
a356454 verified
raw
history blame
585 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import torch
from transformers import AutoTokenizer, AutoModel
from src.model.encoder import ProdFeatureEncoder
from src.config.config import ModelConfig
app = FastAPI()
class EmbeddingOutput(BaseModel):
embedding: List[float]
config = ModelConfig()
model = ProdFeatureEncoder(config=config)
@app.get("/encode_text/{text}", response_model=EmbeddingOutput)
async def encode_text(text: str):
with torch.no_grad():
embedding = model(text)
return {"embedding": embedding.tolist()}