pygmalion7b / handler.py
scepter's picture
Create handler.py
40e34b6
raw
history blame contribute delete
No virus
2.51 kB
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
class EndpointHandler():
def __init__(self, path=""):
#quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)
# device_map = {
# "transformer.word_embeddings": 0,
# "transformer.word_embeddings_layernorm": 0,
# "lm_head": "cpu",
# "transformer.h": 0,
# "transformer.ln_f": 0,
# }
#path = "anon8231489123/gpt4-x-alpaca-13b-native-4bit-128g"
self.model = AutoModelForCausalLM.from_pretrained(
path,
device_map="auto",
load_in_8bit=True,
#kwargs="--wbits 4 --groupsize 128",
#device_map=device_map,
#quantization_config=quantization_config
)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.pipeline = pipeline("conversational", model = self.model, tokenizer=self.tokenizer, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16)
#rep= "anon8231489123/gpt4-x-alpaca-13b-native-4bit-128g"
# tokenizer = AutoTokenizer.from_pretrained(rep)
#model = AutoModelForCausalLM.from_pretrained(rep)
# inputs = tokenizer(["Today is"], return_tensors="pt")
# reply_ids = model.generate(**inputs, max_new_tokens=590) # return_dict_in_generate=True, output_scores=True
# outputs = tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0]
# print(outputs)
#modelPath = "/"
#self.pipeline = pipeline("conversational", model=modelPath)
# Preload all the elements you are going to need at inference.
# pseudo:
# self.model= load_model(path)
print("end")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# preprocess
input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
# pass inputs with all kwargs in data
if parameters is not None:
outputs = self.model.generate(input_ids, **parameters)
else:
outputs = self.model.generate(input_ids)
# postprocess the prediction
prediction = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return [{"generated_text": prediction}]