File size: 1,917 Bytes
6b8946b
a77b076
6b8946b
a77b076
9a254c2
a49c66b
2e55845
 
 
023c0f7
a49c66b
a1c6e67
a77b076
a1c6e67
a77b076
1ecbe19
a77b076
3382124
50af3a4
6b8946b
a49c66b
3382124
 
 
 
 
db34714
3382124
a49c66b
2e55845
a77b076
db34714
3382124
a77b076
db34714
9a254c2
a1c6e67
3382124
a49c66b
a1c6e67
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import  Dict, List, Any
from transformers import AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import torch
from loguru import logger

MAX_INPUT_TOKEN_LENGTH  = 4000
MAX_MAX_NEW_TOKENS      = 2048
DEFAULT_MAX_NEW_TOKENS  = 1024

class EndpointHandler():
    def __init__(self, path=""):     
        self.model = AutoGPTQForCausalLM.from_quantized(path, device_map="auto", use_safetensors=True)
        self.tokenizer = AutoTokenizer.from_pretrained(path)

    def get_input_token_length(self, message: str) -> int:
        input_ids = self.tokenizer([message], return_tensors='np', add_special_tokens=False)['input_ids']
        return input_ids.shape[-1]      

    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", {})

        parameters["max_new_tokens"] = parameters.pop("max_new_tokens", DEFAULT_MAX_NEW_TOKENS)
        
        if parameters["max_new_tokens"] > MAX_MAX_NEW_TOKENS:
            logger.error(f"requested max_new_tokens too high (> {MAX_MAX_NEW_TOKENS})")
            return [{"generated_text": None, "error": f"requested max_new_tokens too high (> {MAX_MAX_NEW_TOKENS})"}] 

        input_token_length = self.get_input_token_length(inputs)
        if input_token_length > MAX_INPUT_TOKEN_LENGTH:
            logger.error(f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})")
            return [{"generated_text": None, "error": f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})"}] 

        logger.info(f"inputs: {inputs}")
        input_ids = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)

        outputs = self.model.generate(**input_ids, **parameters)

        prediction = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        return [{"generated_text": prediction}]