cpi-connect
commited on
Commit
·
621df19
1
Parent(s):
8241ba7
Upload model
Browse files- event_arg_predict.py +5 -5
- event_nugget_predict.py +3 -3
- event_realis_predict.py +5 -5
event_arg_predict.py
CHANGED
@@ -3,11 +3,11 @@ from annotated_text import annotated_text
|
|
3 |
import torch
|
4 |
from torch.utils.data import DataLoader
|
5 |
|
6 |
-
from
|
7 |
-
from
|
8 |
-
from
|
9 |
|
10 |
-
from
|
11 |
import spacy
|
12 |
from transformers import AutoTokenizer
|
13 |
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
|
@@ -37,7 +37,7 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
|
|
37 |
model_checkpoint = "ehsanaghaei/SecureBERT"
|
38 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
39 |
|
40 |
-
from
|
41 |
model_nugget = ArgumentModel(num_classes=43)
|
42 |
model_nugget.load_state_dict(torch.load("cybersecurity_knowledge_graph/argument_model_state_dict.pth", map_location=device))
|
43 |
model_nugget.eval()
|
|
|
3 |
import torch
|
4 |
from torch.utils.data import DataLoader
|
5 |
|
6 |
+
from .args_model_utils import tokenize_and_align_labels_with_pos_ner_dep, find_nearest_nugget_features, find_dep_depth
|
7 |
+
from .nugget_model_utils import CustomRobertaWithPOS
|
8 |
+
from .utils import get_content, get_event_nugget, get_idxs_from_text, get_entity_from_idx, list_of_pos_tags, event_args_list
|
9 |
|
10 |
+
from .event_nugget_predict import get_event_nuggets
|
11 |
import spacy
|
12 |
from transformers import AutoTokenizer
|
13 |
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
|
|
|
37 |
model_checkpoint = "ehsanaghaei/SecureBERT"
|
38 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
39 |
|
40 |
+
from .args_model_utils import CustomRobertaWithPOS as ArgumentModel
|
41 |
model_nugget = ArgumentModel(num_classes=43)
|
42 |
model_nugget.load_state_dict(torch.load("cybersecurity_knowledge_graph/argument_model_state_dict.pth", map_location=device))
|
43 |
model_nugget.eval()
|
event_nugget_predict.py
CHANGED
@@ -3,9 +3,9 @@ from annotated_text import annotated_text
|
|
3 |
import torch
|
4 |
from torch import nn
|
5 |
from torch.utils.data import DataLoader
|
6 |
-
from
|
7 |
-
from
|
8 |
-
from
|
9 |
import spacy
|
10 |
from transformers import AutoTokenizer
|
11 |
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
|
|
|
3 |
import torch
|
4 |
from torch import nn
|
5 |
from torch.utils.data import DataLoader
|
6 |
+
from .nugget_model_utils import CustomRobertaWithPOS as NuggetModel
|
7 |
+
from .nugget_model_utils import tokenize_and_align_labels_with_pos_ner_dep, find_nearest_nugget_features, find_dep_depth
|
8 |
+
from .utils import get_idxs_from_text, event_nugget_list
|
9 |
import spacy
|
10 |
from transformers import AutoTokenizer
|
11 |
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
|
event_realis_predict.py
CHANGED
@@ -3,12 +3,12 @@ import spacy
|
|
3 |
import torch
|
4 |
from torch.utils.data import DataLoader
|
5 |
from transformers import AutoTokenizer
|
6 |
-
from
|
7 |
import streamlit as st
|
8 |
from annotated_text import annotated_text
|
9 |
-
from
|
10 |
-
from
|
11 |
-
from
|
12 |
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
|
13 |
|
14 |
event_nugget_list = ['B-Phishing',
|
@@ -49,7 +49,7 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
|
|
49 |
model_checkpoint = "ehsanaghaei/SecureBERT"
|
50 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
51 |
|
52 |
-
from
|
53 |
model_realis = RealisModel(num_classes_realis=4)
|
54 |
model_realis.load_state_dict(torch.load("cybersecurity_knowledge_graph/realis_model_state_dict.pth", map_location=device))
|
55 |
model_realis.eval()
|
|
|
3 |
import torch
|
4 |
from torch.utils.data import DataLoader
|
5 |
from transformers import AutoTokenizer
|
6 |
+
from .utils import get_idxs_from_text
|
7 |
import streamlit as st
|
8 |
from annotated_text import annotated_text
|
9 |
+
from .nugget_model_utils import CustomRobertaWithPOS
|
10 |
+
from .event_nugget_predict import get_event_nuggets
|
11 |
+
from .realis_model_utils import get_entity_for_realis_from_idx, tokenize_and_align_labels_with_pos_ner_realis
|
12 |
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
|
13 |
|
14 |
event_nugget_list = ['B-Phishing',
|
|
|
49 |
model_checkpoint = "ehsanaghaei/SecureBERT"
|
50 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
51 |
|
52 |
+
from .realis_model_utils import CustomRobertaWithPOS as RealisModel
|
53 |
model_realis = RealisModel(num_classes_realis=4)
|
54 |
model_realis.load_state_dict(torch.load("cybersecurity_knowledge_graph/realis_model_state_dict.pth", map_location=device))
|
55 |
model_realis.eval()
|