EC2 Default User
Add lora model and custom inference file
564cdc6
raw
history blame
3.58 kB
import json
import logging
import torch
from typing import List
from typing import Dict, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria
import torch
class MyStoppingCriteria(StoppingCriteria):
def __init__(self, target_sequence, prompt, tokenizer):
self.target_sequence = target_sequence
self.prompt = prompt
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs):
# Get the generated text as a string
generated_text = self.tokenizer.decode(input_ids[0])
generated_text = generated_text.replace(self.prompt, '')
# Check if the target sequence appears in the generated text
if self.target_sequence in generated_text:
return True # Stop generation
return False # Continue generation
def __len__(self):
return 1
def __iter__(self):
yield self
class EndpointHandler:
def __init__(self, model_dir=""):
# load model and processor from path
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForCausalLM.from_pretrained(model_dir, load_in_4bit=True, device_map="auto")
self.template = {
"prompt_input": """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n""",
"prompt_no_input": """Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n""",
"response_split": """### Response:"""
}
self.instruction = """Extract the start and end sequences for the categories 'personal information', 'work experience', 'education' and 'skills' from the following text in dictionary form"""
if torch.cuda.is_available():
self.device = "cuda"
else:
self.device = "cpu"
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
# process input
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
res = self.template["prompt_input"].format(
instruction=self.instruction, input=input
)
messages = [
{"role": "user", "content": res},
]
input_ids = self.tokenizer.apply_chat_template(
messages, truncation=True, add_generation_prompt=True, return_tensors="pt"
).input_ids
input_ids = input_ids.to(self.device)
# pass inputs with all kwargs in data
if parameters is not None:
outputs = self.model.generate(
input_ids=input_ids,
stopping_criteria=MyStoppingCriteria("</s>", inputs, self.tokenizer),
**parameters)
else:
outputs = self.model.generate(
input_ids=input_ids, max_new_tokens=32,
stopping_criteria=MyStoppingCriteria("</s>", inputs, self.tokenizer)
)
# postprocess the prediction
prediction = self.tokenizer.decode(outputs[0][input_ids.shape[1]:]) #, skip_special_tokens=True)
prediction = prediction.split("</s>")[0]
# TODO: add processing of the LLM output
return [{"generated_text": prediction}]