andrewrreed HF staff commited on
Commit
92f0779
1 Parent(s): ac70c4f

model to gpu

Browse files
Files changed (1) hide show
  1. handler.py +4 -0
handler.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from typing import Dict, List, Any
3
  from transformers import AutoTokenizer
4
  from gector import GECToR, predict, load_verb_dict
@@ -12,6 +13,9 @@ class EndpointHandler:
12
  os.path.join(path, "data/verb-form-vocab.txt")
13
  )
14
 
 
 
 
15
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
  """
17
  Process the input data and return the predicted results.
 
1
  import os
2
+ import torch
3
  from typing import Dict, List, Any
4
  from transformers import AutoTokenizer
5
  from gector import GECToR, predict, load_verb_dict
 
13
  os.path.join(path, "data/verb-form-vocab.txt")
14
  )
15
 
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ self.model.to(device)
18
+
19
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
20
  """
21
  Process the input data and return the predicted results.