|
from typing import Dict, List, Any |
|
from transformers import AutoTokenizer |
|
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig |
|
import torch |
|
from loguru import logger |
|
|
|
MAX_INPUT_TOKEN_LENGTH = 4000 |
|
MAX_MAX_NEW_TOKENS = 2048 |
|
DEFAULT_MAX_NEW_TOKENS = 1024 |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.model = AutoGPTQForCausalLM.from_quantized(path, device_map="auto", use_safetensors=True) |
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
|
def get_input_token_length(self, message: str) -> int: |
|
input_ids = self.tokenizer([message], return_tensors='np', add_special_tokens=False)['input_ids'] |
|
return input_ids.shape[-1] |
|
|
|
def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
|
inputs = data.pop("inputs", data) |
|
parameters = data.pop("parameters", {}) |
|
|
|
parameters["max_new_tokens"] = parameters.pop("max_new_tokens", DEFAULT_MAX_NEW_TOKENS) |
|
|
|
if parameters["max_new_tokens"] > MAX_MAX_NEW_TOKENS: |
|
logger.error(f"requested max_new_tokens too high (> {MAX_MAX_NEW_TOKENS})") |
|
return [{"generated_text": None, "error": f"requested max_new_tokens too high (> {MAX_MAX_NEW_TOKENS})"}] |
|
|
|
input_token_length = self.get_input_token_length(inputs) |
|
if input_token_length > MAX_INPUT_TOKEN_LENGTH: |
|
logger.error(f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})") |
|
return [{"generated_text": None, "error": f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})"}] |
|
|
|
logger.info(f"inputs: {inputs}") |
|
input_ids = self.tokenizer(inputs, return_tensors="pt").to(self.model.device) |
|
|
|
outputs = self.model.generate(**input_ids, **parameters) |
|
|
|
prediction = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
return [{"generated_text": prediction}] |