from typing import Dict, List, Any from transformers import AutoTokenizer from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig import torch # check for GPU #device = 0 if torch.cuda.is_available() else -1 MAX_INPUT_TOKEN_LENGTH = 4000 MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 class EndpointHandler(): def __init__(self, path=""): self.model = AutoGPTQForCausalLM.from_quantized(path, device_map="auto", use_safetensors=True) self.tokenizer = AutoTokenizer.from_pretrained(path) def get_input_token_length(self, message: str) -> int: input_ids = self.tokenizer([message], return_tensors='np', add_special_tokens=False)['input_ids'] return input_ids.shape[-1] def __call__(self, data: Any) -> List[List[Dict[str, float]]]: inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) parameters["max_new_tokens"] = parameters.pop("max_new_tokens", DEFAULT_MAX_NEW_TOKENS) if parameters["max_new_tokens"] > MAX_MAX_NEW_TOKENS: return [{"generated_text": None, "error": f"requested max_new_tokens too high (> {MAX_MAX_NEW_TOKENS})"}] input_token_length = self.get_input_token_length(inputs) if input_token_length > MAX_INPUT_TOKEN_LENGTH: return [{"generated_text": None, "error": f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})"}] #input_ids = self.tokenizer(inputs, return_tensors="pt").to(self.model.device) input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids outputs = self.model.generate(**input_ids, **parameters) prediction = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return [{"generated_text": prediction}]