Update handler.py
Browse files- handler.py +8 -6
handler.py
CHANGED
@@ -1,19 +1,21 @@
|
|
1 |
import torch
|
2 |
from typing import Dict, List, Any
|
3 |
-
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig, TextGenerationPipeline
|
4 |
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
5 |
|
6 |
# check for GPU
|
7 |
device = 0 if torch.cuda.is_available() else -1
|
8 |
|
|
|
|
|
9 |
class EndpointHandler():
|
10 |
def __init__(self, path=""):
|
11 |
# load the optimized model
|
12 |
-
model = AutoGPTQForCausalLM.from_quantized(path, use_safetensors=True) #file_name="model-quantized.onnx")
|
13 |
tokenizer = AutoTokenizer.from_pretrained(path)
|
14 |
# or you can also use pipeline
|
15 |
-
self.pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer)
|
16 |
-
|
17 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
18 |
"""
|
19 |
Args:
|
@@ -27,8 +29,8 @@ class EndpointHandler():
|
|
27 |
|
28 |
# pass inputs with all kwargs in data
|
29 |
if parameters is not None:
|
30 |
-
prediction = self.
|
31 |
else:
|
32 |
-
prediction = self.
|
33 |
|
34 |
return prediction
|
|
|
1 |
import torch
|
2 |
from typing import Dict, List, Any
|
3 |
+
from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig, TextGenerationPipeline
|
4 |
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
5 |
|
6 |
# check for GPU
|
7 |
device = 0 if torch.cuda.is_available() else -1
|
8 |
|
9 |
+
print(f"cuda: {device}")
|
10 |
+
|
11 |
class EndpointHandler():
|
12 |
def __init__(self, path=""):
|
13 |
# load the optimized model
|
14 |
+
model = AutoGPTQForCausalLM.from_quantized(path, use_safetensors=False, low_cpu_mem_usage=True) #file_name="model-quantized.onnx")
|
15 |
tokenizer = AutoTokenizer.from_pretrained(path)
|
16 |
# or you can also use pipeline
|
17 |
+
#self.pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer)
|
18 |
+
self.generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
19 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
20 |
"""
|
21 |
Args:
|
|
|
29 |
|
30 |
# pass inputs with all kwargs in data
|
31 |
if parameters is not None:
|
32 |
+
prediction = self.generator(inputs, **parameters)
|
33 |
else:
|
34 |
+
prediction = self.generator(inputs)
|
35 |
|
36 |
return prediction
|