from typing import Dict, List, Any from transformers import AutoTokenizer from translate import EncoderCT2fromHfHub import os class EndpointHandler(): def __init__(self, path="", local_test=False): path = "tmp/ct2fast-e5-small-v2-hfie" snapshot_id = "1" if local_test: repo_dir = os.getcwd() else: repo_dir = "/repository/" cache_dir = os.path.join(os.path.expanduser("~/.cache/huggingface/hub/models--" + path.replace("/", "--"))) snapshot_dir = cache_dir + "/snapshots/" + snapshot_id os.makedirs(cache_dir + "/refs", exist_ok=True) os.makedirs(snapshot_dir, exist_ok=True) with open(cache_dir + "/refs/main", 'w') as filee: filee.write(snapshot_id) for filee in "config.json", "model.bin", "tokenizer_config.json", "tokenizer.json", "vocabulary.txt": # Make symbolic links link = os.path.join(snapshot_dir, filee) if not(os.path.exists(link)): os.symlink(os.path.join(repo_dir,filee), link) self.model = EncoderCT2fromHfHub( model_name_or_path=path, device="cuda", compute_type="int8_float16" ) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ inputs = data.pop("inputs",data) outputs = self.model.generate(text=[inputs]) return outputs["pooler_output"].tolist() # Test code if __name__ == '__main__': from handler import EndpointHandler my_handler = EndpointHandler(path=".", local_test=True) inputs = ['The quick brown fox jumps over the lazy dog'] for input in inputs: response = my_handler({"inputs": input}) print(response)