File size: 1,917 Bytes
6b8946b a77b076 6b8946b a77b076 9a254c2 a49c66b 2e55845 023c0f7 a49c66b a1c6e67 a77b076 a1c6e67 a77b076 1ecbe19 a77b076 3382124 50af3a4 6b8946b a49c66b 3382124 db34714 3382124 a49c66b 2e55845 a77b076 db34714 3382124 a77b076 db34714 9a254c2 a1c6e67 3382124 a49c66b a1c6e67 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from typing import Dict, List, Any
from transformers import AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import torch
from loguru import logger
MAX_INPUT_TOKEN_LENGTH = 4000
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
class EndpointHandler():
def __init__(self, path=""):
self.model = AutoGPTQForCausalLM.from_quantized(path, device_map="auto", use_safetensors=True)
self.tokenizer = AutoTokenizer.from_pretrained(path)
def get_input_token_length(self, message: str) -> int:
input_ids = self.tokenizer([message], return_tensors='np', add_special_tokens=False)['input_ids']
return input_ids.shape[-1]
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
parameters["max_new_tokens"] = parameters.pop("max_new_tokens", DEFAULT_MAX_NEW_TOKENS)
if parameters["max_new_tokens"] > MAX_MAX_NEW_TOKENS:
logger.error(f"requested max_new_tokens too high (> {MAX_MAX_NEW_TOKENS})")
return [{"generated_text": None, "error": f"requested max_new_tokens too high (> {MAX_MAX_NEW_TOKENS})"}]
input_token_length = self.get_input_token_length(inputs)
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
logger.error(f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})")
return [{"generated_text": None, "error": f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})"}]
logger.info(f"inputs: {inputs}")
input_ids = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(**input_ids, **parameters)
prediction = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return [{"generated_text": prediction}] |