File size: 1,884 Bytes
0012f15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from typing import Dict, List, Any
from transformers import AutoTokenizer
from translate import EncoderCT2fromHfHub
import os

class EndpointHandler():
    def __init__(self, path="", local_test=False):
        path = "tmp/ct2fast-e5-small-v2-hfie"
        snapshot_id = "1"
        if local_test:
            repo_dir = os.getcwd()
        else:
            repo_dir = "/repository/"
        cache_dir = os.path.join(os.path.expanduser("~/.cache/huggingface/hub/models--" + path.replace("/", "--")))
        snapshot_dir = cache_dir + "/snapshots/" + snapshot_id
        os.makedirs(cache_dir + "/refs", exist_ok=True)
        os.makedirs(snapshot_dir, exist_ok=True)
        with open(cache_dir + "/refs/main", 'w') as filee:
            filee.write(snapshot_id)
        for filee in "config.json", "model.bin", "tokenizer_config.json", "tokenizer.json", "vocabulary.txt":
            # Make symbolic links
            link = os.path.join(snapshot_dir, filee)
            if not(os.path.exists(link)): os.symlink(os.path.join(repo_dir,filee), link)
        self.model = EncoderCT2fromHfHub(
                model_name_or_path=path,
                device="cuda",
                compute_type="int8_float16"
        )

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str`)
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        inputs = data.pop("inputs",data)
        outputs = self.model.generate(text=[inputs])
        return outputs["pooler_output"].tolist()

# Test code
if __name__ == '__main__':
    from handler import EndpointHandler
    my_handler = EndpointHandler(path=".", local_test=True)
    inputs = ['The quick brown fox jumps over the lazy dog']
    for input in inputs:
        response = my_handler({"inputs": input})
    print(response)