from typing import Optional from fastapi import APIRouter from pydantic import BaseModel from transformers import AutoTokenizer, AutoModel import torch from datetime import datetime from logger import log from config import TEST_MODE device = "cuda:0" if torch.cuda.is_available() else "cpu" router = APIRouter() class SentenceEmbeddingsInput(BaseModel): inputs: list[str] model: str parameters: dict class SentenceEmbeddingsOutput(BaseModel): embeddings: Optional[list[list[float]]] = None error: Optional[str] = None @router.post('/sentence-embeddings') def sentence_embeddings(inputs: SentenceEmbeddingsInput): start_time = datetime.now() fn = sentence_embeddings_mapping.get(inputs.model) if not fn: return SentenceEmbeddingsOutput( error=f'No sentence embeddings model found for {inputs.model}' ) try: embeddings = fn(inputs.inputs, inputs.parameters) log({ "task": "sentence_embeddings", "model": inputs.model, "start_time": start_time.isoformat(), "time_taken": (datetime.now() - start_time).total_seconds(), "inputs": inputs.inputs, "outputs": embeddings, "parameters": inputs.parameters, }) loaded_models_last_updated[inputs.model] = datetime.now() return SentenceEmbeddingsOutput( embeddings=embeddings ) except Exception as e: return SentenceEmbeddingsOutput( error=str(e) ) def generic_sentence_embeddings(model_name: str): global loaded_models def process_texts(texts: list[str], parameters: dict): if TEST_MODE: return [[0.1,0.2]] * len(texts) if model_name in loaded_models: tokenizer, model = loaded_models[model_name] else: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name).to(device) loaded_models[model] = (tokenizer, model) # Tokenize sentences encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(device) with torch.no_grad(): model_output = model(**encoded_input) sentence_embeddings = model_output[0][:, 0] # normalize embeddings sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) return sentence_embeddings.tolist() return process_texts # Polling every X minutes to loaded_models = {} loaded_models_last_updated = {} sentence_embeddings_mapping = { 'BAAI/bge-base-en-v1.5': generic_sentence_embeddings('BAAI/bge-base-en-v1.5'), 'BAAI/bge-large-en-v1.5': generic_sentence_embeddings('BAAI/bge-large-en-v1.5'), }