souzat19 commited on
Commit
2bf5eb9
·
verified ·
1 Parent(s): a705843

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +24 -10
handler.py CHANGED
@@ -1,5 +1,5 @@
1
  from typing import Dict, Any
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
  class EndpointHandler:
@@ -11,18 +11,26 @@ class EndpointHandler:
11
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
  print(f"Using device: {self.device}")
13
 
 
 
 
14
  # Inicialização do modelo e tokenizer
15
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
 
 
 
 
16
  self.model = AutoModelForCausalLM.from_pretrained(
17
  self.model_name_or_path,
 
 
 
18
  trust_remote_code=True,
19
- use_cache=True,
20
- low_cpu_mem_usage=True
 
21
  )
22
 
23
- # Move modelo para GPU se disponível
24
- self.model = self.model.to(self.device)
25
-
26
  # Template do prompt no formato Alpaca
27
  self.prompt_template = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
28
  ### Instruction:
@@ -46,7 +54,15 @@ Você é um assistente especializado em planejamento de compras públicas de aco
46
  formatted_prompt = self.prompt_template.format(input=input_text)
47
 
48
  # Tokeniza o input
49
- inputs = self.tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=4096)
 
 
 
 
 
 
 
 
50
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
51
 
52
  # Gera a resposta
@@ -80,8 +96,6 @@ Você é um assistente especializado em planejamento de compras públicas de aco
80
  """
81
  Pré-processa o texto de entrada se necessário
82
  """
83
- # Remove espaços extras e normaliza quebras de linha
84
- text = " ".join(text.split())
85
  return text.strip()
86
 
87
  def validate_input(self, text: str) -> bool:
 
1
  from typing import Dict, Any
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
3
  import torch
4
 
5
  class EndpointHandler:
 
11
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
  print(f"Using device: {self.device}")
13
 
14
+ # Configurações para evitar quantização automática
15
+ config = AutoConfig.from_pretrained(self.model_name_or_path)
16
+
17
  # Inicialização do modelo e tokenizer
18
+ self.tokenizer = AutoTokenizer.from_pretrained(
19
+ self.model_name_or_path,
20
+ trust_remote_code=True
21
+ )
22
+
23
  self.model = AutoModelForCausalLM.from_pretrained(
24
  self.model_name_or_path,
25
+ config=config,
26
+ torch_dtype=torch.float32, # Força o uso de float32
27
+ device_map="auto",
28
  trust_remote_code=True,
29
+ use_safetensors=True,
30
+ load_in_4bit=False,
31
+ load_in_8bit=False
32
  )
33
 
 
 
 
34
  # Template do prompt no formato Alpaca
35
  self.prompt_template = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
36
  ### Instruction:
 
54
  formatted_prompt = self.prompt_template.format(input=input_text)
55
 
56
  # Tokeniza o input
57
+ inputs = self.tokenizer(
58
+ formatted_prompt,
59
+ return_tensors="pt",
60
+ truncation=True,
61
+ max_length=4096,
62
+ add_special_tokens=True
63
+ )
64
+
65
+ # Move para o dispositivo apropriado
66
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
67
 
68
  # Gera a resposta
 
96
  """
97
  Pré-processa o texto de entrada se necessário
98
  """
 
 
99
  return text.strip()
100
 
101
  def validate_input(self, text: str) -> bool: