kdevoe commited on
Commit
f848305
1 Parent(s): 12a9461

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +2 -1
inference.py CHANGED
@@ -31,7 +31,8 @@ def inference(input_text):
31
  input_ids = torch.Tensor(input['input_ids']).to(torch.device('cpu')).long()
32
  input_ids.resize_(1,len(input_ids))
33
  print(input_ids)
34
- mask = torch.Tensor(input['attention_mask'])
 
35
  output = model(input_ids, mask)
36
 
37
  return output.tolist()
 
31
  input_ids = torch.Tensor(input['input_ids']).to(torch.device('cpu')).long()
32
  input_ids.resize_(1,len(input_ids))
33
  print(input_ids)
34
+ mask = torch.Tensor(input['attention_mask']).to(torch.device('cpu'))
35
+ mask.resize_(1, len(mask))
36
  output = model(input_ids, mask)
37
 
38
  return output.tolist()