import os import json import logging import torch from ts.torch_handler.base_handler import BaseHandler from sentence_transformers import SentenceTransformer logger = logging.getLogger(__name__) class SentenceTransformerHandler(BaseHandler): def __init__(self): super().__init__() self.initialized = False def initialize(self, context): """Initialize model and transform function.""" self.manifest = context.manifest properties = context.system_properties model_dir = properties.get("model_dir") # Load the model logger.info(f"Loading model from {model_dir}") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = SentenceTransformer(model_dir, device=self.device) # Set model to eval mode self.model.eval() torch.set_grad_enabled(False) logger.info(f"Model loaded successfully. Embedding dimension: {self.model.get_sentence_embedding_dimension()}") self.initialized = True def preprocess(self, requests): """Extract texts from the requests.""" input_batch = [] for request in requests: data = request.get("data") if data is None: data = request.get("body") input_data = json.loads(data) if isinstance(data, str) else data texts = input_data.get("texts", []) if isinstance(texts, str): texts = [texts] input_batch.extend(texts) return input_batch def inference(self, input_batch): """Generate embeddings for the input texts.""" embeddings = self.model.encode( input_batch, batch_size=32, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True ) return embeddings def postprocess(self, inference_output): """Format the output as expected by Vertex AI.""" return [{ "embeddings": embeddings.tolist(), "metadata": { "embedding_dim": len(embeddings) } } for embeddings in inference_output]