|
import torch |
|
import transformers |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
model_id = 'meta-llama/Llama-2-13b-chat-hf' |
|
model_config = transformers.AutoConfig.from_pretrained( |
|
model_id |
|
) |
|
self.model = transformers.AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
trust_remote_code=True, |
|
config=model_config, |
|
device_map='auto' |
|
) |
|
self.model.eval() |
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
model_id, |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
inputs = data.pop("input",data) |
|
return self.embed(inputs) |
|
|
|
def embed(self, text): |
|
with torch.no_grad(): |
|
encoded_input = self.tokenizer(text, return_tensors="pt") |
|
model_output = self.model(**encoded_input, output_hidden_states=True) |
|
last_four_layers = model_output[2][-4:] |
|
return torch.stack(last_four_layers).mean(dim=0).mean(dim=1) |