kajdun commited on
Commit
3382124
·
1 Parent(s): a77b076

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +11 -7
handler.py CHANGED
@@ -7,6 +7,8 @@ import torch
7
  #device = 0 if torch.cuda.is_available() else -1
8
 
9
  MAX_INPUT_TOKEN_LENGTH = 4000
 
 
10
 
11
  class EndpointHandler():
12
  def __init__(self, path=""):
@@ -15,23 +17,25 @@ class EndpointHandler():
15
 
16
  def get_input_token_length(message: str) -> int:
17
  input_ids = self.tokenizer([message], return_tensors='np', add_special_tokens=False)['input_ids']
18
- return input_ids.shape[-1]
19
 
20
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
21
  inputs = data.pop("inputs", data)
22
- parameters = data.pop("parameters", None)
 
 
 
 
 
23
 
24
  input_token_length = get_input_token_length(inputs)
25
  if input_token_length > MAX_INPUT_TOKEN_LENGTH:
26
- [{"generated_text": None, "error": f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})"}]
27
 
28
  #input_ids = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
29
  input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
30
 
31
- if parameters is not None:
32
- outputs = self.model.generate(**input_ids, **parameters)
33
- else:
34
- outputs = self.model.generate(**input_ids)
35
 
36
  prediction = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
37
 
 
7
  #device = 0 if torch.cuda.is_available() else -1
8
 
9
  MAX_INPUT_TOKEN_LENGTH = 4000
10
+ MAX_MAX_NEW_TOKENS=2048
11
+ DEFAULT_MAX_NEW_TOKENS = 1024
12
 
13
  class EndpointHandler():
14
  def __init__(self, path=""):
 
17
 
18
  def get_input_token_length(message: str) -> int:
19
  input_ids = self.tokenizer([message], return_tensors='np', add_special_tokens=False)['input_ids']
20
+ return input_ids.shape[-1]
21
 
22
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
23
  inputs = data.pop("inputs", data)
24
+ parameters = data.pop("parameters", {})
25
+
26
+ parameters["max_new_tokens"] = parameters.pop("max_new_tokens", DEFAULT_MAX_NEW_TOKENS)
27
+
28
+ if parameters["max_new_tokens"] > MAX_MAX_NEW_TOKENS:
29
+ return [{"generated_text": None, "error": f"requested max_new_tokens too high (> {MAX_MAX_NEW_TOKENS})"}]
30
 
31
  input_token_length = get_input_token_length(inputs)
32
  if input_token_length > MAX_INPUT_TOKEN_LENGTH:
33
+ return [{"generated_text": None, "error": f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})"}]
34
 
35
  #input_ids = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
36
  input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
37
 
38
+ outputs = self.model.generate(**input_ids, **parameters)
 
 
 
39
 
40
  prediction = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
41