import torch import transformers from typing import Dict, Any, List class EndpointHandler(): def __init__(self, path=""): model_id = 'meta-llama/Llama-2-13b-chat-hf' # "meta-llama/Llama-2-13b-chat-hf" model_config = transformers.AutoConfig.from_pretrained( model_id ) self.model = transformers.AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, config=model_config, device_map='auto' ) self.model.eval() self.tokenizer = transformers.AutoTokenizer.from_pretrained( model_id, ) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: inputs = data.pop("input",data) return self.embed(inputs) def embed(self, text): with torch.no_grad(): encoded_input = self.tokenizer(text, return_tensors="pt") model_output = self.model(**encoded_input, output_hidden_states=True) last_four_layers = model_output[2][-4:] return torch.stack(last_four_layers).mean(dim=0).mean(dim=1)