import json import logging import torch from typing import List from typing import Dict, Any from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria import torch class MyStoppingCriteria(StoppingCriteria): def __init__(self, target_sequence, prompt, tokenizer): self.target_sequence = target_sequence self.prompt = prompt self.tokenizer = tokenizer def __call__(self, input_ids, scores, **kwargs): # Get the generated text as a string generated_text = self.tokenizer.decode(input_ids[0]) generated_text = generated_text.replace(self.prompt, '') # Check if the target sequence appears in the generated text if self.target_sequence in generated_text: return True # Stop generation return False # Continue generation def __len__(self): return 1 def __iter__(self): yield self class EndpointHandler: def __init__(self, model_dir=""): # load model and processor from path self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.model = AutoModelForCausalLM.from_pretrained(model_dir, load_in_4bit=True, device_map="auto") self.template = { "prompt_input": """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n""", "prompt_no_input": """Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n""", "response_split": """### Response:""" } self.instruction = """Extract the start and end sequences for the categories 'personal information', 'work experience', 'education' and 'skills' from the following text in dictionary form""" if torch.cuda.is_available(): self.device = "cuda" else: self.device = "cpu" def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: """ Args: data (:dict:): The payload with the text prompt and generation parameters. """ # process input inputs = data.pop("inputs", data) parameters = data.pop("parameters", None) res = self.template["prompt_input"].format( instruction=self.instruction, input=input ) messages = [ {"role": "user", "content": res}, ] input_ids = self.tokenizer.apply_chat_template( messages, truncation=True, add_generation_prompt=True, return_tensors="pt" ).input_ids input_ids = input_ids.to(self.device) # pass inputs with all kwargs in data if parameters is not None: outputs = self.model.generate( input_ids=input_ids, stopping_criteria=MyStoppingCriteria("", inputs, self.tokenizer), **parameters) else: outputs = self.model.generate( input_ids=input_ids, max_new_tokens=32, stopping_criteria=MyStoppingCriteria("", inputs, self.tokenizer) ) # postprocess the prediction prediction = self.tokenizer.decode(outputs[0][input_ids.shape[1]:]) #, skip_special_tokens=True) prediction = prediction.split("")[0] # TODO: add processing of the LLM output return [{"generated_text": prediction}]