cpi-connect
commited on
Commit
·
4e38daf
1
Parent(s):
d60b7fd
Upload 18 files
Browse files- .gitattributes +7 -0
- args_model_utils.py +210 -0
- argument_model_state_dict.pth +3 -0
- configuration.py +1 -0
- event_arg_predict.py +280 -0
- event_arg_role_dataloader.py +100 -0
- event_arg_role_predict.py +113 -0
- event_nugget_predict.py +250 -0
- event_realis_predict.py +270 -0
- model_59.pt +3 -0
- model_64_pos_ner.pt +3 -0
- model_66.pt +3 -0
- model_97.pt +3 -0
- nugget_model_state_dict.pth +3 -0
- nugget_model_utils.py +151 -0
- realis_model_state_dict.pth +3 -0
- realis_model_utils.py +146 -0
- utils.py +196 -0
.gitattributes
CHANGED
@@ -1 +1,8 @@
|
|
1 |
pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
|
2 |
+
argument_model_state_dict.pth filter=lfs diff=lfs merge=lfs -text
|
3 |
+
model_59.pt filter=lfs diff=lfs merge=lfs -text
|
4 |
+
model_64_pos_ner.pt filter=lfs diff=lfs merge=lfs -text
|
5 |
+
model_66.pt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
model_97.pt filter=lfs diff=lfs merge=lfs -text
|
7 |
+
nugget_model_state_dict.pth filter=lfs diff=lfs merge=lfs -text
|
8 |
+
realis_model_state_dict.pth filter=lfs diff=lfs merge=lfs -text
|
args_model_utils.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import spacy
|
3 |
+
import en_core_web_sm
|
4 |
+
from torch import nn
|
5 |
+
import math
|
6 |
+
|
7 |
+
|
8 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
9 |
+
|
10 |
+
from transformers import AutoModel, TrainingArguments, Trainer, RobertaTokenizer, RobertaModel
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
|
13 |
+
model_checkpoint = "ehsanaghaei/SecureBERT"
|
14 |
+
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
16 |
+
roberta_model = RobertaModel.from_pretrained(model_checkpoint).to(device)
|
17 |
+
|
18 |
+
nlp = en_core_web_sm.load()
|
19 |
+
pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
|
20 |
+
ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
|
21 |
+
dep_spacy_tag_list = list(nlp.get_pipe("parser").labels)
|
22 |
+
event_nugget_tag_list = ["Databreach", "Ransom", "PatchVulnerability", "Phishing", "DiscoverVulnerability"]
|
23 |
+
arg_nugget_relative_pos_tag_list = ["before-same-sentence", "before-differ-sentence", "after-same-sentence", "after-differ-sentence"]
|
24 |
+
|
25 |
+
class CustomRobertaWithPOS(nn.Module):
|
26 |
+
def __init__(self, num_classes):
|
27 |
+
super(CustomRobertaWithPOS, self).__init__()
|
28 |
+
self.num_classes = num_classes
|
29 |
+
|
30 |
+
self.pos_embed = nn.Embedding(len(pos_spacy_tag_list), 16)
|
31 |
+
self.ner_embed = nn.Embedding(len(ner_spacy_tag_list), 8)
|
32 |
+
self.dep_embed = nn.Embedding(len(dep_spacy_tag_list), 8)
|
33 |
+
self.depth_embed = nn.Embedding(17, 8)
|
34 |
+
self.subtype_embed = nn.Embedding(len(event_nugget_tag_list), 2)
|
35 |
+
self.dist_embed = nn.Embedding(11, 6)
|
36 |
+
self.relative_pos_embed = nn.Embedding(len(arg_nugget_relative_pos_tag_list), 2)
|
37 |
+
|
38 |
+
self.roberta = roberta_model
|
39 |
+
self.dropout1 = nn.Dropout(0.2)
|
40 |
+
self.fc1 = nn.Linear(self.roberta.config.hidden_size + 50, num_classes)
|
41 |
+
|
42 |
+
def forward(self, input_ids, attention_mask, pos_spacy, ner_spacy, dep_spacy, depth_spacy, nearest_nugget_subtype, nearest_nugget_dist, arg_nugget_relative_pos):
|
43 |
+
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
|
44 |
+
last_hidden_output = outputs.last_hidden_state
|
45 |
+
|
46 |
+
pooler_output = outputs.pooler_output
|
47 |
+
pooler_output_unsqz = pooler_output.unsqueeze(1)
|
48 |
+
pooler_output_fin = pooler_output_unsqz.expand(-1, last_hidden_output.shape[1], -1)
|
49 |
+
|
50 |
+
|
51 |
+
pos_mask = pos_spacy != -100
|
52 |
+
pos_embed_masked = self.pos_embed(pos_spacy[pos_mask])
|
53 |
+
pos_embed = torch.zeros((pos_spacy.shape[0], pos_spacy.shape[1], 16), dtype=torch.float).to(device)
|
54 |
+
pos_embed[pos_mask] = pos_embed_masked
|
55 |
+
|
56 |
+
ner_mask = ner_spacy != -100
|
57 |
+
ner_embed_masked = self.ner_embed(ner_spacy[ner_mask])
|
58 |
+
ner_embed = torch.zeros((ner_spacy.shape[0], ner_spacy.shape[1], 8), dtype=torch.float).to(device)
|
59 |
+
ner_embed[ner_mask] = ner_embed_masked
|
60 |
+
|
61 |
+
dep_mask = dep_spacy != -100
|
62 |
+
dep_embed_masked = self.dep_embed(dep_spacy[dep_mask])
|
63 |
+
dep_embed = torch.zeros((dep_spacy.shape[0], dep_spacy.shape[1], 8), dtype=torch.float).to(device)
|
64 |
+
dep_embed[dep_mask] = dep_embed_masked
|
65 |
+
|
66 |
+
depth_mask = depth_spacy != -100
|
67 |
+
depth_embed_masked = self.depth_embed(depth_spacy[depth_mask])
|
68 |
+
depth_embed = torch.zeros((depth_spacy.shape[0], depth_spacy.shape[1], 8), dtype=torch.float).to(device)
|
69 |
+
depth_embed[dep_mask] = depth_embed_masked
|
70 |
+
|
71 |
+
nearest_nugget_subtype_mask = nearest_nugget_subtype != -100
|
72 |
+
nearest_nugget_subtype_embed_masked = self.subtype_embed(nearest_nugget_subtype[nearest_nugget_subtype_mask])
|
73 |
+
nearest_nugget_subtype_embed = torch.zeros((nearest_nugget_subtype.shape[0], nearest_nugget_subtype.shape[1], 2), dtype=torch.float).to(device)
|
74 |
+
nearest_nugget_subtype_embed[dep_mask] = nearest_nugget_subtype_embed_masked
|
75 |
+
|
76 |
+
nearest_nugget_dist_mask = nearest_nugget_dist != -100
|
77 |
+
nearest_nugget_dist_embed_masked = self.dist_embed(nearest_nugget_dist[nearest_nugget_dist_mask])
|
78 |
+
nearest_nugget_dist_embed = torch.zeros((nearest_nugget_dist.shape[0], nearest_nugget_dist.shape[1], 6), dtype=torch.float).to(device)
|
79 |
+
nearest_nugget_dist_embed[dep_mask] = nearest_nugget_dist_embed_masked
|
80 |
+
|
81 |
+
arg_nugget_relative_pos_mask = arg_nugget_relative_pos != -100
|
82 |
+
arg_nugget_relative_pos_embed_masked = self.relative_pos_embed(arg_nugget_relative_pos[arg_nugget_relative_pos_mask])
|
83 |
+
arg_nugget_relative_pos_embed = torch.zeros((arg_nugget_relative_pos.shape[0], arg_nugget_relative_pos.shape[1], 2), dtype=torch.float).to(device)
|
84 |
+
arg_nugget_relative_pos_embed[dep_mask] = arg_nugget_relative_pos_embed_masked
|
85 |
+
|
86 |
+
features_concat = torch.cat((last_hidden_output, pos_embed, ner_embed, dep_embed, depth_embed, nearest_nugget_subtype_embed, nearest_nugget_dist_embed, arg_nugget_relative_pos_embed), 2).to(device)
|
87 |
+
features_concat = self.dropout1(features_concat)
|
88 |
+
|
89 |
+
logits = self.fc1(features_concat)
|
90 |
+
|
91 |
+
return logits
|
92 |
+
|
93 |
+
|
94 |
+
def tokenize_and_align_labels_with_pos_ner_dep(examples, tokenizer, label_all_tokens = True):
|
95 |
+
tokenized_inputs = tokenizer(examples["tokens"], padding='max_length', truncation=True, is_split_into_words=True)
|
96 |
+
#tokenized_inputs.pop('input_ids')
|
97 |
+
ner_spacy = []
|
98 |
+
pos_spacy = []
|
99 |
+
dep_spacy = []
|
100 |
+
depth_spacy = []
|
101 |
+
nearest_nugget_subtype = []
|
102 |
+
nearest_nugget_dist = []
|
103 |
+
arg_nugget_relative_pos = []
|
104 |
+
|
105 |
+
for i, (pos, ner, dep, depth, subtype, dist, relative_pos) in enumerate(zip(examples["pos_spacy"],
|
106 |
+
examples["ner_spacy"],
|
107 |
+
examples["dep_spacy"],
|
108 |
+
examples["depth_spacy"],
|
109 |
+
examples["nearest_nugget_subtype"],
|
110 |
+
examples["nearest_nugget_dist"],
|
111 |
+
examples["arg_nugget_relative_pos"])):
|
112 |
+
word_ids = tokenized_inputs.word_ids(batch_index=i)
|
113 |
+
previous_word_idx = None
|
114 |
+
ner_spacy_ids = []
|
115 |
+
pos_spacy_ids = []
|
116 |
+
dep_spacy_ids = []
|
117 |
+
depth_spacy_ids = []
|
118 |
+
nearest_nugget_subtype_ids = []
|
119 |
+
nearest_nugget_dist_ids = []
|
120 |
+
arg_nugget_relative_pos_ids = []
|
121 |
+
|
122 |
+
for word_idx in word_ids:
|
123 |
+
# Special tokens have a word id that is None. We set the label to -100 so they are automatically
|
124 |
+
# ignored in the loss function.
|
125 |
+
if word_idx is None:
|
126 |
+
ner_spacy_ids.append(-100)
|
127 |
+
pos_spacy_ids.append(-100)
|
128 |
+
dep_spacy_ids.append(-100)
|
129 |
+
depth_spacy_ids.append(-100)
|
130 |
+
nearest_nugget_subtype_ids.append(-100)
|
131 |
+
nearest_nugget_dist_ids.append(-100)
|
132 |
+
arg_nugget_relative_pos_ids.append(-100)
|
133 |
+
# We set the label for the first token of each word.
|
134 |
+
elif word_idx != previous_word_idx:
|
135 |
+
ner_spacy_ids.append(ner[word_idx])
|
136 |
+
pos_spacy_ids.append(pos[word_idx])
|
137 |
+
dep_spacy_ids.append(dep[word_idx])
|
138 |
+
depth_spacy_ids.append(depth[word_idx])
|
139 |
+
nearest_nugget_subtype_ids.append(subtype[word_idx])
|
140 |
+
nearest_nugget_dist_ids.append(dist[word_idx])
|
141 |
+
arg_nugget_relative_pos_ids.append(relative_pos[word_idx])
|
142 |
+
# For the other tokens in a word, we set the label to either the current label or -100, depending on
|
143 |
+
# the label_all_tokens flag.
|
144 |
+
else:
|
145 |
+
ner_spacy_ids.append(ner[word_idx] if label_all_tokens else -100)
|
146 |
+
pos_spacy_ids.append(pos[word_idx] if label_all_tokens else -100)
|
147 |
+
dep_spacy_ids.append(dep[word_idx] if label_all_tokens else -100)
|
148 |
+
depth_spacy_ids.append(depth[word_idx] if label_all_tokens else -100)
|
149 |
+
nearest_nugget_subtype_ids.append(subtype[word_idx] if label_all_tokens else -100)
|
150 |
+
nearest_nugget_dist_ids.append(dist[word_idx] if label_all_tokens else -100)
|
151 |
+
arg_nugget_relative_pos_ids.append(relative_pos[word_idx] if label_all_tokens else -100)
|
152 |
+
previous_word_idx = word_idx
|
153 |
+
|
154 |
+
ner_spacy.append(ner_spacy_ids)
|
155 |
+
pos_spacy.append(pos_spacy_ids)
|
156 |
+
dep_spacy.append(dep_spacy_ids)
|
157 |
+
depth_spacy.append(depth_spacy_ids)
|
158 |
+
nearest_nugget_subtype.append(nearest_nugget_subtype_ids)
|
159 |
+
nearest_nugget_dist.append(nearest_nugget_dist_ids)
|
160 |
+
arg_nugget_relative_pos.append(arg_nugget_relative_pos_ids)
|
161 |
+
|
162 |
+
tokenized_inputs["pos_spacy"] = pos_spacy
|
163 |
+
tokenized_inputs["ner_spacy"] = ner_spacy
|
164 |
+
tokenized_inputs["dep_spacy"] = dep_spacy
|
165 |
+
tokenized_inputs["depth_spacy"] = depth_spacy
|
166 |
+
tokenized_inputs["nearest_nugget_subtype"] = nearest_nugget_subtype
|
167 |
+
tokenized_inputs["nearest_nugget_dist"] = nearest_nugget_dist
|
168 |
+
tokenized_inputs["arg_nugget_relative_pos"] = arg_nugget_relative_pos
|
169 |
+
return tokenized_inputs
|
170 |
+
|
171 |
+
def find_nearest_nugget_features(doc, start_idx, end_idx, event_nuggets):
|
172 |
+
nearest_subtype = None
|
173 |
+
nearest_dist = math.inf
|
174 |
+
relative_pos = None
|
175 |
+
|
176 |
+
mid_idx = (end_idx + start_idx) / 2
|
177 |
+
for nugget in event_nuggets:
|
178 |
+
mid_nugget_idx = (nugget["startOffset"] + nugget["endOffset"]) / 2
|
179 |
+
dist = abs(mid_nugget_idx - mid_idx)
|
180 |
+
|
181 |
+
if dist < nearest_dist:
|
182 |
+
nearest_dist = dist
|
183 |
+
nearest_subtype = nugget["subtype"]
|
184 |
+
for sent in doc.sents:
|
185 |
+
if between_idxs(mid_idx, sent.start_char, sent.end_char) and between_idxs(mid_nugget_idx, sent.start_char, sent.end_char):
|
186 |
+
if mid_idx < mid_nugget_idx:
|
187 |
+
relative_pos = "before-same-sentence"
|
188 |
+
else:
|
189 |
+
relative_pos = "after-same-sentence"
|
190 |
+
break
|
191 |
+
elif between_idxs(mid_nugget_idx, sent.start_char, sent.end_char) and mid_idx > mid_nugget_idx:
|
192 |
+
relative_pos = "after-differ-sentence"
|
193 |
+
break
|
194 |
+
elif between_idxs(mid_idx, sent.start_char, sent.end_char) and mid_idx < mid_nugget_idx:
|
195 |
+
relative_pos = "before-differ-sentence"
|
196 |
+
break
|
197 |
+
|
198 |
+
nearest_dist = int(min(10, nearest_dist // 20))
|
199 |
+
return nearest_subtype, nearest_dist, relative_pos
|
200 |
+
|
201 |
+
def find_dep_depth(token):
|
202 |
+
depth = 0
|
203 |
+
current_token = token
|
204 |
+
while current_token.head != current_token:
|
205 |
+
depth += 1
|
206 |
+
current_token = current_token.head
|
207 |
+
return min(depth, 16)
|
208 |
+
|
209 |
+
def between_idxs(idx, start_idx, end_idx):
|
210 |
+
return idx >= start_idx and idx <= end_idx
|
argument_model_state_dict.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:185e22992430c80ec1eb1fca7f3ba4ebe801163c3ba13bed00abc6dc24072712
|
3 |
+
size 498813605
|
configuration.py
CHANGED
@@ -5,6 +5,7 @@ from cybersecurity_knowledge_graph.utils import event_args_list, event_nugget_li
|
|
5 |
|
6 |
|
7 |
class CybersecurityKnowledgeGraphConfig(PretrainedConfig):
|
|
|
8 |
|
9 |
def __init__(
|
10 |
self,
|
|
|
5 |
|
6 |
|
7 |
class CybersecurityKnowledgeGraphConfig(PretrainedConfig):
|
8 |
+
model_type = "cybersecurity_knowledge_graph"
|
9 |
|
10 |
def __init__(
|
11 |
self,
|
event_arg_predict.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from annotated_text import annotated_text
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
|
6 |
+
from cybersecurity_knowledge_graph.args_model_utils import tokenize_and_align_labels_with_pos_ner_dep, find_nearest_nugget_features, find_dep_depth
|
7 |
+
from cybersecurity_knowledge_graph.nugget_model_utils import CustomRobertaWithPOS
|
8 |
+
from cybersecurity_knowledge_graph.utils import get_content, get_event_nugget, get_idxs_from_text, get_entity_from_idx, list_of_pos_tags, event_args_list
|
9 |
+
|
10 |
+
from cybersecurity_knowledge_graph.event_nugget_predict import get_event_nuggets
|
11 |
+
import spacy
|
12 |
+
from transformers import AutoTokenizer
|
13 |
+
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
|
14 |
+
import os
|
15 |
+
|
16 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
17 |
+
|
18 |
+
def find_dep_depth(token):
|
19 |
+
depth = 0
|
20 |
+
current_token = token
|
21 |
+
while current_token.head != current_token:
|
22 |
+
depth += 1
|
23 |
+
current_token = current_token.head
|
24 |
+
return min(depth, 16)
|
25 |
+
|
26 |
+
|
27 |
+
nlp = spacy.load('en_core_web_sm')
|
28 |
+
|
29 |
+
pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
|
30 |
+
ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
|
31 |
+
dep_spacy_tag_list = list(nlp.get_pipe("parser").labels)
|
32 |
+
event_nugget_tag_list = ["Databreach", "Ransom", "PatchVulnerability", "Phishing", "DiscoverVulnerability"]
|
33 |
+
arg_nugget_relative_pos_tag_list = ["before-same-sentence", "before-differ-sentence", "after-same-sentence", "after-differ-sentence"]
|
34 |
+
|
35 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
36 |
+
|
37 |
+
model_checkpoint = "ehsanaghaei/SecureBERT"
|
38 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
39 |
+
|
40 |
+
from cybersecurity_knowledge_graph.args_model_utils import CustomRobertaWithPOS as ArgumentModel
|
41 |
+
model_nugget = ArgumentModel(num_classes=43)
|
42 |
+
model_nugget.load_state_dict(torch.load("cybersecurity_knowledge_graph/argument_model_state_dict.pth", map_location=device))
|
43 |
+
model_nugget.eval()
|
44 |
+
|
45 |
+
"""
|
46 |
+
Function: create_dataloader(text_input)
|
47 |
+
Description: This function creates a DataLoader for processing text data, tokenizes it, and organizes it into batches.
|
48 |
+
Inputs:
|
49 |
+
- text_input: The input text to be processed.
|
50 |
+
Output:
|
51 |
+
- dataloader: A DataLoader for the tokenized and batched text data.
|
52 |
+
- tokenized_dataset_ner: The tokenized dataset used for training.
|
53 |
+
"""
|
54 |
+
def create_dataloader(text_input):
|
55 |
+
|
56 |
+
event_nuggets = get_event_nuggets(text_input)
|
57 |
+
doc = nlp(text_input)
|
58 |
+
|
59 |
+
content_as_words_emdash = [tok.text for tok in doc]
|
60 |
+
content_as_words_emdash = [word.replace("``", '"').replace("''", '"').replace("$", "") for word in content_as_words_emdash]
|
61 |
+
content_idx_dict = get_idxs_from_text(text_input, content_as_words_emdash)
|
62 |
+
|
63 |
+
data = []
|
64 |
+
|
65 |
+
words = []
|
66 |
+
arg_nugget_nearest_subtype = []
|
67 |
+
arg_nugget_nearest_dist = []
|
68 |
+
arg_nugget_relative_pos = []
|
69 |
+
|
70 |
+
pos_spacy = [tok.pos_ for tok in doc]
|
71 |
+
ner_spacy = [ent.ent_iob_ + "-" + ent.ent_type_ if ent.ent_iob_ != "O" else ent.ent_iob_ for ent in doc]
|
72 |
+
dep_spacy = [tok.dep_ for tok in doc]
|
73 |
+
depth_spacy = [find_dep_depth(tok) for tok in doc]
|
74 |
+
|
75 |
+
for content_dict in content_idx_dict:
|
76 |
+
start_idx, end_idx = content_dict["start_idx"], content_dict["end_idx"]
|
77 |
+
nearest_subtype, nearest_dist, relative_pos = find_nearest_nugget_features(doc, content_dict["start_idx"], content_dict["end_idx"], event_nuggets)
|
78 |
+
words.append(content_dict["word"])
|
79 |
+
|
80 |
+
arg_nugget_nearest_subtype.append(nearest_subtype)
|
81 |
+
arg_nugget_nearest_dist.append(nearest_dist)
|
82 |
+
arg_nugget_relative_pos.append(relative_pos)
|
83 |
+
|
84 |
+
|
85 |
+
content_token_len = len(tokenizer(words, truncation=False, is_split_into_words=True)["input_ids"])
|
86 |
+
if content_token_len > tokenizer.model_max_length:
|
87 |
+
no_split = (content_token_len // tokenizer.model_max_length) + 2
|
88 |
+
split_len = (len(words) // no_split) + 1
|
89 |
+
|
90 |
+
last_id = 0
|
91 |
+
threshold = split_len
|
92 |
+
|
93 |
+
for id, token in enumerate(words):
|
94 |
+
if token == "." and id > threshold:
|
95 |
+
data.append(
|
96 |
+
{
|
97 |
+
"tokens" : words[last_id : id + 1],
|
98 |
+
"pos_spacy" : pos_spacy[last_id : id + 1],
|
99 |
+
"ner_spacy" : ner_spacy[last_id : id + 1],
|
100 |
+
"dep_spacy" : dep_spacy[last_id : id + 1],
|
101 |
+
"depth_spacy" : depth_spacy[last_id : id + 1],
|
102 |
+
"nearest_nugget_subtype" : arg_nugget_nearest_subtype[last_id : id + 1],
|
103 |
+
"nearest_nugget_dist" : arg_nugget_nearest_dist[last_id : id + 1],
|
104 |
+
"arg_nugget_relative_pos" : arg_nugget_relative_pos[last_id : id + 1]
|
105 |
+
}
|
106 |
+
)
|
107 |
+
last_id = id + 1
|
108 |
+
threshold += split_len
|
109 |
+
data.append({"tokens" : words[last_id : ],
|
110 |
+
"pos_spacy" : pos_spacy[last_id : ],
|
111 |
+
"ner_spacy" : ner_spacy[last_id : ],
|
112 |
+
"dep_spacy" : dep_spacy[last_id : ],
|
113 |
+
"depth_spacy" : depth_spacy[last_id : ],
|
114 |
+
"nearest_nugget_subtype" : arg_nugget_nearest_subtype[last_id : ],
|
115 |
+
"nearest_nugget_dist" : arg_nugget_nearest_dist[last_id : ],
|
116 |
+
"arg_nugget_relative_pos" : arg_nugget_relative_pos[last_id : ]})
|
117 |
+
else:
|
118 |
+
data.append(
|
119 |
+
{
|
120 |
+
"tokens" : words,
|
121 |
+
"pos_spacy" : pos_spacy,
|
122 |
+
"ner_spacy" : ner_spacy,
|
123 |
+
"dep_spacy" : dep_spacy,
|
124 |
+
"depth_spacy" : depth_spacy,
|
125 |
+
"nearest_nugget_subtype" : arg_nugget_nearest_subtype,
|
126 |
+
"nearest_nugget_dist" : arg_nugget_nearest_dist,
|
127 |
+
"arg_nugget_relative_pos" : arg_nugget_relative_pos
|
128 |
+
}
|
129 |
+
)
|
130 |
+
|
131 |
+
|
132 |
+
ner_features = Features({'tokens' : Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
|
133 |
+
'pos_spacy' : Sequence(feature=ClassLabel(num_classes=len(pos_spacy_tag_list), names=pos_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
|
134 |
+
'ner_spacy' : Sequence(feature=ClassLabel(num_classes=len(ner_spacy_tag_list), names=ner_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
|
135 |
+
'dep_spacy' : Sequence(feature=ClassLabel(num_classes=len(dep_spacy_tag_list), names=dep_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
|
136 |
+
'depth_spacy' : Sequence(feature=ClassLabel(num_classes=17, names= list(range(17)), names_file=None, id=None), length=-1, id=None),
|
137 |
+
'nearest_nugget_subtype' : Sequence(feature=ClassLabel(num_classes=len(event_nugget_tag_list), names=event_nugget_tag_list, names_file=None, id=None), length=-1, id=None),
|
138 |
+
'nearest_nugget_dist' : Sequence(feature=ClassLabel(num_classes=11, names=list(range(11)), names_file=None, id=None), length=-1, id=None),
|
139 |
+
'arg_nugget_relative_pos' : Sequence(feature=ClassLabel(num_classes=len(arg_nugget_relative_pos_tag_list), names=arg_nugget_relative_pos_tag_list, names_file=None, id=None), length=-1, id=None),
|
140 |
+
})
|
141 |
+
|
142 |
+
dataset = Dataset.from_list(data, features=ner_features)
|
143 |
+
tokenized_dataset_ner = dataset.map(tokenize_and_align_labels_with_pos_ner_dep, fn_kwargs={'tokenizer' : tokenizer}, batched=True, load_from_cache_file=False)
|
144 |
+
tokenized_dataset_ner = tokenized_dataset_ner.with_format("torch")
|
145 |
+
|
146 |
+
tokenized_dataset_ner = tokenized_dataset_ner.remove_columns("tokens")
|
147 |
+
|
148 |
+
batch_size = 4 # Number of input texts
|
149 |
+
dataloader = DataLoader(tokenized_dataset_ner, batch_size=batch_size)
|
150 |
+
return dataloader, tokenized_dataset_ner
|
151 |
+
|
152 |
+
"""
|
153 |
+
Function: predict(dataloader)
|
154 |
+
Description: This function performs prediction on a given dataloader using a trained model for label classification.
|
155 |
+
Inputs:
|
156 |
+
- dataloader: A DataLoader containing the input data for prediction.
|
157 |
+
Output:
|
158 |
+
- predicted_label: A tensor containing the predicted labels for each input in the dataloader.
|
159 |
+
"""
|
160 |
+
def predict(dataloader):
|
161 |
+
predicted_label = []
|
162 |
+
for batch in dataloader:
|
163 |
+
with torch.no_grad():
|
164 |
+
logits = model_nugget(**batch)
|
165 |
+
|
166 |
+
batch_predicted_label = logits.argmax(-1)
|
167 |
+
predicted_label.append(batch_predicted_label)
|
168 |
+
return torch.cat(predicted_label, dim=-1)
|
169 |
+
|
170 |
+
"""
|
171 |
+
Function: show_annotations(text_input)
|
172 |
+
Description: This function displays annotated event arguments in the provided input text.
|
173 |
+
Inputs:
|
174 |
+
- text_input: The input text containing event arguments to be annotated and displayed.
|
175 |
+
Output:
|
176 |
+
- An interactive display of annotated event arguments within the input text.
|
177 |
+
"""
|
178 |
+
def show_annotations(text_input):
|
179 |
+
st.title("Event Arguments")
|
180 |
+
|
181 |
+
dataloader, tokenized_dataset_ner = create_dataloader(text_input)
|
182 |
+
predicted_label = predict(dataloader)
|
183 |
+
|
184 |
+
for idx, labels in enumerate(predicted_label):
|
185 |
+
token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
|
186 |
+
|
187 |
+
tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
|
188 |
+
tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
|
189 |
+
|
190 |
+
text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
|
191 |
+
idxs = get_idxs_from_text(text, tokens)
|
192 |
+
|
193 |
+
labels = labels[token_mask]
|
194 |
+
|
195 |
+
annotated_text_list = []
|
196 |
+
last_label = ""
|
197 |
+
cumulative_tokens = ""
|
198 |
+
last_id = 0
|
199 |
+
|
200 |
+
for idx, label in zip(idxs, labels):
|
201 |
+
to_label = event_args_list[label]
|
202 |
+
label_short = to_label.split("-")[1] if "-" in to_label else to_label
|
203 |
+
if last_label == label_short:
|
204 |
+
cumulative_tokens += text[last_id : idx["end_idx"]]
|
205 |
+
last_id = idx["end_idx"]
|
206 |
+
else:
|
207 |
+
if last_label != "":
|
208 |
+
if last_label == "O":
|
209 |
+
annotated_text_list.append(cumulative_tokens)
|
210 |
+
else:
|
211 |
+
annotated_text_list.append((cumulative_tokens, last_label))
|
212 |
+
last_label = label_short
|
213 |
+
cumulative_tokens = idx["word"]
|
214 |
+
last_id = idx["end_idx"]
|
215 |
+
if last_label == "O":
|
216 |
+
annotated_text_list.append(cumulative_tokens)
|
217 |
+
else:
|
218 |
+
annotated_text_list.append((cumulative_tokens, last_label))
|
219 |
+
|
220 |
+
annotated_text(annotated_text_list)
|
221 |
+
|
222 |
+
"""
|
223 |
+
Function: get_event_args(text_input)
|
224 |
+
Description: This function extracts predicted event arguments (event nuggets) from the provided input text.
|
225 |
+
Inputs:
|
226 |
+
- text_input: The input text containing event nuggets to be extracted.
|
227 |
+
Output:
|
228 |
+
- predicted_event_nuggets: A list of dictionaries, each representing an extracted event nugget with start and end offsets,
|
229 |
+
subtype, and text content.
|
230 |
+
"""
|
231 |
+
def get_event_args(text_input):
|
232 |
+
dataloader, tokenized_dataset_ner = create_dataloader(text_input)
|
233 |
+
predicted_label = predict(dataloader)
|
234 |
+
|
235 |
+
predicted_event_nuggets = []
|
236 |
+
text_length = 0
|
237 |
+
for idx, labels in enumerate(predicted_label):
|
238 |
+
token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
|
239 |
+
|
240 |
+
tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
|
241 |
+
tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
|
242 |
+
|
243 |
+
text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
|
244 |
+
idxs = get_idxs_from_text(text_input[text_length : ], tokens)
|
245 |
+
|
246 |
+
labels = labels[token_mask]
|
247 |
+
|
248 |
+
start_idx = 0
|
249 |
+
end_idx = 0
|
250 |
+
last_label = ""
|
251 |
+
|
252 |
+
for idx, label in zip(idxs, labels):
|
253 |
+
to_label = event_args_list[label]
|
254 |
+
if "-" in to_label:
|
255 |
+
label_split = to_label.split("-")[1]
|
256 |
+
else:
|
257 |
+
label_split = to_label
|
258 |
+
|
259 |
+
if label_split == last_label:
|
260 |
+
end_idx = idx["end_idx"]
|
261 |
+
else:
|
262 |
+
if text_input[start_idx : end_idx] != "" and last_label != "O":
|
263 |
+
predicted_event_nuggets.append(
|
264 |
+
{
|
265 |
+
"startOffset" : text_length + start_idx,
|
266 |
+
"endOffset" : text_length + end_idx,
|
267 |
+
"subtype" : last_label,
|
268 |
+
"text" : text_input[text_length + start_idx : text_length + end_idx]
|
269 |
+
}
|
270 |
+
)
|
271 |
+
start_idx = idx["start_idx"]
|
272 |
+
end_idx = idx["start_idx"] + len(idx["word"])
|
273 |
+
last_label = label_split
|
274 |
+
text_length += idx["end_idx"]
|
275 |
+
return predicted_event_nuggets
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
|
event_arg_role_dataloader.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, json
|
2 |
+
from cybersecurity_knowledge_graph.utils import get_content, get_event_args, get_event_nugget, get_idxs_from_text, get_args_entity_from_idx, find_dict_by_overlap
|
3 |
+
from tqdm import tqdm
|
4 |
+
import spacy
|
5 |
+
import jsonlines
|
6 |
+
from sklearn.model_selection import train_test_split
|
7 |
+
import math
|
8 |
+
from transformers import pipeline
|
9 |
+
from sentence_transformers import SentenceTransformer
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
embed_model = SentenceTransformer('all-MiniLM-L6-v2')
|
13 |
+
|
14 |
+
pipe = pipeline("token-classification", model="CyberPeace-Institute/SecureBERT-NER")
|
15 |
+
|
16 |
+
nlp = spacy.load('en_core_web_sm')
|
17 |
+
|
18 |
+
"""
|
19 |
+
Class: EventArgumentRoleDataset
|
20 |
+
Description: This class represents a dataset for training and evaluating event argument role classifiers.
|
21 |
+
Attributes:
|
22 |
+
- path: The path to the folder containing JSON files with event data.
|
23 |
+
- tokenizer: A tokenizer for encoding text data.
|
24 |
+
- arg: The specific argument type (subtype) for which the dataset is created.
|
25 |
+
- data: A list to store data samples, each consisting of an embedding and a label.
|
26 |
+
- train_data, val_data, test_data: Lists to store the split training, validation, and test data samples.
|
27 |
+
- datapoint_id: An identifier for tracking data samples.
|
28 |
+
Methods:
|
29 |
+
- __len__(): Returns the total number of data samples in the dataset.
|
30 |
+
- __getitem__(index): Retrieves a data sample at a specified index.
|
31 |
+
- to_jsonlines(train_path, val_path, test_path): Writes the dataset to JSON files for train, validation, and test sets.
|
32 |
+
- train_val_test_split(): Splits the data into training and test sets.
|
33 |
+
- load_data(): Loads and preprocesses event data from JSON files, creating embeddings for argument-role classification.
|
34 |
+
"""
|
35 |
+
class EventArgumentRoleDataset():
|
36 |
+
def __init__(self, path, tokenizer, arg):
|
37 |
+
self.path = path
|
38 |
+
self.tokenizer = tokenizer
|
39 |
+
self.arg = arg
|
40 |
+
self.data = []
|
41 |
+
self.train_data, self.val_data, self.test_data = None, None, None
|
42 |
+
self.datapoint_id = 0
|
43 |
+
|
44 |
+
def __len__(self):
|
45 |
+
return len(self.data)
|
46 |
+
|
47 |
+
def __getitem__(self, index):
|
48 |
+
sample = self.data[index]
|
49 |
+
return sample
|
50 |
+
|
51 |
+
def to_jsonlines(self, train_path, val_path, test_path):
|
52 |
+
if self.train_data is None or self.test_data is None:
|
53 |
+
raise ValueError("Do the train-val-test split")
|
54 |
+
with jsonlines.open(train_path, "w") as f:
|
55 |
+
f.write_all(self.train_data)
|
56 |
+
# with jsonlines.open(val_path, "w") as f:
|
57 |
+
# f.write_all(self.val_data)
|
58 |
+
with jsonlines.open(test_path, "w") as f:
|
59 |
+
f.write_all(self.test_data)
|
60 |
+
|
61 |
+
def train_val_test_split(self):
|
62 |
+
self.train_data, self.test_data = train_test_split(self.data, test_size=0.1, random_state=42, shuffle=True)
|
63 |
+
# self.val_data, self.test_data = train_test_split(test_val, test_size=0.5, random_state=42, shuffle=True)
|
64 |
+
|
65 |
+
def load_data(self):
|
66 |
+
folder_path = self.path
|
67 |
+
json_files = [file for file in os.listdir(folder_path) if file.endswith('.json')]
|
68 |
+
|
69 |
+
# Load the nuggets
|
70 |
+
for idx, file_path in enumerate(tqdm(json_files)):
|
71 |
+
try:
|
72 |
+
with open(self.path + file_path, "r") as f:
|
73 |
+
file_json = json.load(f)
|
74 |
+
except:
|
75 |
+
print("Error in ", file_path)
|
76 |
+
content = get_content(file_json)
|
77 |
+
content = content.replace("\xa0", " ")
|
78 |
+
|
79 |
+
event_args = get_event_args(file_json)
|
80 |
+
doc = nlp(content)
|
81 |
+
|
82 |
+
sentence_indexes = []
|
83 |
+
for sent in doc.sents:
|
84 |
+
start_index = sent[0].idx
|
85 |
+
end_index = sent[-1].idx + len(sent[-1].text)
|
86 |
+
sentence_indexes.append((start_index, end_index))
|
87 |
+
|
88 |
+
for idx, (start, end) in enumerate(sentence_indexes):
|
89 |
+
sentence = content[start:end]
|
90 |
+
is_arg_sentence = [event_arg["startOffset"] >= start and event_arg["endOffset"] <= end for event_arg in event_args]
|
91 |
+
args = [event_args[idx] for idx, boolean in enumerate(is_arg_sentence) if boolean]
|
92 |
+
if args != []:
|
93 |
+
sentence_doc = nlp(sentence)
|
94 |
+
sentence_embed = embed_model.encode(sentence)
|
95 |
+
for arg in args:
|
96 |
+
if arg["type"] == self.arg:
|
97 |
+
arg_embed = embed_model.encode(arg["text"])
|
98 |
+
embedding = np.concatenate((sentence_embed, arg_embed))
|
99 |
+
|
100 |
+
self.data.append({"embedding" : embedding, "label" : arg["role"]["type"]})
|
event_arg_role_predict.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from cybersecurity_knowledge_graph.event_arg_role_dataloader import EventArgumentRoleDataset
|
2 |
+
from cybersecurity_knowledge_graph.utils import arg_2_role
|
3 |
+
|
4 |
+
import os
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
import optuna
|
7 |
+
from sklearn.model_selection import StratifiedKFold
|
8 |
+
from sklearn.model_selection import cross_val_score
|
9 |
+
from sklearn.metrics import make_scorer, f1_score
|
10 |
+
from sklearn.ensemble import VotingClassifier
|
11 |
+
from sklearn.linear_model import LogisticRegression
|
12 |
+
from sklearn.neural_network import MLPClassifier
|
13 |
+
from sklearn.svm import SVC
|
14 |
+
from joblib import dump, load
|
15 |
+
from sentence_transformers import SentenceTransformer
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
embed_model = SentenceTransformer('all-MiniLM-L6-v2')
|
19 |
+
|
20 |
+
model_checkpoint = "ehsanaghaei/SecureBERT"
|
21 |
+
|
22 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
23 |
+
|
24 |
+
classifiers = {}
|
25 |
+
folder_path = '/cybersecurity_knowledge_graph/arg_role_models'
|
26 |
+
|
27 |
+
for filename in os.listdir(os.getcwd() + folder_path):
|
28 |
+
if filename.endswith('.joblib'):
|
29 |
+
file_path = os.getcwd() + os.path.join(folder_path, filename)
|
30 |
+
clf = load(file_path)
|
31 |
+
arg = filename.split(".")[0]
|
32 |
+
classifiers[arg] = clf
|
33 |
+
|
34 |
+
"""
|
35 |
+
Function: fit()
|
36 |
+
Description: This function performs a machine learning task to train and evaluate classifiers for multiple argument roles.
|
37 |
+
It utilizes Optuna for hyperparameter optimization and creates a Voting Classifier.
|
38 |
+
The trained classifiers are saved as joblib files.
|
39 |
+
"""
|
40 |
+
def fit():
|
41 |
+
for arg, roles in arg_2_role.items():
|
42 |
+
if len(roles) > 1:
|
43 |
+
|
44 |
+
dataset = EventArgumentRoleDataset(path="./data/annotation/", tokenizer=tokenizer, arg=arg)
|
45 |
+
dataset.load_data()
|
46 |
+
dataset.train_val_test_split()
|
47 |
+
|
48 |
+
|
49 |
+
X = [datapoint["embedding"] for datapoint in dataset.data]
|
50 |
+
y = [roles.index(datapoint["label"]) for datapoint in dataset.data]
|
51 |
+
|
52 |
+
|
53 |
+
# FYI: Objective functions can take additional arguments
|
54 |
+
# (https://optuna.readthedocs.io/en/stable/faq.html#objective-func-additional-args).
|
55 |
+
def objective(trial):
|
56 |
+
|
57 |
+
classifier_name = trial.suggest_categorical("classifier", ["voting"])
|
58 |
+
if classifier_name == "voting":
|
59 |
+
svc_c = trial.suggest_float("svc_c", 1e-3, 1e3, log=True)
|
60 |
+
svc_kernel = trial.suggest_categorical("kernel", ['rbf'])
|
61 |
+
classifier_obj = VotingClassifier(estimators=[
|
62 |
+
('Logistic Regression', LogisticRegression()),
|
63 |
+
('Neural Network', MLPClassifier(max_iter=500)),
|
64 |
+
('Support Vector Machine', SVC(C=svc_c, kernel=svc_kernel))
|
65 |
+
], voting='hard')
|
66 |
+
|
67 |
+
f1_scorer = make_scorer(f1_score, average = "weighted")
|
68 |
+
stratified_kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
|
69 |
+
cv_scores = cross_val_score(classifier_obj, X, y, cv=stratified_kfold, scoring=f1_scorer)
|
70 |
+
return cv_scores.mean()
|
71 |
+
|
72 |
+
|
73 |
+
study = optuna.create_study(direction="maximize")
|
74 |
+
study.optimize(objective, n_trials=20)
|
75 |
+
print(f"{arg} : {study.best_trial.values[0]}")
|
76 |
+
|
77 |
+
best_clf = VotingClassifier(estimators=[
|
78 |
+
('Logistic Regression', LogisticRegression()),
|
79 |
+
('Neural Network', MLPClassifier(max_iter=500)),
|
80 |
+
('Support Vector Machine', SVC(C=study.best_trial.params["svc_c"], kernel=study.best_trial.params["kernel"]))
|
81 |
+
], voting='hard')
|
82 |
+
|
83 |
+
best_clf.fit(X, y)
|
84 |
+
dump(best_clf, f'{arg}.joblib')
|
85 |
+
|
86 |
+
"""
|
87 |
+
Function: get_arg_roles(event_args, doc)
|
88 |
+
Description: This function assigns argument roles to a list of event arguments within a document.
|
89 |
+
Inputs:
|
90 |
+
- event_args: A list of event argument dictionaries, each containing information about an argument.
|
91 |
+
- doc: A spaCy document representing the analyzed text.
|
92 |
+
Output:
|
93 |
+
- The input 'event_args' list with updated 'role' values assigned to each argument.
|
94 |
+
"""
|
95 |
+
def get_arg_roles(event_args, doc):
|
96 |
+
for arg in event_args:
|
97 |
+
if len(arg_2_role[arg["subtype"]]) > 1:
|
98 |
+
sent = next(filter(lambda x : arg["startOffset"] >= x.start_char and arg["endOffset"] <= x.end_char, doc.sents))
|
99 |
+
|
100 |
+
sent_embed = embed_model.encode(sent.text)
|
101 |
+
arg_embed = embed_model.encode(arg["text"])
|
102 |
+
embed = np.concatenate((sent_embed, arg_embed))
|
103 |
+
|
104 |
+
arg_clf = classifiers[arg["subtype"]]
|
105 |
+
role_id = arg_clf.predict(embed.reshape(1, -1))
|
106 |
+
role = arg_2_role[arg["subtype"]][role_id[0]]
|
107 |
+
|
108 |
+
arg["role"] = role
|
109 |
+
else:
|
110 |
+
arg["role"] = arg_2_role[arg["subtype"]][0]
|
111 |
+
return event_args
|
112 |
+
|
113 |
+
|
event_nugget_predict.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from annotated_text import annotated_text
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from cybersecurity_knowledge_graph.nugget_model_utils import CustomRobertaWithPOS as NuggetModel
|
7 |
+
from cybersecurity_knowledge_graph.nugget_model_utils import tokenize_and_align_labels_with_pos_ner_dep, find_nearest_nugget_features, find_dep_depth
|
8 |
+
from cybersecurity_knowledge_graph.utils import get_idxs_from_text, event_nugget_list
|
9 |
+
import spacy
|
10 |
+
from transformers import AutoTokenizer
|
11 |
+
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
|
12 |
+
import os
|
13 |
+
|
14 |
+
|
15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
16 |
+
|
17 |
+
def find_dep_depth(token):
|
18 |
+
depth = 0
|
19 |
+
current_token = token
|
20 |
+
while current_token.head != current_token:
|
21 |
+
depth += 1
|
22 |
+
current_token = current_token.head
|
23 |
+
return min(depth, 16)
|
24 |
+
|
25 |
+
|
26 |
+
nlp = spacy.load('en_core_web_sm')
|
27 |
+
|
28 |
+
pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
|
29 |
+
ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
|
30 |
+
dep_spacy_tag_list = list(nlp.get_pipe("parser").labels)
|
31 |
+
|
32 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
33 |
+
|
34 |
+
model_checkpoint = "ehsanaghaei/SecureBERT"
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
36 |
+
|
37 |
+
model_nugget = NuggetModel(num_classes = 11)
|
38 |
+
model_nugget.load_state_dict(torch.load("cybersecurity_knowledge_graph/nugget_model_state_dict.pth", map_location=device))
|
39 |
+
model_nugget.eval()
|
40 |
+
|
41 |
+
"""
|
42 |
+
Function: create_dataloader(text_input)
|
43 |
+
Description: This function prepares a DataLoader for processing text input, including tokenization and alignment of labels.
|
44 |
+
Inputs:
|
45 |
+
- text_input: The input text to be processed.
|
46 |
+
Output:
|
47 |
+
- dataloader: A DataLoader for the tokenized and batched text data.
|
48 |
+
- tokenized_dataset_ner: The tokenized dataset used for training.
|
49 |
+
"""
|
50 |
+
def create_dataloader(text_input):
|
51 |
+
|
52 |
+
doc = nlp(text_input)
|
53 |
+
|
54 |
+
content_as_words_emdash = [tok.text for tok in doc]
|
55 |
+
content_as_words_emdash = [word.replace("``", '"').replace("''", '"').replace("$", "") for word in content_as_words_emdash]
|
56 |
+
content_idx_dict = get_idxs_from_text(text_input, content_as_words_emdash)
|
57 |
+
|
58 |
+
data = []
|
59 |
+
|
60 |
+
words = []
|
61 |
+
|
62 |
+
pos_spacy = [tok.pos_ for tok in doc]
|
63 |
+
ner_spacy = [ent.ent_iob_ + "-" + ent.ent_type_ if ent.ent_iob_ != "O" else ent.ent_iob_ for ent in doc]
|
64 |
+
dep_spacy = [tok.dep_ for tok in doc]
|
65 |
+
depth_spacy = [find_dep_depth(tok) for tok in doc]
|
66 |
+
|
67 |
+
for content_dict in content_idx_dict:
|
68 |
+
start_idx, end_idx = content_dict["start_idx"], content_dict["end_idx"]
|
69 |
+
words.append(content_dict["word"])
|
70 |
+
|
71 |
+
|
72 |
+
content_token_len = len(tokenizer(words, truncation=False, is_split_into_words=True)["input_ids"])
|
73 |
+
if content_token_len > tokenizer.model_max_length:
|
74 |
+
no_split = (content_token_len // tokenizer.model_max_length) + 2
|
75 |
+
split_len = (len(words) // no_split) + 1
|
76 |
+
|
77 |
+
last_id = 0
|
78 |
+
threshold = split_len
|
79 |
+
|
80 |
+
for id, token in enumerate(words):
|
81 |
+
if token == "." and id > threshold:
|
82 |
+
data.append(
|
83 |
+
{
|
84 |
+
"tokens" : words[last_id : id + 1],
|
85 |
+
"pos_spacy" : pos_spacy[last_id : id + 1],
|
86 |
+
"ner_spacy" : ner_spacy[last_id : id + 1],
|
87 |
+
"dep_spacy" : dep_spacy[last_id : id + 1],
|
88 |
+
"depth_spacy" : depth_spacy[last_id : id + 1],
|
89 |
+
}
|
90 |
+
)
|
91 |
+
last_id = id + 1
|
92 |
+
threshold += split_len
|
93 |
+
data.append({"tokens" : words[last_id : ],
|
94 |
+
"pos_spacy" : pos_spacy[last_id : ],
|
95 |
+
"ner_spacy" : ner_spacy[last_id : ],
|
96 |
+
"dep_spacy" : dep_spacy[last_id : ],
|
97 |
+
"depth_spacy" : depth_spacy[last_id : ]})
|
98 |
+
else:
|
99 |
+
data.append(
|
100 |
+
{
|
101 |
+
"tokens" : words,
|
102 |
+
"pos_spacy" : pos_spacy,
|
103 |
+
"ner_spacy" : ner_spacy,
|
104 |
+
"dep_spacy" : dep_spacy,
|
105 |
+
"depth_spacy" : depth_spacy
|
106 |
+
}
|
107 |
+
)
|
108 |
+
|
109 |
+
|
110 |
+
ner_features = Features({'tokens' : Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
|
111 |
+
'pos_spacy' : Sequence(feature=ClassLabel(num_classes=len(pos_spacy_tag_list), names=pos_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
|
112 |
+
'ner_spacy' : Sequence(feature=ClassLabel(num_classes=len(ner_spacy_tag_list), names=ner_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
|
113 |
+
'dep_spacy' : Sequence(feature=ClassLabel(num_classes=len(dep_spacy_tag_list), names=dep_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
|
114 |
+
'depth_spacy' : Sequence(feature=ClassLabel(num_classes=17, names= list(range(17)), names_file=None, id=None), length=-1, id=None)
|
115 |
+
})
|
116 |
+
|
117 |
+
dataset = Dataset.from_list(data, features=ner_features)
|
118 |
+
tokenized_dataset_ner = dataset.map(tokenize_and_align_labels_with_pos_ner_dep, fn_kwargs={'tokenizer' : tokenizer}, batched=True, load_from_cache_file=False)
|
119 |
+
tokenized_dataset_ner = tokenized_dataset_ner.with_format("torch")
|
120 |
+
|
121 |
+
tokenized_dataset_ner = tokenized_dataset_ner.remove_columns("tokens")
|
122 |
+
|
123 |
+
batch_size = 4 # Number of input texts
|
124 |
+
dataloader = DataLoader(tokenized_dataset_ner, batch_size=batch_size)
|
125 |
+
# TODO : context_idx_dict should be used to index the words
|
126 |
+
return dataloader, tokenized_dataset_ner
|
127 |
+
|
128 |
+
"""
|
129 |
+
Function: predict(dataloader)
|
130 |
+
Description: This function performs inference on a given DataLoader using a trained model and returns the predicted labels.
|
131 |
+
Inputs:
|
132 |
+
- dataloader: A DataLoader containing input data for prediction.
|
133 |
+
Output:
|
134 |
+
- predicted_label: A tensor containing the predicted labels for the input data.
|
135 |
+
"""
|
136 |
+
def predict(dataloader):
|
137 |
+
predicted_label = []
|
138 |
+
for batch in dataloader:
|
139 |
+
with torch.no_grad():
|
140 |
+
logits = model_nugget(**batch)
|
141 |
+
batch_predicted_label = logits.argmax(-1)
|
142 |
+
predicted_label.append(batch_predicted_label)
|
143 |
+
return torch.cat(predicted_label, dim=-1)
|
144 |
+
|
145 |
+
"""
|
146 |
+
Function: show_annotations(text_input)
|
147 |
+
Description: This function displays annotated event nuggets in the provided input text using the Streamlit library.
|
148 |
+
Inputs:
|
149 |
+
- text_input: The input text containing event nuggets to be annotated and displayed.
|
150 |
+
Output:
|
151 |
+
- An interactive display of annotated event nuggets within the input text.
|
152 |
+
"""
|
153 |
+
def show_annotations(text_input):
|
154 |
+
st.title("Event Nuggets")
|
155 |
+
|
156 |
+
dataloader, tokenized_dataset_ner = create_dataloader(text_input)
|
157 |
+
predicted_label = predict(dataloader)
|
158 |
+
|
159 |
+
for idx, labels in enumerate(predicted_label):
|
160 |
+
token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
|
161 |
+
|
162 |
+
tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
|
163 |
+
tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
|
164 |
+
|
165 |
+
text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
|
166 |
+
idxs = get_idxs_from_text(text, tokens)
|
167 |
+
|
168 |
+
labels = labels[token_mask]
|
169 |
+
|
170 |
+
annotated_text_list = []
|
171 |
+
last_label = ""
|
172 |
+
cumulative_tokens = ""
|
173 |
+
last_id = 0
|
174 |
+
|
175 |
+
for idx, label in zip(idxs, labels):
|
176 |
+
to_label = event_nugget_list[label]
|
177 |
+
label_short = to_label.split("-")[1] if "-" in to_label else to_label
|
178 |
+
if last_label == label_short:
|
179 |
+
cumulative_tokens += text[last_id : idx["end_idx"]]
|
180 |
+
last_id = idx["end_idx"]
|
181 |
+
else:
|
182 |
+
if last_label != "":
|
183 |
+
if last_label == "O":
|
184 |
+
annotated_text_list.append(cumulative_tokens)
|
185 |
+
else:
|
186 |
+
annotated_text_list.append((cumulative_tokens, last_label))
|
187 |
+
last_label = label_short
|
188 |
+
cumulative_tokens = idx["word"]
|
189 |
+
last_id = idx["end_idx"]
|
190 |
+
if last_label == "O":
|
191 |
+
annotated_text_list.append(cumulative_tokens)
|
192 |
+
else:
|
193 |
+
annotated_text_list.append((cumulative_tokens, last_label))
|
194 |
+
annotated_text(annotated_text_list)
|
195 |
+
|
196 |
+
"""
|
197 |
+
Function: get_event_nuggets(text_input)
|
198 |
+
Description: This function extracts predicted event nuggets (event entities) from the provided input text.
|
199 |
+
Inputs:
|
200 |
+
- text_input: The input text containing event nuggets to be extracted.
|
201 |
+
Output:
|
202 |
+
- predicted_event_nuggets: A list of dictionaries, each representing an extracted event nugget with start and end offsets,
|
203 |
+
subtype, and text content.
|
204 |
+
"""
|
205 |
+
def get_event_nuggets(text_input):
|
206 |
+
dataloader, tokenized_dataset_ner = create_dataloader(text_input)
|
207 |
+
predicted_label = predict(dataloader)
|
208 |
+
|
209 |
+
predicted_event_nuggets = []
|
210 |
+
text_length = 0
|
211 |
+
for idx, labels in enumerate(predicted_label):
|
212 |
+
token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
|
213 |
+
|
214 |
+
tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
|
215 |
+
tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
|
216 |
+
|
217 |
+
text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
|
218 |
+
idxs = get_idxs_from_text(text_input[text_length : ], tokens)
|
219 |
+
|
220 |
+
labels = labels[token_mask]
|
221 |
+
|
222 |
+
start_idx = 0
|
223 |
+
end_idx = 0
|
224 |
+
last_label = ""
|
225 |
+
|
226 |
+
for idx, label in zip(idxs, labels):
|
227 |
+
to_label = event_nugget_list[label]
|
228 |
+
label_short = to_label.split("-")[1] if "-" in to_label else to_label
|
229 |
+
|
230 |
+
if label_short == last_label:
|
231 |
+
end_idx = idx["end_idx"]
|
232 |
+
else:
|
233 |
+
if text_input[start_idx : end_idx] != "" and last_label != "O":
|
234 |
+
predicted_event_nuggets.append(
|
235 |
+
{
|
236 |
+
"startOffset" : text_length + start_idx,
|
237 |
+
"endOffset" : text_length + end_idx,
|
238 |
+
"subtype" : last_label,
|
239 |
+
"text" : text_input[text_length + start_idx : text_length + end_idx]
|
240 |
+
}
|
241 |
+
)
|
242 |
+
start_idx = idx["start_idx"]
|
243 |
+
end_idx = idx["start_idx"] + len(idx["word"])
|
244 |
+
last_label = label_short
|
245 |
+
|
246 |
+
text_length += idx["end_idx"]
|
247 |
+
return predicted_event_nuggets
|
248 |
+
|
249 |
+
|
250 |
+
|
event_realis_predict.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import spacy
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
from cybersecurity_knowledge_graph.utils import get_idxs_from_text
|
7 |
+
import streamlit as st
|
8 |
+
from annotated_text import annotated_text
|
9 |
+
from cybersecurity_knowledge_graph.nugget_model_utils import CustomRobertaWithPOS
|
10 |
+
from cybersecurity_knowledge_graph.event_nugget_predict import get_event_nuggets
|
11 |
+
from cybersecurity_knowledge_graph.realis_model_utils import get_entity_for_realis_from_idx, tokenize_and_align_labels_with_pos_ner_realis
|
12 |
+
from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
|
13 |
+
|
14 |
+
event_nugget_list = ['B-Phishing',
|
15 |
+
'I-Phishing',
|
16 |
+
'O',
|
17 |
+
'B-DiscoverVulnerability',
|
18 |
+
'B-Ransom',
|
19 |
+
'I-Ransom',
|
20 |
+
'B-Databreach',
|
21 |
+
'I-DiscoverVulnerability',
|
22 |
+
'B-PatchVulnerability',
|
23 |
+
'I-PatchVulnerability',
|
24 |
+
'I-Databreach']
|
25 |
+
|
26 |
+
realis_list = ["O", "Generic", "Other", "Actual"]
|
27 |
+
|
28 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
def find_dep_depth(token):
|
33 |
+
depth = 0
|
34 |
+
current_token = token
|
35 |
+
while current_token.head != current_token:
|
36 |
+
depth += 1
|
37 |
+
current_token = current_token.head
|
38 |
+
return min(depth, 16)
|
39 |
+
|
40 |
+
|
41 |
+
nlp = spacy.load('en_core_web_sm')
|
42 |
+
|
43 |
+
pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
|
44 |
+
ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
|
45 |
+
dep_spacy_tag_list = list(nlp.get_pipe("parser").labels)
|
46 |
+
|
47 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
48 |
+
|
49 |
+
model_checkpoint = "ehsanaghaei/SecureBERT"
|
50 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
51 |
+
|
52 |
+
from cybersecurity_knowledge_graph.realis_model_utils import CustomRobertaWithPOS as RealisModel
|
53 |
+
model_realis = RealisModel(num_classes_realis=4)
|
54 |
+
model_realis.load_state_dict(torch.load("cybersecurity_knowledge_graph/realis_model_state_dict.pth", map_location=device))
|
55 |
+
model_realis.eval()
|
56 |
+
|
57 |
+
"""
|
58 |
+
Function: create_dataloader(text_input)
|
59 |
+
Description: This function prepares a DataLoader for processing text input, including tokenization and alignment of labels.
|
60 |
+
Inputs:
|
61 |
+
- text_input: The input text to be processed.
|
62 |
+
Output:
|
63 |
+
- dataloader: A DataLoader for the tokenized and batched text data.
|
64 |
+
- tokenized_dataset_ner: The tokenized dataset used for training.
|
65 |
+
"""
|
66 |
+
def create_dataloader(text_input):
|
67 |
+
|
68 |
+
event_nuggets = get_event_nuggets(text_input)
|
69 |
+
doc = nlp(text_input)
|
70 |
+
|
71 |
+
content_as_words_emdash = [tok.text for tok in doc]
|
72 |
+
content_as_words_emdash = [word.replace("``", '"').replace("''", '"').replace("$", "") for word in content_as_words_emdash]
|
73 |
+
content_idx_dict = get_idxs_from_text(text_input, content_as_words_emdash)
|
74 |
+
|
75 |
+
data = []
|
76 |
+
|
77 |
+
words = []
|
78 |
+
nugget_ner_tags = []
|
79 |
+
|
80 |
+
pos_spacy = [tok.pos_ for tok in doc]
|
81 |
+
ner_spacy = [ent.ent_iob_ + "-" + ent.ent_type_ if ent.ent_iob_ != "O" else ent.ent_iob_ for ent in doc]
|
82 |
+
dep_spacy = [tok.dep_ for tok in doc]
|
83 |
+
depth_spacy = [find_dep_depth(tok) for tok in doc]
|
84 |
+
|
85 |
+
for content_dict in content_idx_dict:
|
86 |
+
start_idx, end_idx = content_dict["start_idx"], content_dict["end_idx"]
|
87 |
+
entity = get_entity_for_realis_from_idx(start_idx, end_idx, event_nuggets)
|
88 |
+
words.append(content_dict["word"])
|
89 |
+
nugget_ner_tags.append(entity)
|
90 |
+
|
91 |
+
|
92 |
+
content_token_len = len(tokenizer(words, truncation=False, is_split_into_words=True)["input_ids"])
|
93 |
+
if content_token_len > tokenizer.model_max_length:
|
94 |
+
no_split = (content_token_len // tokenizer.model_max_length) + 2
|
95 |
+
split_len = (len(words) // no_split) + 1
|
96 |
+
|
97 |
+
last_id = 0
|
98 |
+
threshold = split_len
|
99 |
+
|
100 |
+
for id, token in enumerate(words):
|
101 |
+
if token == "." and id > threshold:
|
102 |
+
data.append(
|
103 |
+
{
|
104 |
+
"tokens" : words[last_id : id + 1],
|
105 |
+
"ner_tags" : nugget_ner_tags[last_id : id + 1],
|
106 |
+
"pos_spacy" : pos_spacy[last_id : id + 1],
|
107 |
+
"ner_spacy" : ner_spacy[last_id : id + 1],
|
108 |
+
"dep_spacy" : dep_spacy[last_id : id + 1],
|
109 |
+
"depth_spacy" : depth_spacy[last_id : id + 1],
|
110 |
+
}
|
111 |
+
)
|
112 |
+
last_id = id + 1
|
113 |
+
threshold += split_len
|
114 |
+
data.append({"tokens" : words[last_id : ],
|
115 |
+
"ner_tags" : nugget_ner_tags[last_id : ],
|
116 |
+
"pos_spacy" : pos_spacy[last_id : ],
|
117 |
+
"ner_spacy" : ner_spacy[last_id : ],
|
118 |
+
"dep_spacy" : dep_spacy[last_id : ],
|
119 |
+
"depth_spacy" : depth_spacy[last_id : ]})
|
120 |
+
else:
|
121 |
+
data.append(
|
122 |
+
{
|
123 |
+
"tokens" : words,
|
124 |
+
"ner_tags" : nugget_ner_tags,
|
125 |
+
"pos_spacy" : pos_spacy,
|
126 |
+
"ner_spacy" : ner_spacy,
|
127 |
+
"dep_spacy" : dep_spacy,
|
128 |
+
"depth_spacy" : depth_spacy
|
129 |
+
}
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
ner_features = Features({'tokens' : Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
|
134 |
+
'ner_tags' : Sequence(feature=ClassLabel(num_classes=len(event_nugget_list), names=event_nugget_list, names_file=None, id=None), length=-1, id=None),
|
135 |
+
'pos_spacy' : Sequence(feature=ClassLabel(num_classes=len(pos_spacy_tag_list), names=pos_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
|
136 |
+
'ner_spacy' : Sequence(feature=ClassLabel(num_classes=len(ner_spacy_tag_list), names=ner_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
|
137 |
+
'dep_spacy' : Sequence(feature=ClassLabel(num_classes=len(dep_spacy_tag_list), names=dep_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
|
138 |
+
'depth_spacy' : Sequence(feature=ClassLabel(num_classes=17, names= list(range(17)), names_file=None, id=None), length=-1, id=None)
|
139 |
+
})
|
140 |
+
|
141 |
+
dataset = Dataset.from_list(data, features=ner_features)
|
142 |
+
tokenized_dataset_ner = dataset.map(tokenize_and_align_labels_with_pos_ner_realis, fn_kwargs={'tokenizer' : tokenizer, 'ner_names' : event_nugget_list}, batched=True, load_from_cache_file=False)
|
143 |
+
tokenized_dataset_ner = tokenized_dataset_ner.with_format("torch")
|
144 |
+
|
145 |
+
tokenized_dataset_ner = tokenized_dataset_ner.remove_columns("tokens")
|
146 |
+
|
147 |
+
batch_size = 4 # Number of input texts
|
148 |
+
dataloader = DataLoader(tokenized_dataset_ner, batch_size=batch_size)
|
149 |
+
return dataloader, tokenized_dataset_ner
|
150 |
+
|
151 |
+
"""
|
152 |
+
Function: predict(dataloader)
|
153 |
+
Description: This function performs inference on a given DataLoader using a trained model and returns the predicted labels.
|
154 |
+
Inputs:
|
155 |
+
- dataloader: A DataLoader containing input data for prediction.
|
156 |
+
Output:
|
157 |
+
- predicted_label: A tensor containing the predicted labels for the input data.
|
158 |
+
"""
|
159 |
+
def predict(dataloader):
|
160 |
+
predicted_label = []
|
161 |
+
for batch in dataloader:
|
162 |
+
with torch.no_grad():
|
163 |
+
logits = model_realis(**batch)
|
164 |
+
|
165 |
+
batch_predicted_label = logits.argmax(-1)
|
166 |
+
predicted_label.append(batch_predicted_label)
|
167 |
+
return torch.cat(predicted_label, dim=-1)
|
168 |
+
|
169 |
+
"""
|
170 |
+
Function: show_annotations(text_input)
|
171 |
+
Description: This function displays annotated event nuggets in the provided input text using the Streamlit library.
|
172 |
+
Inputs:
|
173 |
+
- text_input: The input text containing event nuggets to be annotated and displayed.
|
174 |
+
Output:
|
175 |
+
- An interactive display of annotated event nuggets within the input text.
|
176 |
+
"""
|
177 |
+
def show_annotations(text_input):
|
178 |
+
st.title("Event Realis")
|
179 |
+
|
180 |
+
dataloader, tokenized_dataset_ner = create_dataloader(text_input)
|
181 |
+
predicted_label = predict(dataloader)
|
182 |
+
|
183 |
+
for idx, labels in enumerate(predicted_label):
|
184 |
+
token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
|
185 |
+
|
186 |
+
tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
|
187 |
+
tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
|
188 |
+
|
189 |
+
text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
|
190 |
+
idxs = get_idxs_from_text(text, tokens)
|
191 |
+
|
192 |
+
labels = labels[token_mask]
|
193 |
+
|
194 |
+
annotated_text_list = []
|
195 |
+
last_label = ""
|
196 |
+
cumulative_tokens = ""
|
197 |
+
last_id = 0
|
198 |
+
|
199 |
+
for idx, label in zip(idxs, labels):
|
200 |
+
to_label = realis_list[label]
|
201 |
+
label_short = to_label.split("-")[1] if "-" in to_label else to_label
|
202 |
+
if last_label == label_short:
|
203 |
+
cumulative_tokens += text[last_id : idx["end_idx"]]
|
204 |
+
last_id = idx["end_idx"]
|
205 |
+
else:
|
206 |
+
if last_label != "":
|
207 |
+
if last_label == "O":
|
208 |
+
annotated_text_list.append(cumulative_tokens)
|
209 |
+
else:
|
210 |
+
annotated_text_list.append((cumulative_tokens, last_label))
|
211 |
+
last_label = label_short
|
212 |
+
cumulative_tokens = idx["word"]
|
213 |
+
last_id = idx["end_idx"]
|
214 |
+
if last_label == "O":
|
215 |
+
annotated_text_list.append(cumulative_tokens)
|
216 |
+
else:
|
217 |
+
annotated_text_list.append((cumulative_tokens, last_label))
|
218 |
+
annotated_text(annotated_text_list)
|
219 |
+
|
220 |
+
"""
|
221 |
+
Function: get_event_realis(text_input)
|
222 |
+
Description: This function extracts predicted event realis (event modality) from the provided input text.
|
223 |
+
Inputs:
|
224 |
+
- text_input: The input text containing event realis to be extracted.
|
225 |
+
Output:
|
226 |
+
- predicted_event_realis: A list of dictionaries, each representing an extracted event realis with start and end offsets,
|
227 |
+
realis type, and text content.
|
228 |
+
"""
|
229 |
+
def get_event_realis(text_input):
|
230 |
+
dataloader, tokenized_dataset_ner = create_dataloader(text_input)
|
231 |
+
predicted_label = predict(dataloader)
|
232 |
+
|
233 |
+
predicted_event_realis = []
|
234 |
+
text_length = 0
|
235 |
+
for idx, labels in enumerate(predicted_label):
|
236 |
+
token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
|
237 |
+
|
238 |
+
tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
|
239 |
+
tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
|
240 |
+
|
241 |
+
text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
|
242 |
+
idxs = get_idxs_from_text(text_input[text_length : ], tokens)
|
243 |
+
|
244 |
+
labels = labels[token_mask]
|
245 |
+
|
246 |
+
start_idx = 0
|
247 |
+
end_idx = 0
|
248 |
+
last_label = ""
|
249 |
+
|
250 |
+
for idx, label in zip(idxs, labels):
|
251 |
+
to_label = realis_list[label]
|
252 |
+
label_split = to_label
|
253 |
+
|
254 |
+
if label_split == last_label:
|
255 |
+
end_idx = idx["end_idx"]
|
256 |
+
else:
|
257 |
+
if text_input[start_idx : end_idx] != "" and last_label != "O":
|
258 |
+
predicted_event_realis.append(
|
259 |
+
{
|
260 |
+
"startOffset" : text_length + start_idx,
|
261 |
+
"endOffset" : text_length + end_idx,
|
262 |
+
"realis" : last_label,
|
263 |
+
"text" : text_input[text_length + start_idx : text_length + end_idx]
|
264 |
+
}
|
265 |
+
)
|
266 |
+
start_idx = idx["start_idx"]
|
267 |
+
end_idx = idx["start_idx"] + len(idx["word"])
|
268 |
+
last_label = label_split
|
269 |
+
text_length += idx["end_idx"]
|
270 |
+
return predicted_event_realis
|
model_59.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:09bc24b422adbe6c4c6ca1333a3a8c33146e6152e00a7ad6376cab616b51e53f
|
3 |
+
size 498858353
|
model_64_pos_ner.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:76125c5bbce2c32e536fe74d24dc51fb1fce3ba076104b459ee290102ce4bd5d
|
3 |
+
size 498746934
|
model_66.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:46531e8ccf92661a025b15c829be791f72416d1b458ae1aa82cc66e069193bf5
|
3 |
+
size 498751092
|
model_97.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e6147a98aa2baaa545903103e9e2f0e55fc249ec638cfe27e273ffdd247479c4
|
3 |
+
size 498729523
|
nugget_model_state_dict.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8d04c7ccd654b3af96c1c8e0f391a20d79ae1b5970d5419680f379c6a09e78bf
|
3 |
+
size 498703483
|
nugget_model_utils.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import spacy
|
3 |
+
import en_core_web_sm
|
4 |
+
from torch import nn
|
5 |
+
import math
|
6 |
+
|
7 |
+
|
8 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
9 |
+
|
10 |
+
from transformers import AutoModel, TrainingArguments, Trainer, RobertaTokenizer, RobertaModel
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
|
13 |
+
model_checkpoint = "ehsanaghaei/SecureBERT"
|
14 |
+
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
16 |
+
roberta_model = RobertaModel.from_pretrained(model_checkpoint).to(device)
|
17 |
+
|
18 |
+
nlp = en_core_web_sm.load()
|
19 |
+
pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
|
20 |
+
ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
|
21 |
+
|
22 |
+
|
23 |
+
class CustomRobertaWithPOS(nn.Module):
|
24 |
+
def __init__(self, num_classes):
|
25 |
+
super(CustomRobertaWithPOS, self).__init__()
|
26 |
+
self.num_classes = num_classes
|
27 |
+
self.pos_embed = nn.Embedding(len(pos_spacy_tag_list), 16)
|
28 |
+
self.ner_embed = nn.Embedding(len(ner_spacy_tag_list), 16)
|
29 |
+
self.roberta = roberta_model
|
30 |
+
self.dropout1 = nn.Dropout(0.2)
|
31 |
+
self.fc1 = nn.Linear(self.roberta.config.hidden_size, num_classes)
|
32 |
+
|
33 |
+
def forward(self, input_ids, attention_mask, pos_spacy, ner_spacy, dep_spacy, depth_spacy):
|
34 |
+
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
|
35 |
+
last_hidden_output = outputs.last_hidden_state
|
36 |
+
|
37 |
+
pos_mask = pos_spacy != -100
|
38 |
+
|
39 |
+
pos_one_hot = torch.zeros((pos_spacy.shape[0], pos_spacy.shape[1], len(pos_spacy_tag_list)), dtype=torch.long)
|
40 |
+
pos_one_hot[pos_mask, pos_spacy[pos_mask]] = 1
|
41 |
+
pos_one_hot = pos_one_hot.to(device)
|
42 |
+
|
43 |
+
ner_mask = ner_spacy != -100
|
44 |
+
|
45 |
+
ner_one_hot = torch.zeros((ner_spacy.shape[0], ner_spacy.shape[1], len(ner_spacy_tag_list)), dtype=torch.long)
|
46 |
+
ner_one_hot[ner_mask, ner_spacy[ner_mask]] = 1
|
47 |
+
ner_one_hot = ner_one_hot.to(device)
|
48 |
+
|
49 |
+
features_concat = last_hidden_output
|
50 |
+
features_concat = self.dropout1(features_concat)
|
51 |
+
|
52 |
+
logits = self.fc1(features_concat)
|
53 |
+
|
54 |
+
return logits
|
55 |
+
|
56 |
+
|
57 |
+
def tokenize_and_align_labels_with_pos_ner_dep(examples, tokenizer, label_all_tokens = True):
|
58 |
+
tokenized_inputs = tokenizer(examples["tokens"], padding='max_length', truncation=True, is_split_into_words=True)
|
59 |
+
#tokenized_inputs.pop('input_ids')
|
60 |
+
ner_spacy = []
|
61 |
+
pos_spacy = []
|
62 |
+
dep_spacy = []
|
63 |
+
depth_spacy = []
|
64 |
+
|
65 |
+
for i, (pos, ner, dep, depth) in enumerate(zip(examples["pos_spacy"],
|
66 |
+
examples["ner_spacy"],
|
67 |
+
examples["dep_spacy"],
|
68 |
+
examples["depth_spacy"])):
|
69 |
+
word_ids = tokenized_inputs.word_ids(batch_index=i)
|
70 |
+
previous_word_idx = None
|
71 |
+
ner_spacy_ids = []
|
72 |
+
pos_spacy_ids = []
|
73 |
+
dep_spacy_ids = []
|
74 |
+
depth_spacy_ids = []
|
75 |
+
|
76 |
+
for word_idx in word_ids:
|
77 |
+
# Special tokens have a word id that is None. We set the label to -100 so they are automatically
|
78 |
+
# ignored in the loss function.
|
79 |
+
if word_idx is None:
|
80 |
+
ner_spacy_ids.append(-100)
|
81 |
+
pos_spacy_ids.append(-100)
|
82 |
+
dep_spacy_ids.append(-100)
|
83 |
+
depth_spacy_ids.append(-100)
|
84 |
+
# We set the label for the first token of each word.
|
85 |
+
elif word_idx != previous_word_idx:
|
86 |
+
ner_spacy_ids.append(ner[word_idx])
|
87 |
+
pos_spacy_ids.append(pos[word_idx])
|
88 |
+
dep_spacy_ids.append(dep[word_idx])
|
89 |
+
depth_spacy_ids.append(depth[word_idx])
|
90 |
+
# For the other tokens in a word, we set the label to either the current label or -100, depending on
|
91 |
+
# the label_all_tokens flag.
|
92 |
+
else:
|
93 |
+
ner_spacy_ids.append(ner[word_idx] if label_all_tokens else -100)
|
94 |
+
pos_spacy_ids.append(pos[word_idx] if label_all_tokens else -100)
|
95 |
+
dep_spacy_ids.append(dep[word_idx] if label_all_tokens else -100)
|
96 |
+
depth_spacy_ids.append(depth[word_idx] if label_all_tokens else -100)
|
97 |
+
previous_word_idx = word_idx
|
98 |
+
|
99 |
+
ner_spacy.append(ner_spacy_ids)
|
100 |
+
pos_spacy.append(pos_spacy_ids)
|
101 |
+
dep_spacy.append(dep_spacy_ids)
|
102 |
+
depth_spacy.append(depth_spacy_ids)
|
103 |
+
|
104 |
+
tokenized_inputs["pos_spacy"] = pos_spacy
|
105 |
+
tokenized_inputs["ner_spacy"] = ner_spacy
|
106 |
+
tokenized_inputs["dep_spacy"] = dep_spacy
|
107 |
+
tokenized_inputs["depth_spacy"] = depth_spacy
|
108 |
+
|
109 |
+
return tokenized_inputs
|
110 |
+
|
111 |
+
|
112 |
+
def find_nearest_nugget_features(doc, start_idx, end_idx, event_nuggets):
|
113 |
+
nearest_subtype = None
|
114 |
+
nearest_dist = math.inf
|
115 |
+
relative_pos = None
|
116 |
+
|
117 |
+
mid_idx = (end_idx + start_idx) / 2
|
118 |
+
for nugget in event_nuggets:
|
119 |
+
mid_nugget_idx = (nugget["nugget"]["startOffset"] + nugget["nugget"]["endOffset"]) / 2
|
120 |
+
dist = abs(mid_nugget_idx - mid_idx)
|
121 |
+
|
122 |
+
if dist < nearest_dist:
|
123 |
+
nearest_dist = dist
|
124 |
+
nearest_subtype = nugget["subtype"]
|
125 |
+
for sent in doc.sents:
|
126 |
+
if between_idxs(mid_idx, sent.start_char, sent.end_char) and between_idxs(mid_nugget_idx, sent.start_char, sent.end_char):
|
127 |
+
if mid_idx < mid_nugget_idx:
|
128 |
+
relative_pos = "before-same-sentence"
|
129 |
+
else:
|
130 |
+
relative_pos = "after-same-sentence"
|
131 |
+
break
|
132 |
+
elif between_idxs(mid_nugget_idx, sent.start_char, sent.end_char) and mid_idx > mid_nugget_idx:
|
133 |
+
relative_pos = "after-differ-sentence"
|
134 |
+
break
|
135 |
+
elif between_idxs(mid_idx, sent.start_char, sent.end_char) and mid_idx < mid_nugget_idx:
|
136 |
+
relative_pos = "before-differ-sentence"
|
137 |
+
break
|
138 |
+
|
139 |
+
nearest_dist = int(min(10, nearest_dist // 20))
|
140 |
+
return nearest_subtype, nearest_dist, relative_pos
|
141 |
+
|
142 |
+
def find_dep_depth(token):
|
143 |
+
depth = 0
|
144 |
+
current_token = token
|
145 |
+
while current_token.head != current_token:
|
146 |
+
depth += 1
|
147 |
+
current_token = current_token.head
|
148 |
+
return min(depth, 16)
|
149 |
+
|
150 |
+
def between_idxs(idx, start_idx, end_idx):
|
151 |
+
return idx >= start_idx and idx <= end_idx
|
realis_model_state_dict.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2ad63eeee95888dc6f22e94e0a8425a99912f7d727cd255881e8630218a3b7f0
|
3 |
+
size 498684837
|
realis_model_utils.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import en_core_web_sm
|
4 |
+
from transformers import AutoModel, TrainingArguments, Trainer, RobertaTokenizer, RobertaModel
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
|
7 |
+
model_checkpoint = "ehsanaghaei/SecureBERT"
|
8 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
9 |
+
|
10 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
|
11 |
+
roberta_model = RobertaModel.from_pretrained(model_checkpoint).to(device)
|
12 |
+
|
13 |
+
event_nugget_list = ['B-Phishing',
|
14 |
+
'I-Phishing',
|
15 |
+
'O',
|
16 |
+
'B-DiscoverVulnerability',
|
17 |
+
'B-Ransom',
|
18 |
+
'I-Ransom',
|
19 |
+
'B-Databreach',
|
20 |
+
'I-DiscoverVulnerability',
|
21 |
+
'B-PatchVulnerability',
|
22 |
+
'I-PatchVulnerability',
|
23 |
+
'I-Databreach']
|
24 |
+
|
25 |
+
nlp = en_core_web_sm.load()
|
26 |
+
pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
|
27 |
+
ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
|
28 |
+
dep_spacy_tag_list = list(nlp.get_pipe("parser").labels)
|
29 |
+
|
30 |
+
class CustomRobertaWithPOS(nn.Module):
|
31 |
+
def __init__(self, num_classes_realis):
|
32 |
+
super(CustomRobertaWithPOS, self).__init__()
|
33 |
+
self.num_classes_realis = num_classes_realis
|
34 |
+
self.pos_embed = nn.Embedding(len(pos_spacy_tag_list), 16)
|
35 |
+
self.ner_embed = nn.Embedding(len(ner_spacy_tag_list), 8)
|
36 |
+
self.dep_embed = nn.Embedding(len(dep_spacy_tag_list), 8)
|
37 |
+
self.depth_embed = nn.Embedding(17, 8)
|
38 |
+
self.nugget_embed = nn.Embedding(len(event_nugget_list), 8)
|
39 |
+
self.roberta = roberta_model
|
40 |
+
self.dropout1 = nn.Dropout(0.2)
|
41 |
+
self.fc1 = nn.Linear(self.roberta.config.hidden_size + 48, self.num_classes_realis)
|
42 |
+
|
43 |
+
def forward(self, input_ids, attention_mask, pos_spacy, ner_spacy, dep_spacy, depth_spacy, ner_tags):
|
44 |
+
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
|
45 |
+
last_hidden_output = outputs.last_hidden_state
|
46 |
+
|
47 |
+
pos_mask = pos_spacy != -100
|
48 |
+
pos_embed_masked = self.pos_embed(pos_spacy[pos_mask])
|
49 |
+
pos_embed = torch.zeros((pos_spacy.shape[0], pos_spacy.shape[1], 16), dtype=torch.float).to(device)
|
50 |
+
pos_embed[pos_mask] = pos_embed_masked
|
51 |
+
|
52 |
+
ner_mask = ner_spacy != -100
|
53 |
+
ner_embed_masked = self.ner_embed(ner_spacy[ner_mask])
|
54 |
+
ner_embed = torch.zeros((ner_spacy.shape[0], ner_spacy.shape[1], 8), dtype=torch.float).to(device)
|
55 |
+
ner_embed[ner_mask] = ner_embed_masked
|
56 |
+
|
57 |
+
dep_mask = dep_spacy != -100
|
58 |
+
dep_embed_masked = self.dep_embed(dep_spacy[dep_mask])
|
59 |
+
dep_embed = torch.zeros((dep_spacy.shape[0], dep_spacy.shape[1], 8), dtype=torch.float).to(device)
|
60 |
+
dep_embed[dep_mask] = dep_embed_masked
|
61 |
+
|
62 |
+
depth_mask = depth_spacy != -100
|
63 |
+
depth_embed_masked = self.depth_embed(depth_spacy[depth_mask])
|
64 |
+
depth_embed = torch.zeros((depth_spacy.shape[0], depth_spacy.shape[1], 8), dtype=torch.float).to(device)
|
65 |
+
depth_embed[dep_mask] = depth_embed_masked
|
66 |
+
|
67 |
+
nugget_mask = ner_tags != -100
|
68 |
+
nugget_embed_masked = self.nugget_embed(ner_tags[nugget_mask])
|
69 |
+
nugget_embed = torch.zeros((ner_tags.shape[0], ner_tags.shape[1], 8), dtype=torch.float).to(device)
|
70 |
+
nugget_embed[dep_mask] = nugget_embed_masked
|
71 |
+
|
72 |
+
features_concat = torch.cat((last_hidden_output, pos_embed, ner_embed, dep_embed, depth_embed, nugget_embed), 2).to(device)
|
73 |
+
features_concat = self.dropout1(features_concat)
|
74 |
+
features_concat = self.dropout1(features_concat)
|
75 |
+
|
76 |
+
logits = self.fc1(features_concat)
|
77 |
+
|
78 |
+
return logits
|
79 |
+
|
80 |
+
|
81 |
+
def get_entity_for_realis_from_idx(start_idx, end_idx, event_nuggets):
|
82 |
+
event_nuggets_idxs = [(nugget["startOffset"], nugget["endOffset"]) for nugget in event_nuggets]
|
83 |
+
for idx, (nugget_start, nugget_end) in enumerate(event_nuggets_idxs):
|
84 |
+
if (start_idx == nugget_start and end_idx == nugget_end) or (start_idx == nugget_start and end_idx <= nugget_end) or (start_idx == nugget_start and end_idx > nugget_end) or (end_idx == nugget_end and start_idx < nugget_start) or (start_idx <= nugget_start and end_idx <= nugget_end and end_idx > nugget_start):
|
85 |
+
return "B-" + event_nuggets[idx]["subtype"]
|
86 |
+
elif (start_idx > nugget_start and end_idx <= nugget_end) or (start_idx > nugget_start and start_idx < nugget_end):
|
87 |
+
return "I-" + event_nuggets[idx]["subtype"]
|
88 |
+
return "O"
|
89 |
+
|
90 |
+
def tokenize_and_align_labels_with_pos_ner_realis(examples, tokenizer, ner_names, label_all_tokens = True):
|
91 |
+
tokenized_inputs = tokenizer(examples["tokens"], padding='max_length', truncation=True, is_split_into_words=True)
|
92 |
+
#tokenized_inputs.pop('input_ids')
|
93 |
+
labels = []
|
94 |
+
nuggets = []
|
95 |
+
ner_spacy = []
|
96 |
+
pos_spacy = []
|
97 |
+
dep_spacy = []
|
98 |
+
depth_spacy = []
|
99 |
+
|
100 |
+
for i, (nugget, pos, ner, dep, depth) in enumerate(zip(examples["ner_tags"], examples["pos_spacy"], examples["ner_spacy"], examples["dep_spacy"], examples["depth_spacy"])):
|
101 |
+
word_ids = tokenized_inputs.word_ids(batch_index=i)
|
102 |
+
previous_word_idx = None
|
103 |
+
nugget_ids = []
|
104 |
+
ner_spacy_ids = []
|
105 |
+
pos_spacy_ids = []
|
106 |
+
dep_spacy_ids = []
|
107 |
+
depth_spacy_ids = []
|
108 |
+
|
109 |
+
for word_idx in word_ids:
|
110 |
+
# Special tokens have a word id that is None. We set the label to -100 so they are automatically
|
111 |
+
# ignored in the loss function.
|
112 |
+
if word_idx is None:
|
113 |
+
nugget_ids.append(-100)
|
114 |
+
ner_spacy_ids.append(-100)
|
115 |
+
pos_spacy_ids.append(-100)
|
116 |
+
dep_spacy_ids.append(-100)
|
117 |
+
depth_spacy_ids.append(-100)
|
118 |
+
# We set the label for the first token of each word.
|
119 |
+
elif word_idx != previous_word_idx:
|
120 |
+
nugget_ids.append(nugget[word_idx])
|
121 |
+
ner_spacy_ids.append(ner[word_idx])
|
122 |
+
pos_spacy_ids.append(pos[word_idx])
|
123 |
+
dep_spacy_ids.append(dep[word_idx])
|
124 |
+
depth_spacy_ids.append(depth[word_idx])
|
125 |
+
# For the other tokens in a word, we set the label to either the current label or -100, depending on
|
126 |
+
# the label_all_tokens flag.
|
127 |
+
else:
|
128 |
+
nugget_ids.append(nugget[word_idx] if label_all_tokens else -100)
|
129 |
+
ner_spacy_ids.append(ner[word_idx] if label_all_tokens else -100)
|
130 |
+
pos_spacy_ids.append(pos[word_idx] if label_all_tokens else -100)
|
131 |
+
dep_spacy_ids.append(dep[word_idx] if label_all_tokens else -100)
|
132 |
+
depth_spacy_ids.append(depth[word_idx] if label_all_tokens else -100)
|
133 |
+
previous_word_idx = word_idx
|
134 |
+
|
135 |
+
nuggets.append(nugget_ids)
|
136 |
+
ner_spacy.append(ner_spacy_ids)
|
137 |
+
pos_spacy.append(pos_spacy_ids)
|
138 |
+
dep_spacy.append(dep_spacy_ids)
|
139 |
+
depth_spacy.append(depth_spacy_ids)
|
140 |
+
|
141 |
+
tokenized_inputs["ner_tags"] = nuggets
|
142 |
+
tokenized_inputs["pos_spacy"] = pos_spacy
|
143 |
+
tokenized_inputs["ner_spacy"] = ner_spacy
|
144 |
+
tokenized_inputs["dep_spacy"] = dep_spacy
|
145 |
+
tokenized_inputs["depth_spacy"] = depth_spacy
|
146 |
+
return tokenized_inputs
|
utils.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
list_of_pos_tags = [
|
2 |
+
"ADJ",
|
3 |
+
"ADP",
|
4 |
+
"ADV",
|
5 |
+
"AUX",
|
6 |
+
"CCONJ",
|
7 |
+
"DET",
|
8 |
+
"INTJ",
|
9 |
+
"NOUN",
|
10 |
+
"NUM",
|
11 |
+
"PART",
|
12 |
+
"PRON",
|
13 |
+
"PROPN",
|
14 |
+
"PUNCT",
|
15 |
+
"SCONJ",
|
16 |
+
"SYM",
|
17 |
+
"VERB",
|
18 |
+
"X"
|
19 |
+
]
|
20 |
+
|
21 |
+
realis_list = ["O",
|
22 |
+
"Generic",
|
23 |
+
"Other",
|
24 |
+
"Actual"
|
25 |
+
]
|
26 |
+
|
27 |
+
|
28 |
+
event_args_list = ['O',
|
29 |
+
'B-System',
|
30 |
+
'I-System',
|
31 |
+
'B-Organization',
|
32 |
+
'B-Money',
|
33 |
+
'I-Money',
|
34 |
+
'B-Device',
|
35 |
+
'B-Person',
|
36 |
+
'I-Person',
|
37 |
+
'B-Vulnerability',
|
38 |
+
'I-Vulnerability',
|
39 |
+
'B-Capabilities',
|
40 |
+
'I-Capabilities',
|
41 |
+
'I-Organization',
|
42 |
+
'B-PaymentMethod',
|
43 |
+
'I-PaymentMethod',
|
44 |
+
'B-Data',
|
45 |
+
'I-Data',
|
46 |
+
'B-Number',
|
47 |
+
'I-Number',
|
48 |
+
'B-Malware',
|
49 |
+
'I-Malware',
|
50 |
+
'B-PII',
|
51 |
+
'I-PII',
|
52 |
+
'B-CVE',
|
53 |
+
'I-CVE',
|
54 |
+
'B-Purpose',
|
55 |
+
'I-Purpose',
|
56 |
+
'B-File',
|
57 |
+
'I-File',
|
58 |
+
'I-Device',
|
59 |
+
'B-Time',
|
60 |
+
'I-Time',
|
61 |
+
'B-Software',
|
62 |
+
'I-Software',
|
63 |
+
'B-Patch',
|
64 |
+
'I-Patch',
|
65 |
+
'B-Version',
|
66 |
+
'I-Version',
|
67 |
+
'B-Website',
|
68 |
+
'I-Website',
|
69 |
+
'B-GPE',
|
70 |
+
'I-GPE'
|
71 |
+
]
|
72 |
+
|
73 |
+
event_nugget_list = ['O',
|
74 |
+
'B-Ransom',
|
75 |
+
'I-Ransom',
|
76 |
+
'B-DiscoverVulnerability',
|
77 |
+
'I-DiscoverVulnerability',
|
78 |
+
'B-PatchVulnerability',
|
79 |
+
'I-PatchVulnerability',
|
80 |
+
'B-Databreach',
|
81 |
+
'I-Databreach',
|
82 |
+
'B-Phishing',
|
83 |
+
'I-Phishing'
|
84 |
+
]
|
85 |
+
|
86 |
+
arg_2_role = {
|
87 |
+
"File" : ['Tool', 'Trusted-Entity'],
|
88 |
+
"Person" : ['Victim', 'Attacker', 'Discoverer', 'Releaser', 'Trusted-Entity', 'Vulnerable_System_Owner'],
|
89 |
+
"Capabilities" : ['Attack-Pattern', 'Capabilities', 'Issues-Addressed'],
|
90 |
+
"Purpose" : ['Purpose'],
|
91 |
+
"Time" : ['Time'],
|
92 |
+
"PII" : ['Compromised-Data', 'Trusted-Entity'],
|
93 |
+
"Data" : ['Compromised-Data', 'Trusted-Entity'],
|
94 |
+
"Organization" : ['Victim', 'Releaser', 'Discoverer', 'Attacker', 'Vulnerable_System_Owner', 'Trusted-Entity'],
|
95 |
+
"Patch" : ['Patch'],
|
96 |
+
"Software" : ['Vulnerable_System', 'Victim', 'Trusted-Entity', 'Supported_Platform'],
|
97 |
+
"Vulnerability" : ['Vulnerability'],
|
98 |
+
"Version" : ['Patch-Number', 'Vulnerable_System_Version'],
|
99 |
+
"Device" : ['Vulnerable_System', 'Victim', 'Supported_Platform'],
|
100 |
+
"CVE" : ['CVE'],
|
101 |
+
"Number" : ['Number-of-Data', 'Number-of-Victim'],
|
102 |
+
"System" : ['Victim', 'Supported_Platform', 'Vulnerable_System', 'Trusted-Entity'],
|
103 |
+
"Malware" : ['Tool'],
|
104 |
+
"Money" : ['Price', 'Damage-Amount'],
|
105 |
+
"PaymentMethod" : ['Payment-Method'],
|
106 |
+
"GPE" : ['Place'],
|
107 |
+
"Website" : ['Trusted-Entity', 'Tool', 'Vulnerable_System', 'Victim', 'Supported_Platform'],
|
108 |
+
}
|
109 |
+
|
110 |
+
def get_content(data):
|
111 |
+
return data["content"]
|
112 |
+
|
113 |
+
def get_event_nugget(data):
|
114 |
+
return [
|
115 |
+
{"nugget" : event["nugget"], "type" : event["type"], "subtype" : event["subtype"], "realis" : event["realis"]}
|
116 |
+
for hopper in data["cyberevent"]["hopper"] for event in hopper["events"]
|
117 |
+
]
|
118 |
+
def get_event_args(data):
|
119 |
+
events = [event for hopper in data["cyberevent"]["hopper"] for event in hopper["events"]]
|
120 |
+
args = []
|
121 |
+
for event in events:
|
122 |
+
if "argument" in event.keys():
|
123 |
+
args.extend(event["argument"])
|
124 |
+
return args
|
125 |
+
|
126 |
+
def get_idxs_from_text(text, text_tokenized):
|
127 |
+
rest_text = text
|
128 |
+
last_idx = 0
|
129 |
+
result_dict = []
|
130 |
+
|
131 |
+
for substring in text_tokenized:
|
132 |
+
index = rest_text.find(substring)
|
133 |
+
result_dict.append(
|
134 |
+
{
|
135 |
+
"word" : substring,
|
136 |
+
"start_idx" : last_idx + index,
|
137 |
+
"end_idx" : last_idx + index + len(substring)
|
138 |
+
}
|
139 |
+
)
|
140 |
+
rest_text = rest_text[index + len(substring) : ]
|
141 |
+
last_idx += index + len(substring)
|
142 |
+
return result_dict
|
143 |
+
|
144 |
+
def get_entity_from_idx(start_idx, end_idx, event_nuggets):
|
145 |
+
event_nuggets_idxs = [(nugget["nugget"]["startOffset"], nugget["nugget"]["endOffset"]) for nugget in event_nuggets]
|
146 |
+
for idx, (nugget_start, nugget_end) in enumerate(event_nuggets_idxs):
|
147 |
+
if (start_idx == nugget_start and end_idx == nugget_end) or (start_idx == nugget_start and end_idx <= nugget_end) or (start_idx == nugget_start and end_idx > nugget_end) or (end_idx == nugget_end and start_idx < nugget_start) or (start_idx <= nugget_start and end_idx <= nugget_end and end_idx > nugget_start):
|
148 |
+
return "B-" + event_nuggets[idx]["subtype"]
|
149 |
+
elif (start_idx > nugget_start and end_idx <= nugget_end) or (start_idx > nugget_start and start_idx < nugget_end):
|
150 |
+
return "I-" + event_nuggets[idx]["subtype"]
|
151 |
+
return "O"
|
152 |
+
|
153 |
+
def get_entity_and_realis_from_idx(start_idx, end_idx, event_nuggets):
|
154 |
+
event_nuggets_idxs = [(nugget["nugget"]["startOffset"], nugget["nugget"]["endOffset"]) for nugget in event_nuggets]
|
155 |
+
for idx, (nugget_start, nugget_end) in enumerate(event_nuggets_idxs):
|
156 |
+
if (start_idx == nugget_start and end_idx == nugget_end) or (start_idx == nugget_start and end_idx <= nugget_end) or (start_idx == nugget_start and end_idx > nugget_end) or (end_idx == nugget_end and start_idx < nugget_start) or (start_idx <= nugget_start and end_idx <= nugget_end and end_idx > nugget_start):
|
157 |
+
return "B-" + event_nuggets[idx]["subtype"], "B-" + event_nuggets[idx]["realis"]
|
158 |
+
elif (start_idx > nugget_start and end_idx <= nugget_end) or (start_idx > nugget_start and start_idx < nugget_end):
|
159 |
+
return "I-" + event_nuggets[idx]["subtype"], "I-" + event_nuggets[idx]["realis"]
|
160 |
+
return "O", "O"
|
161 |
+
|
162 |
+
def get_args_entity_from_idx(start_idx, end_idx, event_args):
|
163 |
+
event_nuggets_idxs = [(nugget["startOffset"], nugget["endOffset"]) for nugget in event_args]
|
164 |
+
for idx, (nugget_start, nugget_end) in enumerate(event_nuggets_idxs):
|
165 |
+
if (start_idx == nugget_start and end_idx == nugget_end) or (start_idx == nugget_start and end_idx <= nugget_end) or (start_idx == nugget_start and end_idx > nugget_end) or (end_idx == nugget_end and start_idx < nugget_start) or (start_idx <= nugget_start and end_idx <= nugget_end and end_idx > nugget_start):
|
166 |
+
return "B-" + event_args[idx]["type"]
|
167 |
+
elif (start_idx > nugget_start and end_idx <= nugget_end) or (start_idx > nugget_start and start_idx < nugget_end):
|
168 |
+
return "I-" + event_args[idx]["type"]
|
169 |
+
return "O"
|
170 |
+
|
171 |
+
def split_with_character(string, char):
|
172 |
+
result = []
|
173 |
+
start = 0
|
174 |
+
for i, c in enumerate(string):
|
175 |
+
if c == char:
|
176 |
+
result.append(string[start:i])
|
177 |
+
result.append(char)
|
178 |
+
start = i + 1
|
179 |
+
result.append(string[start:])
|
180 |
+
return [x for x in result if x != '']
|
181 |
+
|
182 |
+
def extend_list_with_character(content_list, character):
|
183 |
+
content_as_words = []
|
184 |
+
for word in content_list:
|
185 |
+
if character in word:
|
186 |
+
split_list = split_with_character(word, character)
|
187 |
+
content_as_words.extend(split_list)
|
188 |
+
else:
|
189 |
+
content_as_words.append(word)
|
190 |
+
return content_as_words
|
191 |
+
|
192 |
+
def find_dict_by_overlap(list_of_dicts, key_value_pairs):
|
193 |
+
for dictionary in list_of_dicts:
|
194 |
+
if max(dictionary["start"], dictionary["end"]) >= min(key_value_pairs["start"], key_value_pairs["end"]) and max(key_value_pairs["start"], key_value_pairs["end"]) >= min(dictionary["start"], dictionary["end"]):
|
195 |
+
return dictionary
|
196 |
+
return None
|