Michael Brunzel commited on
Commit
e48403c
1 Parent(s): d064821

Update generate

Browse files
Files changed (1) hide show
  1. handler.py +7 -1
handler.py CHANGED
@@ -23,6 +23,11 @@ class EndpointHandler:
23
  }
24
  self.instruction = """Extract the name of the person, the location, the hotel name and the desired date from the following hotel request"""
25
 
 
 
 
 
 
26
  def generate_prompt(
27
  self,
28
  template: str,
@@ -55,6 +60,7 @@ class EndpointHandler:
55
  parameters = data.pop("parameters", None)
56
 
57
  inputs = self.generate_prompt(self.template, self.instruction, inputs)
 
58
  # preprocess
59
  self.tokenizer.pad_token_id = (
60
  0 # unk. we want this to be different from the eos token
@@ -65,7 +71,7 @@ class EndpointHandler:
65
  if parameters is not None:
66
  outputs = self.model.generate(input_ids, **parameters)
67
  else:
68
- outputs = self.model.generate(input_ids)
69
 
70
  # postprocess the prediction
71
  prediction = self.tokenizer.decode(outputs[0]) #, skip_special_tokens=True)
 
23
  }
24
  self.instruction = """Extract the name of the person, the location, the hotel name and the desired date from the following hotel request"""
25
 
26
+ if torch.cuda.is_available():
27
+ self.device = "cuda"
28
+ else:
29
+ self.device = "cpu"
30
+
31
  def generate_prompt(
32
  self,
33
  template: str,
 
60
  parameters = data.pop("parameters", None)
61
 
62
  inputs = self.generate_prompt(self.template, self.instruction, inputs)
63
+ input_ids = input_ids.to(self.device)
64
  # preprocess
65
  self.tokenizer.pad_token_id = (
66
  0 # unk. we want this to be different from the eos token
 
71
  if parameters is not None:
72
  outputs = self.model.generate(input_ids, **parameters)
73
  else:
74
+ outputs = self.model.generate(input_ids, max_new_tokens=64)
75
 
76
  # postprocess the prediction
77
  prediction = self.tokenizer.decode(outputs[0]) #, skip_special_tokens=True)