update model
Browse files- evaluation.ipynb +56 -12
- model.onnx +2 -2
- weights.pb +2 -2
evaluation.ipynb
CHANGED
@@ -86,14 +86,18 @@
|
|
86 |
" input_ids = pad(input_ids, (0, pad_len), value=1)\n",
|
87 |
" ort_inputs = {\n",
|
88 |
" 'input_ids': input_ids.detach().cpu().numpy(),\n",
|
89 |
-
" 'attention_mask': torch.ones(input_ids.shape).detach().cpu().numpy().astype('int64')\n",
|
90 |
" }\n",
|
|
|
|
|
|
|
91 |
" predictions = session.run(None, ort_inputs)\n",
|
92 |
" outputs = torch.from_numpy(predictions[0]) \n",
|
93 |
" last_token_logits = outputs[:, -2 - pad_len, :]\n",
|
94 |
" pred = last_token_logits.argmax(dim=-1)\n",
|
95 |
" total += label.size(0)\n",
|
96 |
" hit += (pred == label).sum().item()\n",
|
|
|
97 |
"acc = hit / total\n",
|
98 |
"print('acc: ', acc)"
|
99 |
]
|
@@ -132,19 +136,59 @@
|
|
132 |
"\n",
|
133 |
"print(\"prompt: \", prompt)\n",
|
134 |
"\n",
|
|
|
|
|
|
|
|
|
135 |
"# start\n",
|
136 |
-
"
|
137 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
" inp = {'input_ids': input_ids.detach().cpu().numpy(),\n",
|
139 |
-
" 'attention_mask':
|
140 |
-
"
|
141 |
-
"
|
142 |
-
"
|
143 |
-
"
|
144 |
-
"
|
145 |
-
"
|
146 |
-
"
|
147 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
]
|
149 |
}
|
150 |
],
|
|
|
86 |
" input_ids = pad(input_ids, (0, pad_len), value=1)\n",
|
87 |
" ort_inputs = {\n",
|
88 |
" 'input_ids': input_ids.detach().cpu().numpy(),\n",
|
89 |
+
" 'attention_mask': torch.cat([torch.ones(input_ids.shape), torch.ones([1, 1])], dim=-1).detach().cpu().numpy().astype('int64')\n",
|
90 |
" }\n",
|
91 |
+
" for i in range(28):\n",
|
92 |
+
" ort_inputs[\"past_key_values.{}.key\".format(i)] = np.zeros((1,16,1,256), dtype='float32')\n",
|
93 |
+
" ort_inputs[\"past_key_values.{}.value\".format(i)] = np.zeros((1,16,1,256), dtype='float32')\n",
|
94 |
" predictions = session.run(None, ort_inputs)\n",
|
95 |
" outputs = torch.from_numpy(predictions[0]) \n",
|
96 |
" last_token_logits = outputs[:, -2 - pad_len, :]\n",
|
97 |
" pred = last_token_logits.argmax(dim=-1)\n",
|
98 |
" total += label.size(0)\n",
|
99 |
" hit += (pred == label).sum().item()\n",
|
100 |
+
"\n",
|
101 |
"acc = hit / total\n",
|
102 |
"print('acc: ', acc)"
|
103 |
]
|
|
|
136 |
"\n",
|
137 |
"print(\"prompt: \", prompt)\n",
|
138 |
"\n",
|
139 |
+
"total_time = 0.0\n",
|
140 |
+
"num_iter = 10\n",
|
141 |
+
"num_warmup = 3\n",
|
142 |
+
"\n",
|
143 |
"# start\n",
|
144 |
+
"for idx in range(num_iter):\n",
|
145 |
+
" text = []\n",
|
146 |
+
" tic = time.time()\n",
|
147 |
+
"\n",
|
148 |
+
" input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n",
|
149 |
+
"\n",
|
150 |
+
" attention_mask = torch.ones(input_ids.shape[1] +1)\n",
|
151 |
+
" attention_mask[0] = 0\n",
|
152 |
+
" attention_mask = attention_mask.unsqueeze(0)\n",
|
153 |
+
"\n",
|
154 |
" inp = {'input_ids': input_ids.detach().cpu().numpy(),\n",
|
155 |
+
" 'attention_mask': attention_mask.detach().cpu().numpy().astype('int64')}\n",
|
156 |
+
" for i in range(28):\n",
|
157 |
+
" inp[\"past_key_values.{}.key\".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()\n",
|
158 |
+
" inp[\"past_key_values.{}.value\".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()\n",
|
159 |
+
"\n",
|
160 |
+
" for i in range(32):\n",
|
161 |
+
"\n",
|
162 |
+
" output = session.run(None, inp)\n",
|
163 |
+
" logits = output[0]\n",
|
164 |
+
" logits = torch.from_numpy(logits)\n",
|
165 |
+
" next_token_logits = logits[:, -1, :]\n",
|
166 |
+
" probs = torch.nn.functional.softmax(next_token_logits, dim=-1)\n",
|
167 |
+
" next_tokens = torch.argmax(probs, dim=-1)\n",
|
168 |
+
" present_kv = output[1]\n",
|
169 |
+
" for i in range(28):\n",
|
170 |
+
"\n",
|
171 |
+
" if step == 0:\n",
|
172 |
+
" inp[\"past_key_values.{}.key\".format(i)] = output[2*i+1][:, :, 1:, :]\n",
|
173 |
+
" inp[\"past_key_values.{}.value\".format(i)] = output[2*i+2][:, :, 1:, :]\n",
|
174 |
+
" else:\n",
|
175 |
+
" inp[\"past_key_values.{}.key\".format(i)] = output[2*i+1]\n",
|
176 |
+
" inp[\"past_key_values.{}.value\".format(i)] = output[2*i+2]\n",
|
177 |
+
"\n",
|
178 |
+
" input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)\n",
|
179 |
+
" if step == 0:\n",
|
180 |
+
" attention_mask = torch.cat([attention_mask[:, 1:], torch.ones([1, 1])], dim=-1)\n",
|
181 |
+
" else:\n",
|
182 |
+
" attention_mask = torch.cat([attention_mask, torch.ones([1, 1])], dim=-1)\n",
|
183 |
+
"\n",
|
184 |
+
" inp['attention_mask'] = attention_mask.detach().cpu().numpy().astype('int64')\n",
|
185 |
+
" inp['input_ids'] = input_ids[:, -1:].detach().cpu().numpy()\n",
|
186 |
+
"\n",
|
187 |
+
" print(tokenizer.decode(input_ids[0]))\n",
|
188 |
+
" toc = time.time()\n",
|
189 |
+
" if idx >= num_warmup:\n",
|
190 |
+
" total_time += (toc - tic)\n",
|
191 |
+
"print(\"Inference latency: %.3f s.\" % (total_time / (num_iter - num_warmup)))"
|
192 |
]
|
193 |
}
|
194 |
],
|
model.onnx
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:99af1fc6a93e6b02902f3f4c3fe32bf3d7bb4441406bee3bf0cbceaa5b9f64e3
|
3 |
+
size 6332176
|
weights.pb
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e9641d64847996acc53c7093cf4ff9c02443b9c4fd61699cb9ac00b86861c528
|
3 |
+
size 6057661312
|