souzat19's picture
Update handler.py
c16bd56 verified
from typing import Dict, Any
from transformers import LlamaForCausalLM, LlamaTokenizer
import torch
class EndpointHandler:
def __init__(self, path=""):
# Configuração do modelo
self.model_name_or_path = path or "souzat19/Llama3.1_fn14133.29122024"
# Detecta se GPU está disponível
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
print("Initializing tokenizer...")
self.tokenizer = LlamaTokenizer.from_pretrained(
self.model_name_or_path,
trust_remote_code=True
)
print("Initializing model...")
self.model = LlamaForCausalLM.from_pretrained(
self.model_name_or_path,
torch_dtype=torch.float32,
trust_remote_code=True,
device_map="auto" if torch.cuda.is_available() else None,
local_files_only=True if path else False
)
if not torch.cuda.is_available():
self.model = self.model.to("cpu")
print("Model initialized successfully")
# Template do prompt
self.prompt_template = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
Você é um assistente especializado em planejamento de compras públicas de acordo com a Lei 14.133/2021 e regulamentos infralegais. Responda de forma clara, detalhada e didática e utilize exemplos práticos para explicar os conceitos.
### Input:
{input}
### Response:
"""
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
try:
# Extrai o texto da entrada
input_text = data.get("text", "")
if not input_text:
return {"error": "Input text is required"}
# Formata o prompt
formatted_prompt = self.prompt_template.format(input=input_text)
# Tokeniza o input
inputs = self.tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=4096,
add_special_tokens=True
)
if torch.cuda.is_available():
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Gera a resposta
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=2096,
temperature=0.5,
top_p=0.95,
top_k=50,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decodifica a resposta
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Processa a resposta
if "### Response:" in response_text:
answer = response_text.split("### Response:")[1].strip()
else:
answer = response_text.strip()
return {"response": answer}
except Exception as e:
return {"error": f"Error during inference: {str(e)}"}