from typing import Dict, Any from transformers import LlamaForCausalLM, LlamaTokenizer import torch class EndpointHandler: def __init__(self, path=""): # Configuração do modelo self.model_name_or_path = path or "souzat19/Llama3.1_fn14133.29122024" # Detecta se GPU está disponível self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") print("Initializing tokenizer...") self.tokenizer = LlamaTokenizer.from_pretrained( self.model_name_or_path, trust_remote_code=True ) print("Initializing model...") self.model = LlamaForCausalLM.from_pretrained( self.model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True, device_map="auto" if torch.cuda.is_available() else None, local_files_only=True if path else False ) if not torch.cuda.is_available(): self.model = self.model.to("cpu") print("Model initialized successfully") # Template do prompt self.prompt_template = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. ### Instruction: Você é um assistente especializado em planejamento de compras públicas de acordo com a Lei 14.133/2021 e regulamentos infralegais. Responda de forma clara, detalhada e didática e utilize exemplos práticos para explicar os conceitos. ### Input: {input} ### Response: """ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: try: # Extrai o texto da entrada input_text = data.get("text", "") if not input_text: return {"error": "Input text is required"} # Formata o prompt formatted_prompt = self.prompt_template.format(input=input_text) # Tokeniza o input inputs = self.tokenizer( formatted_prompt, return_tensors="pt", truncation=True, max_length=4096, add_special_tokens=True ) if torch.cuda.is_available(): inputs = {k: v.to(self.device) for k, v in inputs.items()} # Gera a resposta with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=2096, temperature=0.5, top_p=0.95, top_k=50, do_sample=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) # Decodifica a resposta response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Processa a resposta if "### Response:" in response_text: answer = response_text.split("### Response:")[1].strip() else: answer = response_text.strip() return {"response": answer} except Exception as e: return {"error": f"Error during inference: {str(e)}"}