kajdun commited on
Commit
023c0f7
·
1 Parent(s): e56b9b2

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +8 -6
handler.py CHANGED
@@ -1,19 +1,21 @@
1
  import torch
2
  from typing import Dict, List, Any
3
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig, TextGenerationPipeline
4
  from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
5
 
6
  # check for GPU
7
  device = 0 if torch.cuda.is_available() else -1
8
 
 
 
9
  class EndpointHandler():
10
  def __init__(self, path=""):
11
  # load the optimized model
12
- model = AutoGPTQForCausalLM.from_quantized(path, use_safetensors=True) #file_name="model-quantized.onnx")
13
  tokenizer = AutoTokenizer.from_pretrained(path)
14
  # or you can also use pipeline
15
- self.pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer)
16
-
17
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
18
  """
19
  Args:
@@ -27,8 +29,8 @@ class EndpointHandler():
27
 
28
  # pass inputs with all kwargs in data
29
  if parameters is not None:
30
- prediction = self.pipeline(inputs, **parameters)
31
  else:
32
- prediction = self.pipeline(inputs)
33
 
34
  return prediction
 
1
  import torch
2
  from typing import Dict, List, Any
3
+ from transformers import pipeline, AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig, TextGenerationPipeline
4
  from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
5
 
6
  # check for GPU
7
  device = 0 if torch.cuda.is_available() else -1
8
 
9
+ print(f"cuda: {device}")
10
+
11
  class EndpointHandler():
12
  def __init__(self, path=""):
13
  # load the optimized model
14
+ model = AutoGPTQForCausalLM.from_quantized(path, use_safetensors=False, low_cpu_mem_usage=True) #file_name="model-quantized.onnx")
15
  tokenizer = AutoTokenizer.from_pretrained(path)
16
  # or you can also use pipeline
17
+ #self.pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer)
18
+ self.generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
19
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
20
  """
21
  Args:
 
29
 
30
  # pass inputs with all kwargs in data
31
  if parameters is not None:
32
+ prediction = self.generator(inputs, **parameters)
33
  else:
34
+ prediction = self.generator(inputs)
35
 
36
  return prediction