Michael Brunzel
Add stopping criteria
61eac05
raw
history blame
4.3 kB
from typing import Dict, List, Any, Union
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria
import torch
from peft import PeftModel
class MyStoppingCriteria(StoppingCriteria):
def __init__(self, target_sequence, prompt, tokenizer):
self.target_sequence = target_sequence
self.prompt=prompt
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs):
# Get the generated text as a string
generated_text = self.tokenizer.decode(input_ids[0])
generated_text = generated_text.replace(self.prompt,'')
# Check if the target sequence appears in the generated text
if self.target_sequence in generated_text:
return True # Stop generation
return False # Continue generation
def __len__(self):
return 1
def __iter__(self):
yield self
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.model = AutoModelForCausalLM.from_pretrained(
"tiiuae/falcon-rw-1b", device_map="auto", load_in_8bit=True)
self.model = PeftModel.from_pretrained(
self.model,
"MichaelAI23/falcon-rw-1b_8bit_finetuned",
torch_dtype=torch.float16,
device_map="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-rw-1b")
self.template = {
"prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
"prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
"response_split": "### Response:"
}
self.instruction = """Extract the name of the person, the location, the hotel name and the desired date from the following hotel request"""
if torch.cuda.is_available():
self.device = "cuda"
else:
self.device = "cpu"
def generate_prompt(
self,
template: str,
instruction: str,
input: Union[None, str] = None,
label: Union[None, str] = None,
) -> str:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input:
res = template["prompt_input"].format(
instruction=instruction, input=input
)
else:
res = template["prompt_no_input"].format(
instruction=instruction
)
if label:
res = f"{res}{label}"
return res
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
# process input
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
inputs = self.generate_prompt(self.template, self.instruction, inputs)
# preprocess
self.tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
input_ids = input_ids.to(self.device)
# pass inputs with all kwargs in data
if parameters is not None:
outputs = self.model.generate(
input_ids=input_ids,
stopping_criteria=MyStoppingCriteria("<|endoftext|>", inputs, self.tokenizer),
**parameters)
else:
outputs = self.model.generate(
input_ids=input_ids, max_new_tokens=32,
stopping_criteria=MyStoppingCriteria("<|endoftext|>", inputs, self.tokenizer)
)
# postprocess the prediction
prediction = self.tokenizer.decode(outputs[0][input_ids.shape[1]:]) #, skip_special_tokens=True)
prediction = prediction.split("<|endoftext|>")[0]
return [{"generated_text": prediction}]