cpi-connect commited on
Commit
4e38daf
·
1 Parent(s): d60b7fd

Upload 18 files

Browse files
.gitattributes CHANGED
@@ -1 +1,8 @@
1
  pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
1
  pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
2
+ argument_model_state_dict.pth filter=lfs diff=lfs merge=lfs -text
3
+ model_59.pt filter=lfs diff=lfs merge=lfs -text
4
+ model_64_pos_ner.pt filter=lfs diff=lfs merge=lfs -text
5
+ model_66.pt filter=lfs diff=lfs merge=lfs -text
6
+ model_97.pt filter=lfs diff=lfs merge=lfs -text
7
+ nugget_model_state_dict.pth filter=lfs diff=lfs merge=lfs -text
8
+ realis_model_state_dict.pth filter=lfs diff=lfs merge=lfs -text
args_model_utils.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spacy
3
+ import en_core_web_sm
4
+ from torch import nn
5
+ import math
6
+
7
+
8
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
9
+
10
+ from transformers import AutoModel, TrainingArguments, Trainer, RobertaTokenizer, RobertaModel
11
+ from transformers import AutoTokenizer
12
+
13
+ model_checkpoint = "ehsanaghaei/SecureBERT"
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
16
+ roberta_model = RobertaModel.from_pretrained(model_checkpoint).to(device)
17
+
18
+ nlp = en_core_web_sm.load()
19
+ pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
20
+ ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
21
+ dep_spacy_tag_list = list(nlp.get_pipe("parser").labels)
22
+ event_nugget_tag_list = ["Databreach", "Ransom", "PatchVulnerability", "Phishing", "DiscoverVulnerability"]
23
+ arg_nugget_relative_pos_tag_list = ["before-same-sentence", "before-differ-sentence", "after-same-sentence", "after-differ-sentence"]
24
+
25
+ class CustomRobertaWithPOS(nn.Module):
26
+ def __init__(self, num_classes):
27
+ super(CustomRobertaWithPOS, self).__init__()
28
+ self.num_classes = num_classes
29
+
30
+ self.pos_embed = nn.Embedding(len(pos_spacy_tag_list), 16)
31
+ self.ner_embed = nn.Embedding(len(ner_spacy_tag_list), 8)
32
+ self.dep_embed = nn.Embedding(len(dep_spacy_tag_list), 8)
33
+ self.depth_embed = nn.Embedding(17, 8)
34
+ self.subtype_embed = nn.Embedding(len(event_nugget_tag_list), 2)
35
+ self.dist_embed = nn.Embedding(11, 6)
36
+ self.relative_pos_embed = nn.Embedding(len(arg_nugget_relative_pos_tag_list), 2)
37
+
38
+ self.roberta = roberta_model
39
+ self.dropout1 = nn.Dropout(0.2)
40
+ self.fc1 = nn.Linear(self.roberta.config.hidden_size + 50, num_classes)
41
+
42
+ def forward(self, input_ids, attention_mask, pos_spacy, ner_spacy, dep_spacy, depth_spacy, nearest_nugget_subtype, nearest_nugget_dist, arg_nugget_relative_pos):
43
+ outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
44
+ last_hidden_output = outputs.last_hidden_state
45
+
46
+ pooler_output = outputs.pooler_output
47
+ pooler_output_unsqz = pooler_output.unsqueeze(1)
48
+ pooler_output_fin = pooler_output_unsqz.expand(-1, last_hidden_output.shape[1], -1)
49
+
50
+
51
+ pos_mask = pos_spacy != -100
52
+ pos_embed_masked = self.pos_embed(pos_spacy[pos_mask])
53
+ pos_embed = torch.zeros((pos_spacy.shape[0], pos_spacy.shape[1], 16), dtype=torch.float).to(device)
54
+ pos_embed[pos_mask] = pos_embed_masked
55
+
56
+ ner_mask = ner_spacy != -100
57
+ ner_embed_masked = self.ner_embed(ner_spacy[ner_mask])
58
+ ner_embed = torch.zeros((ner_spacy.shape[0], ner_spacy.shape[1], 8), dtype=torch.float).to(device)
59
+ ner_embed[ner_mask] = ner_embed_masked
60
+
61
+ dep_mask = dep_spacy != -100
62
+ dep_embed_masked = self.dep_embed(dep_spacy[dep_mask])
63
+ dep_embed = torch.zeros((dep_spacy.shape[0], dep_spacy.shape[1], 8), dtype=torch.float).to(device)
64
+ dep_embed[dep_mask] = dep_embed_masked
65
+
66
+ depth_mask = depth_spacy != -100
67
+ depth_embed_masked = self.depth_embed(depth_spacy[depth_mask])
68
+ depth_embed = torch.zeros((depth_spacy.shape[0], depth_spacy.shape[1], 8), dtype=torch.float).to(device)
69
+ depth_embed[dep_mask] = depth_embed_masked
70
+
71
+ nearest_nugget_subtype_mask = nearest_nugget_subtype != -100
72
+ nearest_nugget_subtype_embed_masked = self.subtype_embed(nearest_nugget_subtype[nearest_nugget_subtype_mask])
73
+ nearest_nugget_subtype_embed = torch.zeros((nearest_nugget_subtype.shape[0], nearest_nugget_subtype.shape[1], 2), dtype=torch.float).to(device)
74
+ nearest_nugget_subtype_embed[dep_mask] = nearest_nugget_subtype_embed_masked
75
+
76
+ nearest_nugget_dist_mask = nearest_nugget_dist != -100
77
+ nearest_nugget_dist_embed_masked = self.dist_embed(nearest_nugget_dist[nearest_nugget_dist_mask])
78
+ nearest_nugget_dist_embed = torch.zeros((nearest_nugget_dist.shape[0], nearest_nugget_dist.shape[1], 6), dtype=torch.float).to(device)
79
+ nearest_nugget_dist_embed[dep_mask] = nearest_nugget_dist_embed_masked
80
+
81
+ arg_nugget_relative_pos_mask = arg_nugget_relative_pos != -100
82
+ arg_nugget_relative_pos_embed_masked = self.relative_pos_embed(arg_nugget_relative_pos[arg_nugget_relative_pos_mask])
83
+ arg_nugget_relative_pos_embed = torch.zeros((arg_nugget_relative_pos.shape[0], arg_nugget_relative_pos.shape[1], 2), dtype=torch.float).to(device)
84
+ arg_nugget_relative_pos_embed[dep_mask] = arg_nugget_relative_pos_embed_masked
85
+
86
+ features_concat = torch.cat((last_hidden_output, pos_embed, ner_embed, dep_embed, depth_embed, nearest_nugget_subtype_embed, nearest_nugget_dist_embed, arg_nugget_relative_pos_embed), 2).to(device)
87
+ features_concat = self.dropout1(features_concat)
88
+
89
+ logits = self.fc1(features_concat)
90
+
91
+ return logits
92
+
93
+
94
+ def tokenize_and_align_labels_with_pos_ner_dep(examples, tokenizer, label_all_tokens = True):
95
+ tokenized_inputs = tokenizer(examples["tokens"], padding='max_length', truncation=True, is_split_into_words=True)
96
+ #tokenized_inputs.pop('input_ids')
97
+ ner_spacy = []
98
+ pos_spacy = []
99
+ dep_spacy = []
100
+ depth_spacy = []
101
+ nearest_nugget_subtype = []
102
+ nearest_nugget_dist = []
103
+ arg_nugget_relative_pos = []
104
+
105
+ for i, (pos, ner, dep, depth, subtype, dist, relative_pos) in enumerate(zip(examples["pos_spacy"],
106
+ examples["ner_spacy"],
107
+ examples["dep_spacy"],
108
+ examples["depth_spacy"],
109
+ examples["nearest_nugget_subtype"],
110
+ examples["nearest_nugget_dist"],
111
+ examples["arg_nugget_relative_pos"])):
112
+ word_ids = tokenized_inputs.word_ids(batch_index=i)
113
+ previous_word_idx = None
114
+ ner_spacy_ids = []
115
+ pos_spacy_ids = []
116
+ dep_spacy_ids = []
117
+ depth_spacy_ids = []
118
+ nearest_nugget_subtype_ids = []
119
+ nearest_nugget_dist_ids = []
120
+ arg_nugget_relative_pos_ids = []
121
+
122
+ for word_idx in word_ids:
123
+ # Special tokens have a word id that is None. We set the label to -100 so they are automatically
124
+ # ignored in the loss function.
125
+ if word_idx is None:
126
+ ner_spacy_ids.append(-100)
127
+ pos_spacy_ids.append(-100)
128
+ dep_spacy_ids.append(-100)
129
+ depth_spacy_ids.append(-100)
130
+ nearest_nugget_subtype_ids.append(-100)
131
+ nearest_nugget_dist_ids.append(-100)
132
+ arg_nugget_relative_pos_ids.append(-100)
133
+ # We set the label for the first token of each word.
134
+ elif word_idx != previous_word_idx:
135
+ ner_spacy_ids.append(ner[word_idx])
136
+ pos_spacy_ids.append(pos[word_idx])
137
+ dep_spacy_ids.append(dep[word_idx])
138
+ depth_spacy_ids.append(depth[word_idx])
139
+ nearest_nugget_subtype_ids.append(subtype[word_idx])
140
+ nearest_nugget_dist_ids.append(dist[word_idx])
141
+ arg_nugget_relative_pos_ids.append(relative_pos[word_idx])
142
+ # For the other tokens in a word, we set the label to either the current label or -100, depending on
143
+ # the label_all_tokens flag.
144
+ else:
145
+ ner_spacy_ids.append(ner[word_idx] if label_all_tokens else -100)
146
+ pos_spacy_ids.append(pos[word_idx] if label_all_tokens else -100)
147
+ dep_spacy_ids.append(dep[word_idx] if label_all_tokens else -100)
148
+ depth_spacy_ids.append(depth[word_idx] if label_all_tokens else -100)
149
+ nearest_nugget_subtype_ids.append(subtype[word_idx] if label_all_tokens else -100)
150
+ nearest_nugget_dist_ids.append(dist[word_idx] if label_all_tokens else -100)
151
+ arg_nugget_relative_pos_ids.append(relative_pos[word_idx] if label_all_tokens else -100)
152
+ previous_word_idx = word_idx
153
+
154
+ ner_spacy.append(ner_spacy_ids)
155
+ pos_spacy.append(pos_spacy_ids)
156
+ dep_spacy.append(dep_spacy_ids)
157
+ depth_spacy.append(depth_spacy_ids)
158
+ nearest_nugget_subtype.append(nearest_nugget_subtype_ids)
159
+ nearest_nugget_dist.append(nearest_nugget_dist_ids)
160
+ arg_nugget_relative_pos.append(arg_nugget_relative_pos_ids)
161
+
162
+ tokenized_inputs["pos_spacy"] = pos_spacy
163
+ tokenized_inputs["ner_spacy"] = ner_spacy
164
+ tokenized_inputs["dep_spacy"] = dep_spacy
165
+ tokenized_inputs["depth_spacy"] = depth_spacy
166
+ tokenized_inputs["nearest_nugget_subtype"] = nearest_nugget_subtype
167
+ tokenized_inputs["nearest_nugget_dist"] = nearest_nugget_dist
168
+ tokenized_inputs["arg_nugget_relative_pos"] = arg_nugget_relative_pos
169
+ return tokenized_inputs
170
+
171
+ def find_nearest_nugget_features(doc, start_idx, end_idx, event_nuggets):
172
+ nearest_subtype = None
173
+ nearest_dist = math.inf
174
+ relative_pos = None
175
+
176
+ mid_idx = (end_idx + start_idx) / 2
177
+ for nugget in event_nuggets:
178
+ mid_nugget_idx = (nugget["startOffset"] + nugget["endOffset"]) / 2
179
+ dist = abs(mid_nugget_idx - mid_idx)
180
+
181
+ if dist < nearest_dist:
182
+ nearest_dist = dist
183
+ nearest_subtype = nugget["subtype"]
184
+ for sent in doc.sents:
185
+ if between_idxs(mid_idx, sent.start_char, sent.end_char) and between_idxs(mid_nugget_idx, sent.start_char, sent.end_char):
186
+ if mid_idx < mid_nugget_idx:
187
+ relative_pos = "before-same-sentence"
188
+ else:
189
+ relative_pos = "after-same-sentence"
190
+ break
191
+ elif between_idxs(mid_nugget_idx, sent.start_char, sent.end_char) and mid_idx > mid_nugget_idx:
192
+ relative_pos = "after-differ-sentence"
193
+ break
194
+ elif between_idxs(mid_idx, sent.start_char, sent.end_char) and mid_idx < mid_nugget_idx:
195
+ relative_pos = "before-differ-sentence"
196
+ break
197
+
198
+ nearest_dist = int(min(10, nearest_dist // 20))
199
+ return nearest_subtype, nearest_dist, relative_pos
200
+
201
+ def find_dep_depth(token):
202
+ depth = 0
203
+ current_token = token
204
+ while current_token.head != current_token:
205
+ depth += 1
206
+ current_token = current_token.head
207
+ return min(depth, 16)
208
+
209
+ def between_idxs(idx, start_idx, end_idx):
210
+ return idx >= start_idx and idx <= end_idx
argument_model_state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:185e22992430c80ec1eb1fca7f3ba4ebe801163c3ba13bed00abc6dc24072712
3
+ size 498813605
configuration.py CHANGED
@@ -5,6 +5,7 @@ from cybersecurity_knowledge_graph.utils import event_args_list, event_nugget_li
5
 
6
 
7
  class CybersecurityKnowledgeGraphConfig(PretrainedConfig):
 
8
 
9
  def __init__(
10
  self,
 
5
 
6
 
7
  class CybersecurityKnowledgeGraphConfig(PretrainedConfig):
8
+ model_type = "cybersecurity_knowledge_graph"
9
 
10
  def __init__(
11
  self,
event_arg_predict.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from annotated_text import annotated_text
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+
6
+ from cybersecurity_knowledge_graph.args_model_utils import tokenize_and_align_labels_with_pos_ner_dep, find_nearest_nugget_features, find_dep_depth
7
+ from cybersecurity_knowledge_graph.nugget_model_utils import CustomRobertaWithPOS
8
+ from cybersecurity_knowledge_graph.utils import get_content, get_event_nugget, get_idxs_from_text, get_entity_from_idx, list_of_pos_tags, event_args_list
9
+
10
+ from cybersecurity_knowledge_graph.event_nugget_predict import get_event_nuggets
11
+ import spacy
12
+ from transformers import AutoTokenizer
13
+ from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
14
+ import os
15
+
16
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
17
+
18
+ def find_dep_depth(token):
19
+ depth = 0
20
+ current_token = token
21
+ while current_token.head != current_token:
22
+ depth += 1
23
+ current_token = current_token.head
24
+ return min(depth, 16)
25
+
26
+
27
+ nlp = spacy.load('en_core_web_sm')
28
+
29
+ pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
30
+ ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
31
+ dep_spacy_tag_list = list(nlp.get_pipe("parser").labels)
32
+ event_nugget_tag_list = ["Databreach", "Ransom", "PatchVulnerability", "Phishing", "DiscoverVulnerability"]
33
+ arg_nugget_relative_pos_tag_list = ["before-same-sentence", "before-differ-sentence", "after-same-sentence", "after-differ-sentence"]
34
+
35
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
36
+
37
+ model_checkpoint = "ehsanaghaei/SecureBERT"
38
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
39
+
40
+ from cybersecurity_knowledge_graph.args_model_utils import CustomRobertaWithPOS as ArgumentModel
41
+ model_nugget = ArgumentModel(num_classes=43)
42
+ model_nugget.load_state_dict(torch.load("cybersecurity_knowledge_graph/argument_model_state_dict.pth", map_location=device))
43
+ model_nugget.eval()
44
+
45
+ """
46
+ Function: create_dataloader(text_input)
47
+ Description: This function creates a DataLoader for processing text data, tokenizes it, and organizes it into batches.
48
+ Inputs:
49
+ - text_input: The input text to be processed.
50
+ Output:
51
+ - dataloader: A DataLoader for the tokenized and batched text data.
52
+ - tokenized_dataset_ner: The tokenized dataset used for training.
53
+ """
54
+ def create_dataloader(text_input):
55
+
56
+ event_nuggets = get_event_nuggets(text_input)
57
+ doc = nlp(text_input)
58
+
59
+ content_as_words_emdash = [tok.text for tok in doc]
60
+ content_as_words_emdash = [word.replace("``", '"').replace("''", '"').replace("$", "") for word in content_as_words_emdash]
61
+ content_idx_dict = get_idxs_from_text(text_input, content_as_words_emdash)
62
+
63
+ data = []
64
+
65
+ words = []
66
+ arg_nugget_nearest_subtype = []
67
+ arg_nugget_nearest_dist = []
68
+ arg_nugget_relative_pos = []
69
+
70
+ pos_spacy = [tok.pos_ for tok in doc]
71
+ ner_spacy = [ent.ent_iob_ + "-" + ent.ent_type_ if ent.ent_iob_ != "O" else ent.ent_iob_ for ent in doc]
72
+ dep_spacy = [tok.dep_ for tok in doc]
73
+ depth_spacy = [find_dep_depth(tok) for tok in doc]
74
+
75
+ for content_dict in content_idx_dict:
76
+ start_idx, end_idx = content_dict["start_idx"], content_dict["end_idx"]
77
+ nearest_subtype, nearest_dist, relative_pos = find_nearest_nugget_features(doc, content_dict["start_idx"], content_dict["end_idx"], event_nuggets)
78
+ words.append(content_dict["word"])
79
+
80
+ arg_nugget_nearest_subtype.append(nearest_subtype)
81
+ arg_nugget_nearest_dist.append(nearest_dist)
82
+ arg_nugget_relative_pos.append(relative_pos)
83
+
84
+
85
+ content_token_len = len(tokenizer(words, truncation=False, is_split_into_words=True)["input_ids"])
86
+ if content_token_len > tokenizer.model_max_length:
87
+ no_split = (content_token_len // tokenizer.model_max_length) + 2
88
+ split_len = (len(words) // no_split) + 1
89
+
90
+ last_id = 0
91
+ threshold = split_len
92
+
93
+ for id, token in enumerate(words):
94
+ if token == "." and id > threshold:
95
+ data.append(
96
+ {
97
+ "tokens" : words[last_id : id + 1],
98
+ "pos_spacy" : pos_spacy[last_id : id + 1],
99
+ "ner_spacy" : ner_spacy[last_id : id + 1],
100
+ "dep_spacy" : dep_spacy[last_id : id + 1],
101
+ "depth_spacy" : depth_spacy[last_id : id + 1],
102
+ "nearest_nugget_subtype" : arg_nugget_nearest_subtype[last_id : id + 1],
103
+ "nearest_nugget_dist" : arg_nugget_nearest_dist[last_id : id + 1],
104
+ "arg_nugget_relative_pos" : arg_nugget_relative_pos[last_id : id + 1]
105
+ }
106
+ )
107
+ last_id = id + 1
108
+ threshold += split_len
109
+ data.append({"tokens" : words[last_id : ],
110
+ "pos_spacy" : pos_spacy[last_id : ],
111
+ "ner_spacy" : ner_spacy[last_id : ],
112
+ "dep_spacy" : dep_spacy[last_id : ],
113
+ "depth_spacy" : depth_spacy[last_id : ],
114
+ "nearest_nugget_subtype" : arg_nugget_nearest_subtype[last_id : ],
115
+ "nearest_nugget_dist" : arg_nugget_nearest_dist[last_id : ],
116
+ "arg_nugget_relative_pos" : arg_nugget_relative_pos[last_id : ]})
117
+ else:
118
+ data.append(
119
+ {
120
+ "tokens" : words,
121
+ "pos_spacy" : pos_spacy,
122
+ "ner_spacy" : ner_spacy,
123
+ "dep_spacy" : dep_spacy,
124
+ "depth_spacy" : depth_spacy,
125
+ "nearest_nugget_subtype" : arg_nugget_nearest_subtype,
126
+ "nearest_nugget_dist" : arg_nugget_nearest_dist,
127
+ "arg_nugget_relative_pos" : arg_nugget_relative_pos
128
+ }
129
+ )
130
+
131
+
132
+ ner_features = Features({'tokens' : Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
133
+ 'pos_spacy' : Sequence(feature=ClassLabel(num_classes=len(pos_spacy_tag_list), names=pos_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
134
+ 'ner_spacy' : Sequence(feature=ClassLabel(num_classes=len(ner_spacy_tag_list), names=ner_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
135
+ 'dep_spacy' : Sequence(feature=ClassLabel(num_classes=len(dep_spacy_tag_list), names=dep_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
136
+ 'depth_spacy' : Sequence(feature=ClassLabel(num_classes=17, names= list(range(17)), names_file=None, id=None), length=-1, id=None),
137
+ 'nearest_nugget_subtype' : Sequence(feature=ClassLabel(num_classes=len(event_nugget_tag_list), names=event_nugget_tag_list, names_file=None, id=None), length=-1, id=None),
138
+ 'nearest_nugget_dist' : Sequence(feature=ClassLabel(num_classes=11, names=list(range(11)), names_file=None, id=None), length=-1, id=None),
139
+ 'arg_nugget_relative_pos' : Sequence(feature=ClassLabel(num_classes=len(arg_nugget_relative_pos_tag_list), names=arg_nugget_relative_pos_tag_list, names_file=None, id=None), length=-1, id=None),
140
+ })
141
+
142
+ dataset = Dataset.from_list(data, features=ner_features)
143
+ tokenized_dataset_ner = dataset.map(tokenize_and_align_labels_with_pos_ner_dep, fn_kwargs={'tokenizer' : tokenizer}, batched=True, load_from_cache_file=False)
144
+ tokenized_dataset_ner = tokenized_dataset_ner.with_format("torch")
145
+
146
+ tokenized_dataset_ner = tokenized_dataset_ner.remove_columns("tokens")
147
+
148
+ batch_size = 4 # Number of input texts
149
+ dataloader = DataLoader(tokenized_dataset_ner, batch_size=batch_size)
150
+ return dataloader, tokenized_dataset_ner
151
+
152
+ """
153
+ Function: predict(dataloader)
154
+ Description: This function performs prediction on a given dataloader using a trained model for label classification.
155
+ Inputs:
156
+ - dataloader: A DataLoader containing the input data for prediction.
157
+ Output:
158
+ - predicted_label: A tensor containing the predicted labels for each input in the dataloader.
159
+ """
160
+ def predict(dataloader):
161
+ predicted_label = []
162
+ for batch in dataloader:
163
+ with torch.no_grad():
164
+ logits = model_nugget(**batch)
165
+
166
+ batch_predicted_label = logits.argmax(-1)
167
+ predicted_label.append(batch_predicted_label)
168
+ return torch.cat(predicted_label, dim=-1)
169
+
170
+ """
171
+ Function: show_annotations(text_input)
172
+ Description: This function displays annotated event arguments in the provided input text.
173
+ Inputs:
174
+ - text_input: The input text containing event arguments to be annotated and displayed.
175
+ Output:
176
+ - An interactive display of annotated event arguments within the input text.
177
+ """
178
+ def show_annotations(text_input):
179
+ st.title("Event Arguments")
180
+
181
+ dataloader, tokenized_dataset_ner = create_dataloader(text_input)
182
+ predicted_label = predict(dataloader)
183
+
184
+ for idx, labels in enumerate(predicted_label):
185
+ token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
186
+
187
+ tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
188
+ tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
189
+
190
+ text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
191
+ idxs = get_idxs_from_text(text, tokens)
192
+
193
+ labels = labels[token_mask]
194
+
195
+ annotated_text_list = []
196
+ last_label = ""
197
+ cumulative_tokens = ""
198
+ last_id = 0
199
+
200
+ for idx, label in zip(idxs, labels):
201
+ to_label = event_args_list[label]
202
+ label_short = to_label.split("-")[1] if "-" in to_label else to_label
203
+ if last_label == label_short:
204
+ cumulative_tokens += text[last_id : idx["end_idx"]]
205
+ last_id = idx["end_idx"]
206
+ else:
207
+ if last_label != "":
208
+ if last_label == "O":
209
+ annotated_text_list.append(cumulative_tokens)
210
+ else:
211
+ annotated_text_list.append((cumulative_tokens, last_label))
212
+ last_label = label_short
213
+ cumulative_tokens = idx["word"]
214
+ last_id = idx["end_idx"]
215
+ if last_label == "O":
216
+ annotated_text_list.append(cumulative_tokens)
217
+ else:
218
+ annotated_text_list.append((cumulative_tokens, last_label))
219
+
220
+ annotated_text(annotated_text_list)
221
+
222
+ """
223
+ Function: get_event_args(text_input)
224
+ Description: This function extracts predicted event arguments (event nuggets) from the provided input text.
225
+ Inputs:
226
+ - text_input: The input text containing event nuggets to be extracted.
227
+ Output:
228
+ - predicted_event_nuggets: A list of dictionaries, each representing an extracted event nugget with start and end offsets,
229
+ subtype, and text content.
230
+ """
231
+ def get_event_args(text_input):
232
+ dataloader, tokenized_dataset_ner = create_dataloader(text_input)
233
+ predicted_label = predict(dataloader)
234
+
235
+ predicted_event_nuggets = []
236
+ text_length = 0
237
+ for idx, labels in enumerate(predicted_label):
238
+ token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
239
+
240
+ tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
241
+ tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
242
+
243
+ text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
244
+ idxs = get_idxs_from_text(text_input[text_length : ], tokens)
245
+
246
+ labels = labels[token_mask]
247
+
248
+ start_idx = 0
249
+ end_idx = 0
250
+ last_label = ""
251
+
252
+ for idx, label in zip(idxs, labels):
253
+ to_label = event_args_list[label]
254
+ if "-" in to_label:
255
+ label_split = to_label.split("-")[1]
256
+ else:
257
+ label_split = to_label
258
+
259
+ if label_split == last_label:
260
+ end_idx = idx["end_idx"]
261
+ else:
262
+ if text_input[start_idx : end_idx] != "" and last_label != "O":
263
+ predicted_event_nuggets.append(
264
+ {
265
+ "startOffset" : text_length + start_idx,
266
+ "endOffset" : text_length + end_idx,
267
+ "subtype" : last_label,
268
+ "text" : text_input[text_length + start_idx : text_length + end_idx]
269
+ }
270
+ )
271
+ start_idx = idx["start_idx"]
272
+ end_idx = idx["start_idx"] + len(idx["word"])
273
+ last_label = label_split
274
+ text_length += idx["end_idx"]
275
+ return predicted_event_nuggets
276
+
277
+
278
+
279
+
280
+
event_arg_role_dataloader.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json
2
+ from cybersecurity_knowledge_graph.utils import get_content, get_event_args, get_event_nugget, get_idxs_from_text, get_args_entity_from_idx, find_dict_by_overlap
3
+ from tqdm import tqdm
4
+ import spacy
5
+ import jsonlines
6
+ from sklearn.model_selection import train_test_split
7
+ import math
8
+ from transformers import pipeline
9
+ from sentence_transformers import SentenceTransformer
10
+ import numpy as np
11
+
12
+ embed_model = SentenceTransformer('all-MiniLM-L6-v2')
13
+
14
+ pipe = pipeline("token-classification", model="CyberPeace-Institute/SecureBERT-NER")
15
+
16
+ nlp = spacy.load('en_core_web_sm')
17
+
18
+ """
19
+ Class: EventArgumentRoleDataset
20
+ Description: This class represents a dataset for training and evaluating event argument role classifiers.
21
+ Attributes:
22
+ - path: The path to the folder containing JSON files with event data.
23
+ - tokenizer: A tokenizer for encoding text data.
24
+ - arg: The specific argument type (subtype) for which the dataset is created.
25
+ - data: A list to store data samples, each consisting of an embedding and a label.
26
+ - train_data, val_data, test_data: Lists to store the split training, validation, and test data samples.
27
+ - datapoint_id: An identifier for tracking data samples.
28
+ Methods:
29
+ - __len__(): Returns the total number of data samples in the dataset.
30
+ - __getitem__(index): Retrieves a data sample at a specified index.
31
+ - to_jsonlines(train_path, val_path, test_path): Writes the dataset to JSON files for train, validation, and test sets.
32
+ - train_val_test_split(): Splits the data into training and test sets.
33
+ - load_data(): Loads and preprocesses event data from JSON files, creating embeddings for argument-role classification.
34
+ """
35
+ class EventArgumentRoleDataset():
36
+ def __init__(self, path, tokenizer, arg):
37
+ self.path = path
38
+ self.tokenizer = tokenizer
39
+ self.arg = arg
40
+ self.data = []
41
+ self.train_data, self.val_data, self.test_data = None, None, None
42
+ self.datapoint_id = 0
43
+
44
+ def __len__(self):
45
+ return len(self.data)
46
+
47
+ def __getitem__(self, index):
48
+ sample = self.data[index]
49
+ return sample
50
+
51
+ def to_jsonlines(self, train_path, val_path, test_path):
52
+ if self.train_data is None or self.test_data is None:
53
+ raise ValueError("Do the train-val-test split")
54
+ with jsonlines.open(train_path, "w") as f:
55
+ f.write_all(self.train_data)
56
+ # with jsonlines.open(val_path, "w") as f:
57
+ # f.write_all(self.val_data)
58
+ with jsonlines.open(test_path, "w") as f:
59
+ f.write_all(self.test_data)
60
+
61
+ def train_val_test_split(self):
62
+ self.train_data, self.test_data = train_test_split(self.data, test_size=0.1, random_state=42, shuffle=True)
63
+ # self.val_data, self.test_data = train_test_split(test_val, test_size=0.5, random_state=42, shuffle=True)
64
+
65
+ def load_data(self):
66
+ folder_path = self.path
67
+ json_files = [file for file in os.listdir(folder_path) if file.endswith('.json')]
68
+
69
+ # Load the nuggets
70
+ for idx, file_path in enumerate(tqdm(json_files)):
71
+ try:
72
+ with open(self.path + file_path, "r") as f:
73
+ file_json = json.load(f)
74
+ except:
75
+ print("Error in ", file_path)
76
+ content = get_content(file_json)
77
+ content = content.replace("\xa0", " ")
78
+
79
+ event_args = get_event_args(file_json)
80
+ doc = nlp(content)
81
+
82
+ sentence_indexes = []
83
+ for sent in doc.sents:
84
+ start_index = sent[0].idx
85
+ end_index = sent[-1].idx + len(sent[-1].text)
86
+ sentence_indexes.append((start_index, end_index))
87
+
88
+ for idx, (start, end) in enumerate(sentence_indexes):
89
+ sentence = content[start:end]
90
+ is_arg_sentence = [event_arg["startOffset"] >= start and event_arg["endOffset"] <= end for event_arg in event_args]
91
+ args = [event_args[idx] for idx, boolean in enumerate(is_arg_sentence) if boolean]
92
+ if args != []:
93
+ sentence_doc = nlp(sentence)
94
+ sentence_embed = embed_model.encode(sentence)
95
+ for arg in args:
96
+ if arg["type"] == self.arg:
97
+ arg_embed = embed_model.encode(arg["text"])
98
+ embedding = np.concatenate((sentence_embed, arg_embed))
99
+
100
+ self.data.append({"embedding" : embedding, "label" : arg["role"]["type"]})
event_arg_role_predict.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cybersecurity_knowledge_graph.event_arg_role_dataloader import EventArgumentRoleDataset
2
+ from cybersecurity_knowledge_graph.utils import arg_2_role
3
+
4
+ import os
5
+ from transformers import AutoTokenizer
6
+ import optuna
7
+ from sklearn.model_selection import StratifiedKFold
8
+ from sklearn.model_selection import cross_val_score
9
+ from sklearn.metrics import make_scorer, f1_score
10
+ from sklearn.ensemble import VotingClassifier
11
+ from sklearn.linear_model import LogisticRegression
12
+ from sklearn.neural_network import MLPClassifier
13
+ from sklearn.svm import SVC
14
+ from joblib import dump, load
15
+ from sentence_transformers import SentenceTransformer
16
+ import numpy as np
17
+
18
+ embed_model = SentenceTransformer('all-MiniLM-L6-v2')
19
+
20
+ model_checkpoint = "ehsanaghaei/SecureBERT"
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
23
+
24
+ classifiers = {}
25
+ folder_path = '/cybersecurity_knowledge_graph/arg_role_models'
26
+
27
+ for filename in os.listdir(os.getcwd() + folder_path):
28
+ if filename.endswith('.joblib'):
29
+ file_path = os.getcwd() + os.path.join(folder_path, filename)
30
+ clf = load(file_path)
31
+ arg = filename.split(".")[0]
32
+ classifiers[arg] = clf
33
+
34
+ """
35
+ Function: fit()
36
+ Description: This function performs a machine learning task to train and evaluate classifiers for multiple argument roles.
37
+ It utilizes Optuna for hyperparameter optimization and creates a Voting Classifier.
38
+ The trained classifiers are saved as joblib files.
39
+ """
40
+ def fit():
41
+ for arg, roles in arg_2_role.items():
42
+ if len(roles) > 1:
43
+
44
+ dataset = EventArgumentRoleDataset(path="./data/annotation/", tokenizer=tokenizer, arg=arg)
45
+ dataset.load_data()
46
+ dataset.train_val_test_split()
47
+
48
+
49
+ X = [datapoint["embedding"] for datapoint in dataset.data]
50
+ y = [roles.index(datapoint["label"]) for datapoint in dataset.data]
51
+
52
+
53
+ # FYI: Objective functions can take additional arguments
54
+ # (https://optuna.readthedocs.io/en/stable/faq.html#objective-func-additional-args).
55
+ def objective(trial):
56
+
57
+ classifier_name = trial.suggest_categorical("classifier", ["voting"])
58
+ if classifier_name == "voting":
59
+ svc_c = trial.suggest_float("svc_c", 1e-3, 1e3, log=True)
60
+ svc_kernel = trial.suggest_categorical("kernel", ['rbf'])
61
+ classifier_obj = VotingClassifier(estimators=[
62
+ ('Logistic Regression', LogisticRegression()),
63
+ ('Neural Network', MLPClassifier(max_iter=500)),
64
+ ('Support Vector Machine', SVC(C=svc_c, kernel=svc_kernel))
65
+ ], voting='hard')
66
+
67
+ f1_scorer = make_scorer(f1_score, average = "weighted")
68
+ stratified_kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
69
+ cv_scores = cross_val_score(classifier_obj, X, y, cv=stratified_kfold, scoring=f1_scorer)
70
+ return cv_scores.mean()
71
+
72
+
73
+ study = optuna.create_study(direction="maximize")
74
+ study.optimize(objective, n_trials=20)
75
+ print(f"{arg} : {study.best_trial.values[0]}")
76
+
77
+ best_clf = VotingClassifier(estimators=[
78
+ ('Logistic Regression', LogisticRegression()),
79
+ ('Neural Network', MLPClassifier(max_iter=500)),
80
+ ('Support Vector Machine', SVC(C=study.best_trial.params["svc_c"], kernel=study.best_trial.params["kernel"]))
81
+ ], voting='hard')
82
+
83
+ best_clf.fit(X, y)
84
+ dump(best_clf, f'{arg}.joblib')
85
+
86
+ """
87
+ Function: get_arg_roles(event_args, doc)
88
+ Description: This function assigns argument roles to a list of event arguments within a document.
89
+ Inputs:
90
+ - event_args: A list of event argument dictionaries, each containing information about an argument.
91
+ - doc: A spaCy document representing the analyzed text.
92
+ Output:
93
+ - The input 'event_args' list with updated 'role' values assigned to each argument.
94
+ """
95
+ def get_arg_roles(event_args, doc):
96
+ for arg in event_args:
97
+ if len(arg_2_role[arg["subtype"]]) > 1:
98
+ sent = next(filter(lambda x : arg["startOffset"] >= x.start_char and arg["endOffset"] <= x.end_char, doc.sents))
99
+
100
+ sent_embed = embed_model.encode(sent.text)
101
+ arg_embed = embed_model.encode(arg["text"])
102
+ embed = np.concatenate((sent_embed, arg_embed))
103
+
104
+ arg_clf = classifiers[arg["subtype"]]
105
+ role_id = arg_clf.predict(embed.reshape(1, -1))
106
+ role = arg_2_role[arg["subtype"]][role_id[0]]
107
+
108
+ arg["role"] = role
109
+ else:
110
+ arg["role"] = arg_2_role[arg["subtype"]][0]
111
+ return event_args
112
+
113
+
event_nugget_predict.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from annotated_text import annotated_text
3
+ import torch
4
+ from torch import nn
5
+ from torch.utils.data import DataLoader
6
+ from cybersecurity_knowledge_graph.nugget_model_utils import CustomRobertaWithPOS as NuggetModel
7
+ from cybersecurity_knowledge_graph.nugget_model_utils import tokenize_and_align_labels_with_pos_ner_dep, find_nearest_nugget_features, find_dep_depth
8
+ from cybersecurity_knowledge_graph.utils import get_idxs_from_text, event_nugget_list
9
+ import spacy
10
+ from transformers import AutoTokenizer
11
+ from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
12
+ import os
13
+
14
+
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
16
+
17
+ def find_dep_depth(token):
18
+ depth = 0
19
+ current_token = token
20
+ while current_token.head != current_token:
21
+ depth += 1
22
+ current_token = current_token.head
23
+ return min(depth, 16)
24
+
25
+
26
+ nlp = spacy.load('en_core_web_sm')
27
+
28
+ pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
29
+ ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
30
+ dep_spacy_tag_list = list(nlp.get_pipe("parser").labels)
31
+
32
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
33
+
34
+ model_checkpoint = "ehsanaghaei/SecureBERT"
35
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
36
+
37
+ model_nugget = NuggetModel(num_classes = 11)
38
+ model_nugget.load_state_dict(torch.load("cybersecurity_knowledge_graph/nugget_model_state_dict.pth", map_location=device))
39
+ model_nugget.eval()
40
+
41
+ """
42
+ Function: create_dataloader(text_input)
43
+ Description: This function prepares a DataLoader for processing text input, including tokenization and alignment of labels.
44
+ Inputs:
45
+ - text_input: The input text to be processed.
46
+ Output:
47
+ - dataloader: A DataLoader for the tokenized and batched text data.
48
+ - tokenized_dataset_ner: The tokenized dataset used for training.
49
+ """
50
+ def create_dataloader(text_input):
51
+
52
+ doc = nlp(text_input)
53
+
54
+ content_as_words_emdash = [tok.text for tok in doc]
55
+ content_as_words_emdash = [word.replace("``", '"').replace("''", '"').replace("$", "") for word in content_as_words_emdash]
56
+ content_idx_dict = get_idxs_from_text(text_input, content_as_words_emdash)
57
+
58
+ data = []
59
+
60
+ words = []
61
+
62
+ pos_spacy = [tok.pos_ for tok in doc]
63
+ ner_spacy = [ent.ent_iob_ + "-" + ent.ent_type_ if ent.ent_iob_ != "O" else ent.ent_iob_ for ent in doc]
64
+ dep_spacy = [tok.dep_ for tok in doc]
65
+ depth_spacy = [find_dep_depth(tok) for tok in doc]
66
+
67
+ for content_dict in content_idx_dict:
68
+ start_idx, end_idx = content_dict["start_idx"], content_dict["end_idx"]
69
+ words.append(content_dict["word"])
70
+
71
+
72
+ content_token_len = len(tokenizer(words, truncation=False, is_split_into_words=True)["input_ids"])
73
+ if content_token_len > tokenizer.model_max_length:
74
+ no_split = (content_token_len // tokenizer.model_max_length) + 2
75
+ split_len = (len(words) // no_split) + 1
76
+
77
+ last_id = 0
78
+ threshold = split_len
79
+
80
+ for id, token in enumerate(words):
81
+ if token == "." and id > threshold:
82
+ data.append(
83
+ {
84
+ "tokens" : words[last_id : id + 1],
85
+ "pos_spacy" : pos_spacy[last_id : id + 1],
86
+ "ner_spacy" : ner_spacy[last_id : id + 1],
87
+ "dep_spacy" : dep_spacy[last_id : id + 1],
88
+ "depth_spacy" : depth_spacy[last_id : id + 1],
89
+ }
90
+ )
91
+ last_id = id + 1
92
+ threshold += split_len
93
+ data.append({"tokens" : words[last_id : ],
94
+ "pos_spacy" : pos_spacy[last_id : ],
95
+ "ner_spacy" : ner_spacy[last_id : ],
96
+ "dep_spacy" : dep_spacy[last_id : ],
97
+ "depth_spacy" : depth_spacy[last_id : ]})
98
+ else:
99
+ data.append(
100
+ {
101
+ "tokens" : words,
102
+ "pos_spacy" : pos_spacy,
103
+ "ner_spacy" : ner_spacy,
104
+ "dep_spacy" : dep_spacy,
105
+ "depth_spacy" : depth_spacy
106
+ }
107
+ )
108
+
109
+
110
+ ner_features = Features({'tokens' : Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
111
+ 'pos_spacy' : Sequence(feature=ClassLabel(num_classes=len(pos_spacy_tag_list), names=pos_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
112
+ 'ner_spacy' : Sequence(feature=ClassLabel(num_classes=len(ner_spacy_tag_list), names=ner_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
113
+ 'dep_spacy' : Sequence(feature=ClassLabel(num_classes=len(dep_spacy_tag_list), names=dep_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
114
+ 'depth_spacy' : Sequence(feature=ClassLabel(num_classes=17, names= list(range(17)), names_file=None, id=None), length=-1, id=None)
115
+ })
116
+
117
+ dataset = Dataset.from_list(data, features=ner_features)
118
+ tokenized_dataset_ner = dataset.map(tokenize_and_align_labels_with_pos_ner_dep, fn_kwargs={'tokenizer' : tokenizer}, batched=True, load_from_cache_file=False)
119
+ tokenized_dataset_ner = tokenized_dataset_ner.with_format("torch")
120
+
121
+ tokenized_dataset_ner = tokenized_dataset_ner.remove_columns("tokens")
122
+
123
+ batch_size = 4 # Number of input texts
124
+ dataloader = DataLoader(tokenized_dataset_ner, batch_size=batch_size)
125
+ # TODO : context_idx_dict should be used to index the words
126
+ return dataloader, tokenized_dataset_ner
127
+
128
+ """
129
+ Function: predict(dataloader)
130
+ Description: This function performs inference on a given DataLoader using a trained model and returns the predicted labels.
131
+ Inputs:
132
+ - dataloader: A DataLoader containing input data for prediction.
133
+ Output:
134
+ - predicted_label: A tensor containing the predicted labels for the input data.
135
+ """
136
+ def predict(dataloader):
137
+ predicted_label = []
138
+ for batch in dataloader:
139
+ with torch.no_grad():
140
+ logits = model_nugget(**batch)
141
+ batch_predicted_label = logits.argmax(-1)
142
+ predicted_label.append(batch_predicted_label)
143
+ return torch.cat(predicted_label, dim=-1)
144
+
145
+ """
146
+ Function: show_annotations(text_input)
147
+ Description: This function displays annotated event nuggets in the provided input text using the Streamlit library.
148
+ Inputs:
149
+ - text_input: The input text containing event nuggets to be annotated and displayed.
150
+ Output:
151
+ - An interactive display of annotated event nuggets within the input text.
152
+ """
153
+ def show_annotations(text_input):
154
+ st.title("Event Nuggets")
155
+
156
+ dataloader, tokenized_dataset_ner = create_dataloader(text_input)
157
+ predicted_label = predict(dataloader)
158
+
159
+ for idx, labels in enumerate(predicted_label):
160
+ token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
161
+
162
+ tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
163
+ tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
164
+
165
+ text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
166
+ idxs = get_idxs_from_text(text, tokens)
167
+
168
+ labels = labels[token_mask]
169
+
170
+ annotated_text_list = []
171
+ last_label = ""
172
+ cumulative_tokens = ""
173
+ last_id = 0
174
+
175
+ for idx, label in zip(idxs, labels):
176
+ to_label = event_nugget_list[label]
177
+ label_short = to_label.split("-")[1] if "-" in to_label else to_label
178
+ if last_label == label_short:
179
+ cumulative_tokens += text[last_id : idx["end_idx"]]
180
+ last_id = idx["end_idx"]
181
+ else:
182
+ if last_label != "":
183
+ if last_label == "O":
184
+ annotated_text_list.append(cumulative_tokens)
185
+ else:
186
+ annotated_text_list.append((cumulative_tokens, last_label))
187
+ last_label = label_short
188
+ cumulative_tokens = idx["word"]
189
+ last_id = idx["end_idx"]
190
+ if last_label == "O":
191
+ annotated_text_list.append(cumulative_tokens)
192
+ else:
193
+ annotated_text_list.append((cumulative_tokens, last_label))
194
+ annotated_text(annotated_text_list)
195
+
196
+ """
197
+ Function: get_event_nuggets(text_input)
198
+ Description: This function extracts predicted event nuggets (event entities) from the provided input text.
199
+ Inputs:
200
+ - text_input: The input text containing event nuggets to be extracted.
201
+ Output:
202
+ - predicted_event_nuggets: A list of dictionaries, each representing an extracted event nugget with start and end offsets,
203
+ subtype, and text content.
204
+ """
205
+ def get_event_nuggets(text_input):
206
+ dataloader, tokenized_dataset_ner = create_dataloader(text_input)
207
+ predicted_label = predict(dataloader)
208
+
209
+ predicted_event_nuggets = []
210
+ text_length = 0
211
+ for idx, labels in enumerate(predicted_label):
212
+ token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
213
+
214
+ tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
215
+ tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
216
+
217
+ text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
218
+ idxs = get_idxs_from_text(text_input[text_length : ], tokens)
219
+
220
+ labels = labels[token_mask]
221
+
222
+ start_idx = 0
223
+ end_idx = 0
224
+ last_label = ""
225
+
226
+ for idx, label in zip(idxs, labels):
227
+ to_label = event_nugget_list[label]
228
+ label_short = to_label.split("-")[1] if "-" in to_label else to_label
229
+
230
+ if label_short == last_label:
231
+ end_idx = idx["end_idx"]
232
+ else:
233
+ if text_input[start_idx : end_idx] != "" and last_label != "O":
234
+ predicted_event_nuggets.append(
235
+ {
236
+ "startOffset" : text_length + start_idx,
237
+ "endOffset" : text_length + end_idx,
238
+ "subtype" : last_label,
239
+ "text" : text_input[text_length + start_idx : text_length + end_idx]
240
+ }
241
+ )
242
+ start_idx = idx["start_idx"]
243
+ end_idx = idx["start_idx"] + len(idx["word"])
244
+ last_label = label_short
245
+
246
+ text_length += idx["end_idx"]
247
+ return predicted_event_nuggets
248
+
249
+
250
+
event_realis_predict.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spacy
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ from transformers import AutoTokenizer
6
+ from cybersecurity_knowledge_graph.utils import get_idxs_from_text
7
+ import streamlit as st
8
+ from annotated_text import annotated_text
9
+ from cybersecurity_knowledge_graph.nugget_model_utils import CustomRobertaWithPOS
10
+ from cybersecurity_knowledge_graph.event_nugget_predict import get_event_nuggets
11
+ from cybersecurity_knowledge_graph.realis_model_utils import get_entity_for_realis_from_idx, tokenize_and_align_labels_with_pos_ner_realis
12
+ from datasets import load_dataset, Features, ClassLabel, Value, Sequence, Dataset
13
+
14
+ event_nugget_list = ['B-Phishing',
15
+ 'I-Phishing',
16
+ 'O',
17
+ 'B-DiscoverVulnerability',
18
+ 'B-Ransom',
19
+ 'I-Ransom',
20
+ 'B-Databreach',
21
+ 'I-DiscoverVulnerability',
22
+ 'B-PatchVulnerability',
23
+ 'I-PatchVulnerability',
24
+ 'I-Databreach']
25
+
26
+ realis_list = ["O", "Generic", "Other", "Actual"]
27
+
28
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
29
+
30
+
31
+
32
+ def find_dep_depth(token):
33
+ depth = 0
34
+ current_token = token
35
+ while current_token.head != current_token:
36
+ depth += 1
37
+ current_token = current_token.head
38
+ return min(depth, 16)
39
+
40
+
41
+ nlp = spacy.load('en_core_web_sm')
42
+
43
+ pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
44
+ ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
45
+ dep_spacy_tag_list = list(nlp.get_pipe("parser").labels)
46
+
47
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
48
+
49
+ model_checkpoint = "ehsanaghaei/SecureBERT"
50
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
51
+
52
+ from cybersecurity_knowledge_graph.realis_model_utils import CustomRobertaWithPOS as RealisModel
53
+ model_realis = RealisModel(num_classes_realis=4)
54
+ model_realis.load_state_dict(torch.load("cybersecurity_knowledge_graph/realis_model_state_dict.pth", map_location=device))
55
+ model_realis.eval()
56
+
57
+ """
58
+ Function: create_dataloader(text_input)
59
+ Description: This function prepares a DataLoader for processing text input, including tokenization and alignment of labels.
60
+ Inputs:
61
+ - text_input: The input text to be processed.
62
+ Output:
63
+ - dataloader: A DataLoader for the tokenized and batched text data.
64
+ - tokenized_dataset_ner: The tokenized dataset used for training.
65
+ """
66
+ def create_dataloader(text_input):
67
+
68
+ event_nuggets = get_event_nuggets(text_input)
69
+ doc = nlp(text_input)
70
+
71
+ content_as_words_emdash = [tok.text for tok in doc]
72
+ content_as_words_emdash = [word.replace("``", '"').replace("''", '"').replace("$", "") for word in content_as_words_emdash]
73
+ content_idx_dict = get_idxs_from_text(text_input, content_as_words_emdash)
74
+
75
+ data = []
76
+
77
+ words = []
78
+ nugget_ner_tags = []
79
+
80
+ pos_spacy = [tok.pos_ for tok in doc]
81
+ ner_spacy = [ent.ent_iob_ + "-" + ent.ent_type_ if ent.ent_iob_ != "O" else ent.ent_iob_ for ent in doc]
82
+ dep_spacy = [tok.dep_ for tok in doc]
83
+ depth_spacy = [find_dep_depth(tok) for tok in doc]
84
+
85
+ for content_dict in content_idx_dict:
86
+ start_idx, end_idx = content_dict["start_idx"], content_dict["end_idx"]
87
+ entity = get_entity_for_realis_from_idx(start_idx, end_idx, event_nuggets)
88
+ words.append(content_dict["word"])
89
+ nugget_ner_tags.append(entity)
90
+
91
+
92
+ content_token_len = len(tokenizer(words, truncation=False, is_split_into_words=True)["input_ids"])
93
+ if content_token_len > tokenizer.model_max_length:
94
+ no_split = (content_token_len // tokenizer.model_max_length) + 2
95
+ split_len = (len(words) // no_split) + 1
96
+
97
+ last_id = 0
98
+ threshold = split_len
99
+
100
+ for id, token in enumerate(words):
101
+ if token == "." and id > threshold:
102
+ data.append(
103
+ {
104
+ "tokens" : words[last_id : id + 1],
105
+ "ner_tags" : nugget_ner_tags[last_id : id + 1],
106
+ "pos_spacy" : pos_spacy[last_id : id + 1],
107
+ "ner_spacy" : ner_spacy[last_id : id + 1],
108
+ "dep_spacy" : dep_spacy[last_id : id + 1],
109
+ "depth_spacy" : depth_spacy[last_id : id + 1],
110
+ }
111
+ )
112
+ last_id = id + 1
113
+ threshold += split_len
114
+ data.append({"tokens" : words[last_id : ],
115
+ "ner_tags" : nugget_ner_tags[last_id : ],
116
+ "pos_spacy" : pos_spacy[last_id : ],
117
+ "ner_spacy" : ner_spacy[last_id : ],
118
+ "dep_spacy" : dep_spacy[last_id : ],
119
+ "depth_spacy" : depth_spacy[last_id : ]})
120
+ else:
121
+ data.append(
122
+ {
123
+ "tokens" : words,
124
+ "ner_tags" : nugget_ner_tags,
125
+ "pos_spacy" : pos_spacy,
126
+ "ner_spacy" : ner_spacy,
127
+ "dep_spacy" : dep_spacy,
128
+ "depth_spacy" : depth_spacy
129
+ }
130
+ )
131
+
132
+
133
+ ner_features = Features({'tokens' : Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
134
+ 'ner_tags' : Sequence(feature=ClassLabel(num_classes=len(event_nugget_list), names=event_nugget_list, names_file=None, id=None), length=-1, id=None),
135
+ 'pos_spacy' : Sequence(feature=ClassLabel(num_classes=len(pos_spacy_tag_list), names=pos_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
136
+ 'ner_spacy' : Sequence(feature=ClassLabel(num_classes=len(ner_spacy_tag_list), names=ner_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
137
+ 'dep_spacy' : Sequence(feature=ClassLabel(num_classes=len(dep_spacy_tag_list), names=dep_spacy_tag_list, names_file=None, id=None), length=-1, id=None),
138
+ 'depth_spacy' : Sequence(feature=ClassLabel(num_classes=17, names= list(range(17)), names_file=None, id=None), length=-1, id=None)
139
+ })
140
+
141
+ dataset = Dataset.from_list(data, features=ner_features)
142
+ tokenized_dataset_ner = dataset.map(tokenize_and_align_labels_with_pos_ner_realis, fn_kwargs={'tokenizer' : tokenizer, 'ner_names' : event_nugget_list}, batched=True, load_from_cache_file=False)
143
+ tokenized_dataset_ner = tokenized_dataset_ner.with_format("torch")
144
+
145
+ tokenized_dataset_ner = tokenized_dataset_ner.remove_columns("tokens")
146
+
147
+ batch_size = 4 # Number of input texts
148
+ dataloader = DataLoader(tokenized_dataset_ner, batch_size=batch_size)
149
+ return dataloader, tokenized_dataset_ner
150
+
151
+ """
152
+ Function: predict(dataloader)
153
+ Description: This function performs inference on a given DataLoader using a trained model and returns the predicted labels.
154
+ Inputs:
155
+ - dataloader: A DataLoader containing input data for prediction.
156
+ Output:
157
+ - predicted_label: A tensor containing the predicted labels for the input data.
158
+ """
159
+ def predict(dataloader):
160
+ predicted_label = []
161
+ for batch in dataloader:
162
+ with torch.no_grad():
163
+ logits = model_realis(**batch)
164
+
165
+ batch_predicted_label = logits.argmax(-1)
166
+ predicted_label.append(batch_predicted_label)
167
+ return torch.cat(predicted_label, dim=-1)
168
+
169
+ """
170
+ Function: show_annotations(text_input)
171
+ Description: This function displays annotated event nuggets in the provided input text using the Streamlit library.
172
+ Inputs:
173
+ - text_input: The input text containing event nuggets to be annotated and displayed.
174
+ Output:
175
+ - An interactive display of annotated event nuggets within the input text.
176
+ """
177
+ def show_annotations(text_input):
178
+ st.title("Event Realis")
179
+
180
+ dataloader, tokenized_dataset_ner = create_dataloader(text_input)
181
+ predicted_label = predict(dataloader)
182
+
183
+ for idx, labels in enumerate(predicted_label):
184
+ token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
185
+
186
+ tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
187
+ tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
188
+
189
+ text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
190
+ idxs = get_idxs_from_text(text, tokens)
191
+
192
+ labels = labels[token_mask]
193
+
194
+ annotated_text_list = []
195
+ last_label = ""
196
+ cumulative_tokens = ""
197
+ last_id = 0
198
+
199
+ for idx, label in zip(idxs, labels):
200
+ to_label = realis_list[label]
201
+ label_short = to_label.split("-")[1] if "-" in to_label else to_label
202
+ if last_label == label_short:
203
+ cumulative_tokens += text[last_id : idx["end_idx"]]
204
+ last_id = idx["end_idx"]
205
+ else:
206
+ if last_label != "":
207
+ if last_label == "O":
208
+ annotated_text_list.append(cumulative_tokens)
209
+ else:
210
+ annotated_text_list.append((cumulative_tokens, last_label))
211
+ last_label = label_short
212
+ cumulative_tokens = idx["word"]
213
+ last_id = idx["end_idx"]
214
+ if last_label == "O":
215
+ annotated_text_list.append(cumulative_tokens)
216
+ else:
217
+ annotated_text_list.append((cumulative_tokens, last_label))
218
+ annotated_text(annotated_text_list)
219
+
220
+ """
221
+ Function: get_event_realis(text_input)
222
+ Description: This function extracts predicted event realis (event modality) from the provided input text.
223
+ Inputs:
224
+ - text_input: The input text containing event realis to be extracted.
225
+ Output:
226
+ - predicted_event_realis: A list of dictionaries, each representing an extracted event realis with start and end offsets,
227
+ realis type, and text content.
228
+ """
229
+ def get_event_realis(text_input):
230
+ dataloader, tokenized_dataset_ner = create_dataloader(text_input)
231
+ predicted_label = predict(dataloader)
232
+
233
+ predicted_event_realis = []
234
+ text_length = 0
235
+ for idx, labels in enumerate(predicted_label):
236
+ token_mask = [token > 2 for token in tokenized_dataset_ner[idx]["input_ids"]]
237
+
238
+ tokens = tokenizer.convert_ids_to_tokens(tokenized_dataset_ner[idx]["input_ids"][token_mask], skip_special_tokens=True)
239
+ tokens = [token.replace("Ġ", "").replace("Ċ", "").replace("âĢĻ", "'") for token in tokens]
240
+
241
+ text = tokenizer.decode(tokenized_dataset_ner[idx]["input_ids"][token_mask])
242
+ idxs = get_idxs_from_text(text_input[text_length : ], tokens)
243
+
244
+ labels = labels[token_mask]
245
+
246
+ start_idx = 0
247
+ end_idx = 0
248
+ last_label = ""
249
+
250
+ for idx, label in zip(idxs, labels):
251
+ to_label = realis_list[label]
252
+ label_split = to_label
253
+
254
+ if label_split == last_label:
255
+ end_idx = idx["end_idx"]
256
+ else:
257
+ if text_input[start_idx : end_idx] != "" and last_label != "O":
258
+ predicted_event_realis.append(
259
+ {
260
+ "startOffset" : text_length + start_idx,
261
+ "endOffset" : text_length + end_idx,
262
+ "realis" : last_label,
263
+ "text" : text_input[text_length + start_idx : text_length + end_idx]
264
+ }
265
+ )
266
+ start_idx = idx["start_idx"]
267
+ end_idx = idx["start_idx"] + len(idx["word"])
268
+ last_label = label_split
269
+ text_length += idx["end_idx"]
270
+ return predicted_event_realis
model_59.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09bc24b422adbe6c4c6ca1333a3a8c33146e6152e00a7ad6376cab616b51e53f
3
+ size 498858353
model_64_pos_ner.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76125c5bbce2c32e536fe74d24dc51fb1fce3ba076104b459ee290102ce4bd5d
3
+ size 498746934
model_66.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46531e8ccf92661a025b15c829be791f72416d1b458ae1aa82cc66e069193bf5
3
+ size 498751092
model_97.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6147a98aa2baaa545903103e9e2f0e55fc249ec638cfe27e273ffdd247479c4
3
+ size 498729523
nugget_model_state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d04c7ccd654b3af96c1c8e0f391a20d79ae1b5970d5419680f379c6a09e78bf
3
+ size 498703483
nugget_model_utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spacy
3
+ import en_core_web_sm
4
+ from torch import nn
5
+ import math
6
+
7
+
8
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
9
+
10
+ from transformers import AutoModel, TrainingArguments, Trainer, RobertaTokenizer, RobertaModel
11
+ from transformers import AutoTokenizer
12
+
13
+ model_checkpoint = "ehsanaghaei/SecureBERT"
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
16
+ roberta_model = RobertaModel.from_pretrained(model_checkpoint).to(device)
17
+
18
+ nlp = en_core_web_sm.load()
19
+ pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
20
+ ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
21
+
22
+
23
+ class CustomRobertaWithPOS(nn.Module):
24
+ def __init__(self, num_classes):
25
+ super(CustomRobertaWithPOS, self).__init__()
26
+ self.num_classes = num_classes
27
+ self.pos_embed = nn.Embedding(len(pos_spacy_tag_list), 16)
28
+ self.ner_embed = nn.Embedding(len(ner_spacy_tag_list), 16)
29
+ self.roberta = roberta_model
30
+ self.dropout1 = nn.Dropout(0.2)
31
+ self.fc1 = nn.Linear(self.roberta.config.hidden_size, num_classes)
32
+
33
+ def forward(self, input_ids, attention_mask, pos_spacy, ner_spacy, dep_spacy, depth_spacy):
34
+ outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
35
+ last_hidden_output = outputs.last_hidden_state
36
+
37
+ pos_mask = pos_spacy != -100
38
+
39
+ pos_one_hot = torch.zeros((pos_spacy.shape[0], pos_spacy.shape[1], len(pos_spacy_tag_list)), dtype=torch.long)
40
+ pos_one_hot[pos_mask, pos_spacy[pos_mask]] = 1
41
+ pos_one_hot = pos_one_hot.to(device)
42
+
43
+ ner_mask = ner_spacy != -100
44
+
45
+ ner_one_hot = torch.zeros((ner_spacy.shape[0], ner_spacy.shape[1], len(ner_spacy_tag_list)), dtype=torch.long)
46
+ ner_one_hot[ner_mask, ner_spacy[ner_mask]] = 1
47
+ ner_one_hot = ner_one_hot.to(device)
48
+
49
+ features_concat = last_hidden_output
50
+ features_concat = self.dropout1(features_concat)
51
+
52
+ logits = self.fc1(features_concat)
53
+
54
+ return logits
55
+
56
+
57
+ def tokenize_and_align_labels_with_pos_ner_dep(examples, tokenizer, label_all_tokens = True):
58
+ tokenized_inputs = tokenizer(examples["tokens"], padding='max_length', truncation=True, is_split_into_words=True)
59
+ #tokenized_inputs.pop('input_ids')
60
+ ner_spacy = []
61
+ pos_spacy = []
62
+ dep_spacy = []
63
+ depth_spacy = []
64
+
65
+ for i, (pos, ner, dep, depth) in enumerate(zip(examples["pos_spacy"],
66
+ examples["ner_spacy"],
67
+ examples["dep_spacy"],
68
+ examples["depth_spacy"])):
69
+ word_ids = tokenized_inputs.word_ids(batch_index=i)
70
+ previous_word_idx = None
71
+ ner_spacy_ids = []
72
+ pos_spacy_ids = []
73
+ dep_spacy_ids = []
74
+ depth_spacy_ids = []
75
+
76
+ for word_idx in word_ids:
77
+ # Special tokens have a word id that is None. We set the label to -100 so they are automatically
78
+ # ignored in the loss function.
79
+ if word_idx is None:
80
+ ner_spacy_ids.append(-100)
81
+ pos_spacy_ids.append(-100)
82
+ dep_spacy_ids.append(-100)
83
+ depth_spacy_ids.append(-100)
84
+ # We set the label for the first token of each word.
85
+ elif word_idx != previous_word_idx:
86
+ ner_spacy_ids.append(ner[word_idx])
87
+ pos_spacy_ids.append(pos[word_idx])
88
+ dep_spacy_ids.append(dep[word_idx])
89
+ depth_spacy_ids.append(depth[word_idx])
90
+ # For the other tokens in a word, we set the label to either the current label or -100, depending on
91
+ # the label_all_tokens flag.
92
+ else:
93
+ ner_spacy_ids.append(ner[word_idx] if label_all_tokens else -100)
94
+ pos_spacy_ids.append(pos[word_idx] if label_all_tokens else -100)
95
+ dep_spacy_ids.append(dep[word_idx] if label_all_tokens else -100)
96
+ depth_spacy_ids.append(depth[word_idx] if label_all_tokens else -100)
97
+ previous_word_idx = word_idx
98
+
99
+ ner_spacy.append(ner_spacy_ids)
100
+ pos_spacy.append(pos_spacy_ids)
101
+ dep_spacy.append(dep_spacy_ids)
102
+ depth_spacy.append(depth_spacy_ids)
103
+
104
+ tokenized_inputs["pos_spacy"] = pos_spacy
105
+ tokenized_inputs["ner_spacy"] = ner_spacy
106
+ tokenized_inputs["dep_spacy"] = dep_spacy
107
+ tokenized_inputs["depth_spacy"] = depth_spacy
108
+
109
+ return tokenized_inputs
110
+
111
+
112
+ def find_nearest_nugget_features(doc, start_idx, end_idx, event_nuggets):
113
+ nearest_subtype = None
114
+ nearest_dist = math.inf
115
+ relative_pos = None
116
+
117
+ mid_idx = (end_idx + start_idx) / 2
118
+ for nugget in event_nuggets:
119
+ mid_nugget_idx = (nugget["nugget"]["startOffset"] + nugget["nugget"]["endOffset"]) / 2
120
+ dist = abs(mid_nugget_idx - mid_idx)
121
+
122
+ if dist < nearest_dist:
123
+ nearest_dist = dist
124
+ nearest_subtype = nugget["subtype"]
125
+ for sent in doc.sents:
126
+ if between_idxs(mid_idx, sent.start_char, sent.end_char) and between_idxs(mid_nugget_idx, sent.start_char, sent.end_char):
127
+ if mid_idx < mid_nugget_idx:
128
+ relative_pos = "before-same-sentence"
129
+ else:
130
+ relative_pos = "after-same-sentence"
131
+ break
132
+ elif between_idxs(mid_nugget_idx, sent.start_char, sent.end_char) and mid_idx > mid_nugget_idx:
133
+ relative_pos = "after-differ-sentence"
134
+ break
135
+ elif between_idxs(mid_idx, sent.start_char, sent.end_char) and mid_idx < mid_nugget_idx:
136
+ relative_pos = "before-differ-sentence"
137
+ break
138
+
139
+ nearest_dist = int(min(10, nearest_dist // 20))
140
+ return nearest_subtype, nearest_dist, relative_pos
141
+
142
+ def find_dep_depth(token):
143
+ depth = 0
144
+ current_token = token
145
+ while current_token.head != current_token:
146
+ depth += 1
147
+ current_token = current_token.head
148
+ return min(depth, 16)
149
+
150
+ def between_idxs(idx, start_idx, end_idx):
151
+ return idx >= start_idx and idx <= end_idx
realis_model_state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ad63eeee95888dc6f22e94e0a8425a99912f7d727cd255881e8630218a3b7f0
3
+ size 498684837
realis_model_utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import en_core_web_sm
4
+ from transformers import AutoModel, TrainingArguments, Trainer, RobertaTokenizer, RobertaModel
5
+ from transformers import AutoTokenizer
6
+
7
+ model_checkpoint = "ehsanaghaei/SecureBERT"
8
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
11
+ roberta_model = RobertaModel.from_pretrained(model_checkpoint).to(device)
12
+
13
+ event_nugget_list = ['B-Phishing',
14
+ 'I-Phishing',
15
+ 'O',
16
+ 'B-DiscoverVulnerability',
17
+ 'B-Ransom',
18
+ 'I-Ransom',
19
+ 'B-Databreach',
20
+ 'I-DiscoverVulnerability',
21
+ 'B-PatchVulnerability',
22
+ 'I-PatchVulnerability',
23
+ 'I-Databreach']
24
+
25
+ nlp = en_core_web_sm.load()
26
+ pos_spacy_tag_list = ["ADJ","ADP","ADV","AUX","CCONJ","DET","INTJ","NOUN","NUM","PART","PRON","PROPN","PUNCT","SCONJ","SYM","VERB","SPACE","X"]
27
+ ner_spacy_tag_list = [bio + entity for entity in list(nlp.get_pipe('ner').labels) for bio in ["B-", "I-"]] + ["O"]
28
+ dep_spacy_tag_list = list(nlp.get_pipe("parser").labels)
29
+
30
+ class CustomRobertaWithPOS(nn.Module):
31
+ def __init__(self, num_classes_realis):
32
+ super(CustomRobertaWithPOS, self).__init__()
33
+ self.num_classes_realis = num_classes_realis
34
+ self.pos_embed = nn.Embedding(len(pos_spacy_tag_list), 16)
35
+ self.ner_embed = nn.Embedding(len(ner_spacy_tag_list), 8)
36
+ self.dep_embed = nn.Embedding(len(dep_spacy_tag_list), 8)
37
+ self.depth_embed = nn.Embedding(17, 8)
38
+ self.nugget_embed = nn.Embedding(len(event_nugget_list), 8)
39
+ self.roberta = roberta_model
40
+ self.dropout1 = nn.Dropout(0.2)
41
+ self.fc1 = nn.Linear(self.roberta.config.hidden_size + 48, self.num_classes_realis)
42
+
43
+ def forward(self, input_ids, attention_mask, pos_spacy, ner_spacy, dep_spacy, depth_spacy, ner_tags):
44
+ outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
45
+ last_hidden_output = outputs.last_hidden_state
46
+
47
+ pos_mask = pos_spacy != -100
48
+ pos_embed_masked = self.pos_embed(pos_spacy[pos_mask])
49
+ pos_embed = torch.zeros((pos_spacy.shape[0], pos_spacy.shape[1], 16), dtype=torch.float).to(device)
50
+ pos_embed[pos_mask] = pos_embed_masked
51
+
52
+ ner_mask = ner_spacy != -100
53
+ ner_embed_masked = self.ner_embed(ner_spacy[ner_mask])
54
+ ner_embed = torch.zeros((ner_spacy.shape[0], ner_spacy.shape[1], 8), dtype=torch.float).to(device)
55
+ ner_embed[ner_mask] = ner_embed_masked
56
+
57
+ dep_mask = dep_spacy != -100
58
+ dep_embed_masked = self.dep_embed(dep_spacy[dep_mask])
59
+ dep_embed = torch.zeros((dep_spacy.shape[0], dep_spacy.shape[1], 8), dtype=torch.float).to(device)
60
+ dep_embed[dep_mask] = dep_embed_masked
61
+
62
+ depth_mask = depth_spacy != -100
63
+ depth_embed_masked = self.depth_embed(depth_spacy[depth_mask])
64
+ depth_embed = torch.zeros((depth_spacy.shape[0], depth_spacy.shape[1], 8), dtype=torch.float).to(device)
65
+ depth_embed[dep_mask] = depth_embed_masked
66
+
67
+ nugget_mask = ner_tags != -100
68
+ nugget_embed_masked = self.nugget_embed(ner_tags[nugget_mask])
69
+ nugget_embed = torch.zeros((ner_tags.shape[0], ner_tags.shape[1], 8), dtype=torch.float).to(device)
70
+ nugget_embed[dep_mask] = nugget_embed_masked
71
+
72
+ features_concat = torch.cat((last_hidden_output, pos_embed, ner_embed, dep_embed, depth_embed, nugget_embed), 2).to(device)
73
+ features_concat = self.dropout1(features_concat)
74
+ features_concat = self.dropout1(features_concat)
75
+
76
+ logits = self.fc1(features_concat)
77
+
78
+ return logits
79
+
80
+
81
+ def get_entity_for_realis_from_idx(start_idx, end_idx, event_nuggets):
82
+ event_nuggets_idxs = [(nugget["startOffset"], nugget["endOffset"]) for nugget in event_nuggets]
83
+ for idx, (nugget_start, nugget_end) in enumerate(event_nuggets_idxs):
84
+ if (start_idx == nugget_start and end_idx == nugget_end) or (start_idx == nugget_start and end_idx <= nugget_end) or (start_idx == nugget_start and end_idx > nugget_end) or (end_idx == nugget_end and start_idx < nugget_start) or (start_idx <= nugget_start and end_idx <= nugget_end and end_idx > nugget_start):
85
+ return "B-" + event_nuggets[idx]["subtype"]
86
+ elif (start_idx > nugget_start and end_idx <= nugget_end) or (start_idx > nugget_start and start_idx < nugget_end):
87
+ return "I-" + event_nuggets[idx]["subtype"]
88
+ return "O"
89
+
90
+ def tokenize_and_align_labels_with_pos_ner_realis(examples, tokenizer, ner_names, label_all_tokens = True):
91
+ tokenized_inputs = tokenizer(examples["tokens"], padding='max_length', truncation=True, is_split_into_words=True)
92
+ #tokenized_inputs.pop('input_ids')
93
+ labels = []
94
+ nuggets = []
95
+ ner_spacy = []
96
+ pos_spacy = []
97
+ dep_spacy = []
98
+ depth_spacy = []
99
+
100
+ for i, (nugget, pos, ner, dep, depth) in enumerate(zip(examples["ner_tags"], examples["pos_spacy"], examples["ner_spacy"], examples["dep_spacy"], examples["depth_spacy"])):
101
+ word_ids = tokenized_inputs.word_ids(batch_index=i)
102
+ previous_word_idx = None
103
+ nugget_ids = []
104
+ ner_spacy_ids = []
105
+ pos_spacy_ids = []
106
+ dep_spacy_ids = []
107
+ depth_spacy_ids = []
108
+
109
+ for word_idx in word_ids:
110
+ # Special tokens have a word id that is None. We set the label to -100 so they are automatically
111
+ # ignored in the loss function.
112
+ if word_idx is None:
113
+ nugget_ids.append(-100)
114
+ ner_spacy_ids.append(-100)
115
+ pos_spacy_ids.append(-100)
116
+ dep_spacy_ids.append(-100)
117
+ depth_spacy_ids.append(-100)
118
+ # We set the label for the first token of each word.
119
+ elif word_idx != previous_word_idx:
120
+ nugget_ids.append(nugget[word_idx])
121
+ ner_spacy_ids.append(ner[word_idx])
122
+ pos_spacy_ids.append(pos[word_idx])
123
+ dep_spacy_ids.append(dep[word_idx])
124
+ depth_spacy_ids.append(depth[word_idx])
125
+ # For the other tokens in a word, we set the label to either the current label or -100, depending on
126
+ # the label_all_tokens flag.
127
+ else:
128
+ nugget_ids.append(nugget[word_idx] if label_all_tokens else -100)
129
+ ner_spacy_ids.append(ner[word_idx] if label_all_tokens else -100)
130
+ pos_spacy_ids.append(pos[word_idx] if label_all_tokens else -100)
131
+ dep_spacy_ids.append(dep[word_idx] if label_all_tokens else -100)
132
+ depth_spacy_ids.append(depth[word_idx] if label_all_tokens else -100)
133
+ previous_word_idx = word_idx
134
+
135
+ nuggets.append(nugget_ids)
136
+ ner_spacy.append(ner_spacy_ids)
137
+ pos_spacy.append(pos_spacy_ids)
138
+ dep_spacy.append(dep_spacy_ids)
139
+ depth_spacy.append(depth_spacy_ids)
140
+
141
+ tokenized_inputs["ner_tags"] = nuggets
142
+ tokenized_inputs["pos_spacy"] = pos_spacy
143
+ tokenized_inputs["ner_spacy"] = ner_spacy
144
+ tokenized_inputs["dep_spacy"] = dep_spacy
145
+ tokenized_inputs["depth_spacy"] = depth_spacy
146
+ return tokenized_inputs
utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ list_of_pos_tags = [
2
+ "ADJ",
3
+ "ADP",
4
+ "ADV",
5
+ "AUX",
6
+ "CCONJ",
7
+ "DET",
8
+ "INTJ",
9
+ "NOUN",
10
+ "NUM",
11
+ "PART",
12
+ "PRON",
13
+ "PROPN",
14
+ "PUNCT",
15
+ "SCONJ",
16
+ "SYM",
17
+ "VERB",
18
+ "X"
19
+ ]
20
+
21
+ realis_list = ["O",
22
+ "Generic",
23
+ "Other",
24
+ "Actual"
25
+ ]
26
+
27
+
28
+ event_args_list = ['O',
29
+ 'B-System',
30
+ 'I-System',
31
+ 'B-Organization',
32
+ 'B-Money',
33
+ 'I-Money',
34
+ 'B-Device',
35
+ 'B-Person',
36
+ 'I-Person',
37
+ 'B-Vulnerability',
38
+ 'I-Vulnerability',
39
+ 'B-Capabilities',
40
+ 'I-Capabilities',
41
+ 'I-Organization',
42
+ 'B-PaymentMethod',
43
+ 'I-PaymentMethod',
44
+ 'B-Data',
45
+ 'I-Data',
46
+ 'B-Number',
47
+ 'I-Number',
48
+ 'B-Malware',
49
+ 'I-Malware',
50
+ 'B-PII',
51
+ 'I-PII',
52
+ 'B-CVE',
53
+ 'I-CVE',
54
+ 'B-Purpose',
55
+ 'I-Purpose',
56
+ 'B-File',
57
+ 'I-File',
58
+ 'I-Device',
59
+ 'B-Time',
60
+ 'I-Time',
61
+ 'B-Software',
62
+ 'I-Software',
63
+ 'B-Patch',
64
+ 'I-Patch',
65
+ 'B-Version',
66
+ 'I-Version',
67
+ 'B-Website',
68
+ 'I-Website',
69
+ 'B-GPE',
70
+ 'I-GPE'
71
+ ]
72
+
73
+ event_nugget_list = ['O',
74
+ 'B-Ransom',
75
+ 'I-Ransom',
76
+ 'B-DiscoverVulnerability',
77
+ 'I-DiscoverVulnerability',
78
+ 'B-PatchVulnerability',
79
+ 'I-PatchVulnerability',
80
+ 'B-Databreach',
81
+ 'I-Databreach',
82
+ 'B-Phishing',
83
+ 'I-Phishing'
84
+ ]
85
+
86
+ arg_2_role = {
87
+ "File" : ['Tool', 'Trusted-Entity'],
88
+ "Person" : ['Victim', 'Attacker', 'Discoverer', 'Releaser', 'Trusted-Entity', 'Vulnerable_System_Owner'],
89
+ "Capabilities" : ['Attack-Pattern', 'Capabilities', 'Issues-Addressed'],
90
+ "Purpose" : ['Purpose'],
91
+ "Time" : ['Time'],
92
+ "PII" : ['Compromised-Data', 'Trusted-Entity'],
93
+ "Data" : ['Compromised-Data', 'Trusted-Entity'],
94
+ "Organization" : ['Victim', 'Releaser', 'Discoverer', 'Attacker', 'Vulnerable_System_Owner', 'Trusted-Entity'],
95
+ "Patch" : ['Patch'],
96
+ "Software" : ['Vulnerable_System', 'Victim', 'Trusted-Entity', 'Supported_Platform'],
97
+ "Vulnerability" : ['Vulnerability'],
98
+ "Version" : ['Patch-Number', 'Vulnerable_System_Version'],
99
+ "Device" : ['Vulnerable_System', 'Victim', 'Supported_Platform'],
100
+ "CVE" : ['CVE'],
101
+ "Number" : ['Number-of-Data', 'Number-of-Victim'],
102
+ "System" : ['Victim', 'Supported_Platform', 'Vulnerable_System', 'Trusted-Entity'],
103
+ "Malware" : ['Tool'],
104
+ "Money" : ['Price', 'Damage-Amount'],
105
+ "PaymentMethod" : ['Payment-Method'],
106
+ "GPE" : ['Place'],
107
+ "Website" : ['Trusted-Entity', 'Tool', 'Vulnerable_System', 'Victim', 'Supported_Platform'],
108
+ }
109
+
110
+ def get_content(data):
111
+ return data["content"]
112
+
113
+ def get_event_nugget(data):
114
+ return [
115
+ {"nugget" : event["nugget"], "type" : event["type"], "subtype" : event["subtype"], "realis" : event["realis"]}
116
+ for hopper in data["cyberevent"]["hopper"] for event in hopper["events"]
117
+ ]
118
+ def get_event_args(data):
119
+ events = [event for hopper in data["cyberevent"]["hopper"] for event in hopper["events"]]
120
+ args = []
121
+ for event in events:
122
+ if "argument" in event.keys():
123
+ args.extend(event["argument"])
124
+ return args
125
+
126
+ def get_idxs_from_text(text, text_tokenized):
127
+ rest_text = text
128
+ last_idx = 0
129
+ result_dict = []
130
+
131
+ for substring in text_tokenized:
132
+ index = rest_text.find(substring)
133
+ result_dict.append(
134
+ {
135
+ "word" : substring,
136
+ "start_idx" : last_idx + index,
137
+ "end_idx" : last_idx + index + len(substring)
138
+ }
139
+ )
140
+ rest_text = rest_text[index + len(substring) : ]
141
+ last_idx += index + len(substring)
142
+ return result_dict
143
+
144
+ def get_entity_from_idx(start_idx, end_idx, event_nuggets):
145
+ event_nuggets_idxs = [(nugget["nugget"]["startOffset"], nugget["nugget"]["endOffset"]) for nugget in event_nuggets]
146
+ for idx, (nugget_start, nugget_end) in enumerate(event_nuggets_idxs):
147
+ if (start_idx == nugget_start and end_idx == nugget_end) or (start_idx == nugget_start and end_idx <= nugget_end) or (start_idx == nugget_start and end_idx > nugget_end) or (end_idx == nugget_end and start_idx < nugget_start) or (start_idx <= nugget_start and end_idx <= nugget_end and end_idx > nugget_start):
148
+ return "B-" + event_nuggets[idx]["subtype"]
149
+ elif (start_idx > nugget_start and end_idx <= nugget_end) or (start_idx > nugget_start and start_idx < nugget_end):
150
+ return "I-" + event_nuggets[idx]["subtype"]
151
+ return "O"
152
+
153
+ def get_entity_and_realis_from_idx(start_idx, end_idx, event_nuggets):
154
+ event_nuggets_idxs = [(nugget["nugget"]["startOffset"], nugget["nugget"]["endOffset"]) for nugget in event_nuggets]
155
+ for idx, (nugget_start, nugget_end) in enumerate(event_nuggets_idxs):
156
+ if (start_idx == nugget_start and end_idx == nugget_end) or (start_idx == nugget_start and end_idx <= nugget_end) or (start_idx == nugget_start and end_idx > nugget_end) or (end_idx == nugget_end and start_idx < nugget_start) or (start_idx <= nugget_start and end_idx <= nugget_end and end_idx > nugget_start):
157
+ return "B-" + event_nuggets[idx]["subtype"], "B-" + event_nuggets[idx]["realis"]
158
+ elif (start_idx > nugget_start and end_idx <= nugget_end) or (start_idx > nugget_start and start_idx < nugget_end):
159
+ return "I-" + event_nuggets[idx]["subtype"], "I-" + event_nuggets[idx]["realis"]
160
+ return "O", "O"
161
+
162
+ def get_args_entity_from_idx(start_idx, end_idx, event_args):
163
+ event_nuggets_idxs = [(nugget["startOffset"], nugget["endOffset"]) for nugget in event_args]
164
+ for idx, (nugget_start, nugget_end) in enumerate(event_nuggets_idxs):
165
+ if (start_idx == nugget_start and end_idx == nugget_end) or (start_idx == nugget_start and end_idx <= nugget_end) or (start_idx == nugget_start and end_idx > nugget_end) or (end_idx == nugget_end and start_idx < nugget_start) or (start_idx <= nugget_start and end_idx <= nugget_end and end_idx > nugget_start):
166
+ return "B-" + event_args[idx]["type"]
167
+ elif (start_idx > nugget_start and end_idx <= nugget_end) or (start_idx > nugget_start and start_idx < nugget_end):
168
+ return "I-" + event_args[idx]["type"]
169
+ return "O"
170
+
171
+ def split_with_character(string, char):
172
+ result = []
173
+ start = 0
174
+ for i, c in enumerate(string):
175
+ if c == char:
176
+ result.append(string[start:i])
177
+ result.append(char)
178
+ start = i + 1
179
+ result.append(string[start:])
180
+ return [x for x in result if x != '']
181
+
182
+ def extend_list_with_character(content_list, character):
183
+ content_as_words = []
184
+ for word in content_list:
185
+ if character in word:
186
+ split_list = split_with_character(word, character)
187
+ content_as_words.extend(split_list)
188
+ else:
189
+ content_as_words.append(word)
190
+ return content_as_words
191
+
192
+ def find_dict_by_overlap(list_of_dicts, key_value_pairs):
193
+ for dictionary in list_of_dicts:
194
+ if max(dictionary["start"], dictionary["end"]) >= min(key_value_pairs["start"], key_value_pairs["end"]) and max(key_value_pairs["start"], key_value_pairs["end"]) >= min(dictionary["start"], dictionary["end"]):
195
+ return dictionary
196
+ return None