from typing import Dict, List, Any from transformers import AutoModelForCausalLM, AutoTokenizer import torch import os MAX_INPUT_SIZE = 10_000 MAX_NEW_TOKENS = 4_000 def clean_json_text(text): """ Cleans JSON text by removing leading/trailing whitespace and escaping special characters. """ text = text.strip() text = text.replace("\#", "#").replace("\&", "&") return text class EndpointHandler: def __init__(self, path=""): # load model and processor from path self.model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto") self.model.eval() self.tokenizer = AutoTokenizer.from_pretrained(path) def __call__(self, data: Dict[str, Any]) -> str: data = data.pop("inputs") template = data.pop("template") text = data.pop("text") input_llm = f"<|input|>\n### Template:\n{template}\n### Text:\n{text}\n\n<|output|>" + "{" input_ids = self.tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=MAX_INPUT_SIZE).to("cuda") output = self.tokenizer.decode(self.model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS)[0], skip_special_tokens=True) return clean_json_text(output.split("<|output|>")[1])