Michael Brunzel
commited on
Commit
•
e48403c
1
Parent(s):
d064821
Update generate
Browse files- handler.py +7 -1
handler.py
CHANGED
@@ -23,6 +23,11 @@ class EndpointHandler:
|
|
23 |
}
|
24 |
self.instruction = """Extract the name of the person, the location, the hotel name and the desired date from the following hotel request"""
|
25 |
|
|
|
|
|
|
|
|
|
|
|
26 |
def generate_prompt(
|
27 |
self,
|
28 |
template: str,
|
@@ -55,6 +60,7 @@ class EndpointHandler:
|
|
55 |
parameters = data.pop("parameters", None)
|
56 |
|
57 |
inputs = self.generate_prompt(self.template, self.instruction, inputs)
|
|
|
58 |
# preprocess
|
59 |
self.tokenizer.pad_token_id = (
|
60 |
0 # unk. we want this to be different from the eos token
|
@@ -65,7 +71,7 @@ class EndpointHandler:
|
|
65 |
if parameters is not None:
|
66 |
outputs = self.model.generate(input_ids, **parameters)
|
67 |
else:
|
68 |
-
outputs = self.model.generate(input_ids)
|
69 |
|
70 |
# postprocess the prediction
|
71 |
prediction = self.tokenizer.decode(outputs[0]) #, skip_special_tokens=True)
|
|
|
23 |
}
|
24 |
self.instruction = """Extract the name of the person, the location, the hotel name and the desired date from the following hotel request"""
|
25 |
|
26 |
+
if torch.cuda.is_available():
|
27 |
+
self.device = "cuda"
|
28 |
+
else:
|
29 |
+
self.device = "cpu"
|
30 |
+
|
31 |
def generate_prompt(
|
32 |
self,
|
33 |
template: str,
|
|
|
60 |
parameters = data.pop("parameters", None)
|
61 |
|
62 |
inputs = self.generate_prompt(self.template, self.instruction, inputs)
|
63 |
+
input_ids = input_ids.to(self.device)
|
64 |
# preprocess
|
65 |
self.tokenizer.pad_token_id = (
|
66 |
0 # unk. we want this to be different from the eos token
|
|
|
71 |
if parameters is not None:
|
72 |
outputs = self.model.generate(input_ids, **parameters)
|
73 |
else:
|
74 |
+
outputs = self.model.generate(input_ids, max_new_tokens=64)
|
75 |
|
76 |
# postprocess the prediction
|
77 |
prediction = self.tokenizer.decode(outputs[0]) #, skip_special_tokens=True)
|