|
from typing import Dict, List, Any |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) |
|
self.model = AutoModel.from_pretrained(path, trust_remote_code=True).half().cuda() |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: |
|
""" |
|
Args: |
|
data (:dict:): |
|
The payload with the text prompt and generation parameters. |
|
""" |
|
|
|
inputs = data.pop("inputs", data) |
|
history = data.pop("history", None) |
|
|
|
response, new_history = self.model.chat(self.tokenizer, inputs, history) |
|
|
|
return [{"generated_text": response, "history": new_history}] |