sijieaaa's picture
Upload handler.py
a3dfe34 verified
from typing import Dict, List, Any
# from transformers import pipeline
# import holidays
from transformers import AutoTokenizer, AutoModelForCausalLM
class EndpointHandler():
def __init__(self, path=None):
# self.pipeline = pipeline("text-classification",model=path)
# self.holidays = holidays.US()
model_id = 'sijieaaa/CodeModel-V1-3B-2024-02-07'
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
# load_in_8bit=True,
torch_dtype="auto",
device_map="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_id
)
self.model.eval()
# self.tokenizer.eval()
# llm = vllm.LLM(model=model_id,
# dtype=torch.bfloat16,
# trust_remote_code=True,
# quantization="bitsandbytes",
# load_format="bitsandbytes")
a=1
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
prompt = data["inputs"]
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
generated_ids = self.model.generate(
**model_inputs,
max_new_tokens=512
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
response = [
{"role": "assistant", "content": response}
]
# yield response
return response
# def test():
# # init handler
# my_handler = EndpointHandler(path=".")
# # prepare sample payload
# non_holiday_payload = {"inputs": "I am quite excited how this will turn out", "date": "2022-08-08"}
# holiday_payload = {"inputs": "Today is a though day", "date": "2022-07-04"}
# # test the handler
# a = my_handler.__call__(non_holiday_payload)
# non_holiday_pred=my_handler(non_holiday_payload)
# holiday_payload=my_handler(holiday_payload)
# # show results
# print("non_holiday_pred", non_holiday_pred)
# print("holiday_payload", holiday_payload)
# a=1
# if __name__ == "__main__":
# test()