|
from typing import Dict, List, Any |
|
|
|
|
|
|
|
|
|
from exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig |
|
from exllama.tokenizer import ExLlamaTokenizer |
|
from exllama.generator import ExLlamaGenerator |
|
|
|
|
|
|
|
repo = "kajdun/iubaris-13b-v3_GPTQ" |
|
model_directory = "/repository/" |
|
|
|
tokenizer_path = f"{model_directory}tokenizer.model" |
|
model_config_path = f"{model_directory}config.json" |
|
model_path = f"{model_directory}gptq_model-4bit-128g.safetensors" |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
config = ExLlamaConfig(model_config_path) |
|
config.model_path = model_path |
|
|
|
model = ExLlama(config) |
|
tokenizer = ExLlamaTokenizer(tokenizer_path) |
|
cache = ExLlamaCache(model) |
|
self.generator = ExLlamaGenerator(model, tokenizer, cache) |
|
self.generator.disallow_tokens([tokenizer.eos_token_id]) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str` | `PIL.Image` | `np.array`) |
|
kwargs |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
|
|
inputs = data.pop("inputs", data) |
|
parameters = data.pop("parameters", None) |
|
|
|
self.generator.settings.token_repetition_penalty_max = 1.0 |
|
self.generator.settings.temperature = 0.9 |
|
self.generator.settings.top_p = 0.6 |
|
self.generator.settings.top_k = 100 |
|
self.generator.settings.typical = 0.5 |
|
|
|
output = self.generator.generate_simple(inputs, max_new_tokens = 50) |
|
return [{"generated_text": output[len(inputs):]}] |