anttip's picture
Create handler.py
0012f15
raw
history blame
No virus
1.88 kB
from typing import Dict, List, Any
from transformers import AutoTokenizer
from translate import EncoderCT2fromHfHub
import os
class EndpointHandler():
def __init__(self, path="", local_test=False):
path = "tmp/ct2fast-e5-small-v2-hfie"
snapshot_id = "1"
if local_test:
repo_dir = os.getcwd()
else:
repo_dir = "/repository/"
cache_dir = os.path.join(os.path.expanduser("~/.cache/huggingface/hub/models--" + path.replace("/", "--")))
snapshot_dir = cache_dir + "/snapshots/" + snapshot_id
os.makedirs(cache_dir + "/refs", exist_ok=True)
os.makedirs(snapshot_dir, exist_ok=True)
with open(cache_dir + "/refs/main", 'w') as filee:
filee.write(snapshot_id)
for filee in "config.json", "model.bin", "tokenizer_config.json", "tokenizer.json", "vocabulary.txt":
# Make symbolic links
link = os.path.join(snapshot_dir, filee)
if not(os.path.exists(link)): os.symlink(os.path.join(repo_dir,filee), link)
self.model = EncoderCT2fromHfHub(
model_name_or_path=path,
device="cuda",
compute_type="int8_float16"
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs",data)
outputs = self.model.generate(text=[inputs])
return outputs["pooler_output"].tolist()
# Test code
if __name__ == '__main__':
from handler import EndpointHandler
my_handler = EndpointHandler(path=".", local_test=True)
inputs = ['The quick brown fox jumps over the lazy dog']
for input in inputs:
response = my_handler({"inputs": input})
print(response)