ypesk commited on
Commit
bea7c94
·
verified ·
1 Parent(s): d1c0661

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +8 -3
tasks/text.py CHANGED
@@ -158,10 +158,15 @@ async def evaluate_text(request: TextEvaluationRequest):
158
  model.eval()
159
  predictions = []
160
  for batch in tqdm(test_dataloader):
161
-
162
  with torch.no_grad():
163
- logits = model(**batch)
164
-
 
 
 
 
 
165
  logits = logits.detach().cpu().numpy()
166
  predictions.extend(logits.argmax(1))
167
 
 
158
  model.eval()
159
  predictions = []
160
  for batch in tqdm(test_dataloader):
161
+
162
  with torch.no_grad():
163
+ if MODEL =="mlp":
164
+ b_texts = batch
165
+ logits = model(b_texts)
166
+ elif MODEL == "ct":
167
+ b_input_ids, b_input_mask, b_token_type_ids, b_labels = batch
168
+ logits = model(b_input_ids, b_token_type_ids, b_input_mask)
169
+
170
  logits = logits.detach().cpu().numpy()
171
  predictions.extend(logits.argmax(1))
172