cpi-connect commited on
Commit
44d9bc9
·
1 Parent(s): 621df19

Upload model

Browse files
event_arg_predict.py CHANGED
@@ -39,7 +39,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=Tru
39
 
40
  from .args_model_utils import CustomRobertaWithPOS as ArgumentModel
41
  model_nugget = ArgumentModel(num_classes=43)
42
- model_nugget.load_state_dict(torch.load("cybersecurity_knowledge_graph/argument_model_state_dict.pth", map_location=device))
43
  model_nugget.eval()
44
 
45
  """
 
39
 
40
  from .args_model_utils import CustomRobertaWithPOS as ArgumentModel
41
  model_nugget = ArgumentModel(num_classes=43)
42
+ model_nugget.load_state_dict(torch.load(f"{os.path.dirname(os.path.abspath(__file__))}/argument_model_state_dict.pth", map_location=device))
43
  model_nugget.eval()
44
 
45
  """
event_nugget_predict.py CHANGED
@@ -35,7 +35,7 @@ model_checkpoint = "ehsanaghaei/SecureBERT"
35
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
36
 
37
  model_nugget = NuggetModel(num_classes = 11)
38
- model_nugget.load_state_dict(torch.load("cybersecurity_knowledge_graph/nugget_model_state_dict.pth", map_location=device))
39
  model_nugget.eval()
40
 
41
  """
 
35
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
36
 
37
  model_nugget = NuggetModel(num_classes = 11)
38
+ model_nugget.load_state_dict(torch.load(f"{os.path.dirname(os.path.abspath(__file__))}/nugget_model_state_dict.pth", map_location=device))
39
  model_nugget.eval()
40
 
41
  """
event_realis_predict.py CHANGED
@@ -51,7 +51,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=Tru
51
 
52
  from .realis_model_utils import CustomRobertaWithPOS as RealisModel
53
  model_realis = RealisModel(num_classes_realis=4)
54
- model_realis.load_state_dict(torch.load("cybersecurity_knowledge_graph/realis_model_state_dict.pth", map_location=device))
55
  model_realis.eval()
56
 
57
  """
 
51
 
52
  from .realis_model_utils import CustomRobertaWithPOS as RealisModel
53
  model_realis = RealisModel(num_classes_realis=4)
54
+ model_realis.load_state_dict(torch.load(f"{os.path.dirname(os.path.abspath(__file__))}/realis_model_state_dict.pth", map_location=device))
55
  model_realis.eval()
56
 
57
  """