Michael Brunzel
Update input ids filtering
bb8956c
raw
history blame
3.29 kB
from typing import Dict, List, Any, Union
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import PeftModel
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.model = AutoModelForCausalLM.from_pretrained(
"tiiuae/falcon-rw-1b", device_map="auto", load_in_8bit=True)
self.model = PeftModel.from_pretrained(
self.model,
"MichaelAI23/falcon-rw-1b_8bit_finetuned",
torch_dtype=torch.float16,
device_map="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-rw-1b")
self.template = {
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
"response_split": "### Response:"
}
self.instruction = """Extract the name of the person, the location, the hotel name and the desired date from the following hotel request"""
if torch.cuda.is_available():
self.device = "cuda"
else:
self.device = "cpu"
def generate_prompt(
self,
template: str,
instruction: str,
input: Union[None, str] = None,
label: Union[None, str] = None,
) -> str:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = template["prompt_input"].format(
instruction=instruction, input=input
)
else:
res = template["prompt_no_input"].format(
instruction=instruction
)
if label:
res = f"{res}{label}"
return res
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
# process input
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
inputs = self.generate_prompt(self.template, self.instruction, inputs)
# preprocess
self.tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
input_ids = input_ids.to(self.device)
# pass inputs with all kwargs in data
if parameters is not None:
outputs = self.model.generate(input_ids=input_ids, **parameters)
else:
outputs = self.model.generate(input_ids=input_ids, max_new_tokens=20)
# postprocess the prediction
prediction = self.tokenizer.decode(outputs[0][input_ids.shape[1]:]) #, skip_special_tokens=True)
prediction = prediction.split("<|endoftext|>")[0]
return [{"generated_text": prediction}]