hengyu commited on
Commit
f030635
1 Parent(s): 17b85b7

update model

Browse files
Files changed (3) hide show
  1. evaluation.ipynb +56 -12
  2. model.onnx +2 -2
  3. weights.pb +2 -2
evaluation.ipynb CHANGED
@@ -86,14 +86,18 @@
86
  " input_ids = pad(input_ids, (0, pad_len), value=1)\n",
87
  " ort_inputs = {\n",
88
  " 'input_ids': input_ids.detach().cpu().numpy(),\n",
89
- " 'attention_mask': torch.ones(input_ids.shape).detach().cpu().numpy().astype('int64')\n",
90
  " }\n",
 
 
 
91
  " predictions = session.run(None, ort_inputs)\n",
92
  " outputs = torch.from_numpy(predictions[0]) \n",
93
  " last_token_logits = outputs[:, -2 - pad_len, :]\n",
94
  " pred = last_token_logits.argmax(dim=-1)\n",
95
  " total += label.size(0)\n",
96
  " hit += (pred == label).sum().item()\n",
 
97
  "acc = hit / total\n",
98
  "print('acc: ', acc)"
99
  ]
@@ -132,19 +136,59 @@
132
  "\n",
133
  "print(\"prompt: \", prompt)\n",
134
  "\n",
 
 
 
 
135
  "# start\n",
136
- "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n",
137
- "for i in range(32):\n",
 
 
 
 
 
 
 
 
138
  " inp = {'input_ids': input_ids.detach().cpu().numpy(),\n",
139
- " 'attention_mask': torch.ones(input_ids.shape).detach().cpu().numpy().astype('int64')}\n",
140
- " output = session.run(None, inp)\n",
141
- " logits = output[0]\n",
142
- " logits = torch.from_numpy(logits)\n",
143
- " next_token_logits = logits[:, -1, :]\n",
144
- " probs = torch.nn.functional.softmax(next_token_logits, dim=-1)\n",
145
- " next_tokens = torch.argmax(probs, dim=-1)\n",
146
- " input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)\n",
147
- "print(tokenizer.decode(input_ids[0]))"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  ]
149
  }
150
  ],
 
86
  " input_ids = pad(input_ids, (0, pad_len), value=1)\n",
87
  " ort_inputs = {\n",
88
  " 'input_ids': input_ids.detach().cpu().numpy(),\n",
89
+ " 'attention_mask': torch.cat([torch.ones(input_ids.shape), torch.ones([1, 1])], dim=-1).detach().cpu().numpy().astype('int64')\n",
90
  " }\n",
91
+ " for i in range(28):\n",
92
+ " ort_inputs[\"past_key_values.{}.key\".format(i)] = np.zeros((1,16,1,256), dtype='float32')\n",
93
+ " ort_inputs[\"past_key_values.{}.value\".format(i)] = np.zeros((1,16,1,256), dtype='float32')\n",
94
  " predictions = session.run(None, ort_inputs)\n",
95
  " outputs = torch.from_numpy(predictions[0]) \n",
96
  " last_token_logits = outputs[:, -2 - pad_len, :]\n",
97
  " pred = last_token_logits.argmax(dim=-1)\n",
98
  " total += label.size(0)\n",
99
  " hit += (pred == label).sum().item()\n",
100
+ "\n",
101
  "acc = hit / total\n",
102
  "print('acc: ', acc)"
103
  ]
 
136
  "\n",
137
  "print(\"prompt: \", prompt)\n",
138
  "\n",
139
+ "total_time = 0.0\n",
140
+ "num_iter = 10\n",
141
+ "num_warmup = 3\n",
142
+ "\n",
143
  "# start\n",
144
+ "for idx in range(num_iter):\n",
145
+ " text = []\n",
146
+ " tic = time.time()\n",
147
+ "\n",
148
+ " input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n",
149
+ "\n",
150
+ " attention_mask = torch.ones(input_ids.shape[1] +1)\n",
151
+ " attention_mask[0] = 0\n",
152
+ " attention_mask = attention_mask.unsqueeze(0)\n",
153
+ "\n",
154
  " inp = {'input_ids': input_ids.detach().cpu().numpy(),\n",
155
+ " 'attention_mask': attention_mask.detach().cpu().numpy().astype('int64')}\n",
156
+ " for i in range(28):\n",
157
+ " inp[\"past_key_values.{}.key\".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()\n",
158
+ " inp[\"past_key_values.{}.value\".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()\n",
159
+ "\n",
160
+ " for i in range(32):\n",
161
+ "\n",
162
+ " output = session.run(None, inp)\n",
163
+ " logits = output[0]\n",
164
+ " logits = torch.from_numpy(logits)\n",
165
+ " next_token_logits = logits[:, -1, :]\n",
166
+ " probs = torch.nn.functional.softmax(next_token_logits, dim=-1)\n",
167
+ " next_tokens = torch.argmax(probs, dim=-1)\n",
168
+ " present_kv = output[1]\n",
169
+ " for i in range(28):\n",
170
+ "\n",
171
+ " if step == 0:\n",
172
+ " inp[\"past_key_values.{}.key\".format(i)] = output[2*i+1][:, :, 1:, :]\n",
173
+ " inp[\"past_key_values.{}.value\".format(i)] = output[2*i+2][:, :, 1:, :]\n",
174
+ " else:\n",
175
+ " inp[\"past_key_values.{}.key\".format(i)] = output[2*i+1]\n",
176
+ " inp[\"past_key_values.{}.value\".format(i)] = output[2*i+2]\n",
177
+ "\n",
178
+ " input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)\n",
179
+ " if step == 0:\n",
180
+ " attention_mask = torch.cat([attention_mask[:, 1:], torch.ones([1, 1])], dim=-1)\n",
181
+ " else:\n",
182
+ " attention_mask = torch.cat([attention_mask, torch.ones([1, 1])], dim=-1)\n",
183
+ "\n",
184
+ " inp['attention_mask'] = attention_mask.detach().cpu().numpy().astype('int64')\n",
185
+ " inp['input_ids'] = input_ids[:, -1:].detach().cpu().numpy()\n",
186
+ "\n",
187
+ " print(tokenizer.decode(input_ids[0]))\n",
188
+ " toc = time.time()\n",
189
+ " if idx >= num_warmup:\n",
190
+ " total_time += (toc - tic)\n",
191
+ "print(\"Inference latency: %.3f s.\" % (total_time / (num_iter - num_warmup)))"
192
  ]
193
  }
194
  ],
model.onnx CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9c1e2627bdfc69469e0bb412d24acffd611a686be3fdf788f1c077040f5e0f92
3
- size 6127447
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99af1fc6a93e6b02902f3f4c3fe32bf3d7bb4441406bee3bf0cbceaa5b9f64e3
3
+ size 6332176
weights.pb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:938fa97d7d469cb2f373c92916a55e5bcfab1cff40bd878f6f789ccae240c655
3
- size 6790222720
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9641d64847996acc53c7093cf4ff9c02443b9c4fd61699cb9ac00b86861c528
3
+ size 6057661312