Copycats commited on
Commit
304ffb4
1 Parent(s): cb5f4b6

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +37 -0
handler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
4
+
5
+ # check for GPU
6
+ device = 0 if torch.cuda.is_available() else -1
7
+
8
+ # multi-model list
9
+ multi_model_list = [
10
+ {"model_id": "BAAI/bge-base-en-v1.5", "task":"sentence-embeddings"},
11
+ {"model_id": "BAAI/bge-reranker-base", "task":"sentence-ranking"},
12
+ ]
13
+
14
+ class EndpointHandler():
15
+ def __init__(self, path=""):
16
+ self.multi_model={}
17
+ # load all the models onto device
18
+ for model in multi_model_list:
19
+ self.multi_model[model["model_id"]] = pipeline(model["task"], model=model["model_id"], device=device)
20
+
21
+ def __call__(self, data):
22
+ # deserialize incomin request
23
+ inputs = data.pop("inputs", data)
24
+ parameters = data.pop("parameters", None)
25
+ model_id = data.pop("model_id", None)
26
+
27
+ # check if model_id is in the list of models
28
+ if model_id is None or model_id not in self.multi_model:
29
+ raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}")
30
+
31
+ # pass inputs with all kwargs in data
32
+ if parameters is not None:
33
+ prediction = self.multi_model[model_id](inputs, **parameters)
34
+ else:
35
+ prediction = self.multi_model[model_id](inputs)
36
+ # postprocess the prediction
37
+ return prediction