|
from typing import Dict, Any |
|
from transformers import LlamaForCausalLM, LlamaTokenizer |
|
import torch |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
self.model_name_or_path = path or "souzat19/Llama3.1_fn14133.29122024" |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device: {self.device}") |
|
|
|
print("Initializing tokenizer...") |
|
self.tokenizer = LlamaTokenizer.from_pretrained( |
|
self.model_name_or_path, |
|
trust_remote_code=True |
|
) |
|
|
|
print("Initializing model...") |
|
self.model = LlamaForCausalLM.from_pretrained( |
|
self.model_name_or_path, |
|
torch_dtype=torch.float32, |
|
trust_remote_code=True, |
|
device_map="auto" if torch.cuda.is_available() else None, |
|
local_files_only=True if path else False |
|
) |
|
|
|
if not torch.cuda.is_available(): |
|
self.model = self.model.to("cpu") |
|
|
|
print("Model initialized successfully") |
|
|
|
|
|
self.prompt_template = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. |
|
### Instruction: |
|
Você é um assistente especializado em planejamento de compras públicas de acordo com a Lei 14.133/2021 e regulamentos infralegais. Responda de forma clara, detalhada e didática e utilize exemplos práticos para explicar os conceitos. |
|
### Input: |
|
{input} |
|
### Response: |
|
""" |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
try: |
|
|
|
input_text = data.get("text", "") |
|
if not input_text: |
|
return {"error": "Input text is required"} |
|
|
|
|
|
formatted_prompt = self.prompt_template.format(input=input_text) |
|
|
|
|
|
inputs = self.tokenizer( |
|
formatted_prompt, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=4096, |
|
add_special_tokens=True |
|
) |
|
|
|
if torch.cuda.is_available(): |
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
**inputs, |
|
max_new_tokens=2096, |
|
temperature=0.5, |
|
top_p=0.95, |
|
top_k=50, |
|
do_sample=True, |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
eos_token_id=self.tokenizer.eos_token_id |
|
) |
|
|
|
|
|
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
if "### Response:" in response_text: |
|
answer = response_text.split("### Response:")[1].strip() |
|
else: |
|
answer = response_text.strip() |
|
|
|
return {"response": answer} |
|
|
|
except Exception as e: |
|
return {"error": f"Error during inference: {str(e)}"} |