Update handler.py
Browse files- handler.py +15 -24
handler.py
CHANGED
@@ -11,8 +11,9 @@ class EndpointHandler:
|
|
11 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
|
12 |
self.model = AutoModelForCausalLM.from_pretrained(
|
13 |
self.model_name_or_path,
|
14 |
-
|
15 |
-
|
|
|
16 |
)
|
17 |
|
18 |
# Template do prompt no formato Alpaca
|
@@ -25,16 +26,6 @@ Você é um assistente especializado em planejamento de compras públicas de aco
|
|
25 |
"""
|
26 |
|
27 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
28 |
-
"""
|
29 |
-
Processa a entrada e retorna a resposta do modelo.
|
30 |
-
|
31 |
-
Args:
|
32 |
-
data: Dicionário contendo a entrada do usuário
|
33 |
-
Formato esperado: {"text": "pergunta do usuário"}
|
34 |
-
|
35 |
-
Returns:
|
36 |
-
Dict contendo a resposta do modelo
|
37 |
-
"""
|
38 |
try:
|
39 |
# Extrai o texto da entrada
|
40 |
input_text = data.get("text", "")
|
@@ -52,16 +43,17 @@ Você é um assistente especializado em planejamento de compras públicas de aco
|
|
52 |
inputs = inputs.to(self.model.device)
|
53 |
|
54 |
# Gera a resposta
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
65 |
|
66 |
# Decodifica a resposta
|
67 |
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
@@ -91,5 +83,4 @@ Você é um assistente especializado em planejamento de compras públicas de aco
|
|
91 |
"""
|
92 |
if not text or len(text.strip()) == 0:
|
93 |
return False
|
94 |
-
return True
|
95 |
-
|
|
|
11 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
|
12 |
self.model = AutoModelForCausalLM.from_pretrained(
|
13 |
self.model_name_or_path,
|
14 |
+
device_map="auto",
|
15 |
+
trust_remote_code=True,
|
16 |
+
use_cache=True
|
17 |
)
|
18 |
|
19 |
# Template do prompt no formato Alpaca
|
|
|
26 |
"""
|
27 |
|
28 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
try:
|
30 |
# Extrai o texto da entrada
|
31 |
input_text = data.get("text", "")
|
|
|
43 |
inputs = inputs.to(self.model.device)
|
44 |
|
45 |
# Gera a resposta
|
46 |
+
with torch.no_grad():
|
47 |
+
outputs = self.model.generate(
|
48 |
+
**inputs,
|
49 |
+
max_new_tokens=2096,
|
50 |
+
temperature=0.5,
|
51 |
+
top_p=0.95,
|
52 |
+
top_k=50,
|
53 |
+
do_sample=True,
|
54 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
55 |
+
eos_token_id=self.tokenizer.eos_token_id
|
56 |
+
)
|
57 |
|
58 |
# Decodifica a resposta
|
59 |
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
83 |
"""
|
84 |
if not text or len(text.strip()) == 0:
|
85 |
return False
|
86 |
+
return True
|
|