Michael Brunzel commited on
Commit
61eac05
1 Parent(s): bb8956c

Add stopping criteria

Browse files
Files changed (1) hide show
  1. handler.py +32 -3
handler.py CHANGED
@@ -1,9 +1,32 @@
1
  from typing import Dict, List, Any, Union
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  from peft import PeftModel
5
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
9
  # load model and processor from path
@@ -69,9 +92,15 @@ class EndpointHandler:
69
 
70
  # pass inputs with all kwargs in data
71
  if parameters is not None:
72
- outputs = self.model.generate(input_ids=input_ids, **parameters)
 
 
 
73
  else:
74
- outputs = self.model.generate(input_ids=input_ids, max_new_tokens=20)
 
 
 
75
 
76
  # postprocess the prediction
77
  prediction = self.tokenizer.decode(outputs[0][input_ids.shape[1]:]) #, skip_special_tokens=True)
 
1
  from typing import Dict, List, Any, Union
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria
3
  import torch
4
  from peft import PeftModel
5
 
6
 
7
+ class MyStoppingCriteria(StoppingCriteria):
8
+ def __init__(self, target_sequence, prompt, tokenizer):
9
+ self.target_sequence = target_sequence
10
+ self.prompt=prompt
11
+ self.tokenizer = tokenizer
12
+
13
+ def __call__(self, input_ids, scores, **kwargs):
14
+ # Get the generated text as a string
15
+ generated_text = self.tokenizer.decode(input_ids[0])
16
+ generated_text = generated_text.replace(self.prompt,'')
17
+ # Check if the target sequence appears in the generated text
18
+ if self.target_sequence in generated_text:
19
+ return True # Stop generation
20
+
21
+ return False # Continue generation
22
+
23
+ def __len__(self):
24
+ return 1
25
+
26
+ def __iter__(self):
27
+ yield self
28
+
29
+
30
  class EndpointHandler:
31
  def __init__(self, path=""):
32
  # load model and processor from path
 
92
 
93
  # pass inputs with all kwargs in data
94
  if parameters is not None:
95
+ outputs = self.model.generate(
96
+ input_ids=input_ids,
97
+ stopping_criteria=MyStoppingCriteria("<|endoftext|>", inputs, self.tokenizer),
98
+ **parameters)
99
  else:
100
+ outputs = self.model.generate(
101
+ input_ids=input_ids, max_new_tokens=32,
102
+ stopping_criteria=MyStoppingCriteria("<|endoftext|>", inputs, self.tokenizer)
103
+ )
104
 
105
  # postprocess the prediction
106
  prediction = self.tokenizer.decode(outputs[0][input_ids.shape[1]:]) #, skip_special_tokens=True)