DeepMount00 commited on
Commit
cc8997b
·
verified ·
1 Parent(s): b1d7709

Upload 12 files

Browse files
GLiNER/README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Card for GLiNER-base
2
+
3
+ GLiNER is a Named Entity Recognition (NER) model capable of identifying any entity type using a bidirectional transformer encoder (BERT-like). It provides a practical alternative to traditional NER models, which are limited to predefined entities, and Large Language Models (LLMs) that, despite their flexibility, are costly and large for resource-constrained scenarios.
4
+
5
+ ## Models Status
6
+
7
+ ### Available Models on Hugging Face
8
+ - [x] [GLiNER-Base](https://huggingface.co/urchade/gliner_base) (CC BY NC 4.0)
9
+ - [x] [GLiNER-Multi](https://huggingface.co/urchade/gliner_multi) (CC BY NC 4.0)
10
+ - [x] [GLiNER-small](https://huggingface.co/urchade/gliner_small) (CC BY NC 4.0)
11
+ - [x] [GLiNER-small-v2](https://huggingface.co/urchade/gliner_smallv2) (Apache)
12
+ - [x] [GLiNER-medium](https://huggingface.co/urchade/gliner_medium) (CC BY NC 4.0)
13
+ - [x] [GLiNER-medium-v2](https://huggingface.co/urchade/gliner_mediumv2) (Apache)
14
+ - [x] [GLiNER-large](https://huggingface.co/urchade/gliner_large) (CC BY NC 4.0)
15
+ - [x] [GLiNER-ledium-v2](https://huggingface.co/urchade/gliner_largev2) (Apache)
16
+
17
+ ### To Release
18
+ - [ ] ⏳ GLiNER-Multiv2
19
+ - [ ] ⏳ GLiNER-Sup (trained on mixture of NER datasets)
20
+
21
+ ## Links
22
+
23
+ * Paper: https://arxiv.org/abs/2311.08526
24
+ * Repository: https://github.com/urchade/GLiNER
25
+
26
+ ## Installation
27
+ To use this model, you must download the GLiNER repository and install its dependencies:
28
+ ```
29
+ !git clone https://github.com/urchade/GLiNER.git
30
+ %cd GLiNER
31
+ !pip install -r requirements.txt
32
+ ```
33
+
34
+ ## Usage
35
+ Once you've downloaded the GLiNER repository, you can import the GLiNER class from the `model` file. You can then load this model using `GLiNER.from_pretrained` and predict entities with `predict_entities`.
36
+
37
+ ```python
38
+ from model import GLiNER
39
+
40
+ model = GLiNER.from_pretrained("urchade/gliner_base")
41
+
42
+ text = """
43
+ Cristiano Ronaldo dos Santos Aveiro (Portuguese pronunciation: [kɾiʃˈtjɐnu ʁɔˈnaldu]; born 5 February 1985) is a Portuguese professional footballer who plays as a forward for and captains both Saudi Pro League club Al Nassr and the Portugal national team. Widely regarded as one of the greatest players of all time, Ronaldo has won five Ballon d'Or awards,[note 3] a record three UEFA Men's Player of the Year Awards, and four European Golden Shoes, the most by a European player. He has won 33 trophies in his career, including seven league titles, five UEFA Champions Leagues, the UEFA European Championship and the UEFA Nations League. Ronaldo holds the records for most appearances (183), goals (140) and assists (42) in the Champions League, goals in the European Championship (14), international goals (128) and international appearances (205). He is one of the few players to have made over 1,200 professional career appearances, the most by an outfield player, and has scored over 850 official senior career goals for club and country, making him the top goalscorer of all time.
44
+ """
45
+
46
+ labels = ["person", "award", "date", "competitions", "teams"]
47
+
48
+ entities = model.predict_entities(text, labels)
49
+
50
+ for entity in entities:
51
+ print(entity["text"], "=>", entity["label"])
52
+ ```
53
+
54
+ ```
55
+ Cristiano Ronaldo dos Santos Aveiro => person
56
+ 5 February 1985 => date
57
+ Al Nassr => teams
58
+ Portugal national team => teams
59
+ Ballon d'Or => award
60
+ UEFA Men's Player of the Year Awards => award
61
+ European Golden Shoes => award
62
+ UEFA Champions Leagues => competitions
63
+ UEFA European Championship => competitions
64
+ UEFA Nations League => competitions
65
+ Champions League => competitions
66
+ European Championship => competitions
67
+ ```
68
+
69
+ ## Named Entity Recognition benchmark result
70
+
71
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/6317233cc92fd6fee317e030/Y5f7tK8lonGqeeO6L6bVI.png)
72
+
73
+ ## Model Authors
74
+ The model authors are:
75
+ * [Urchade Zaratiana](https://huggingface.co/urchade)
76
+ * Nadi Tomeh
77
+ * Pierre Holat
78
+ * Thierry Charnois
79
+
80
+ ## Citation
81
+ ```bibtex
82
+ @misc{zaratiana2023gliner,
83
+ title={GLiNER: Generalist Model for Named Entity Recognition using Bidirectional Transformer},
84
+ author={Urchade Zaratiana and Nadi Tomeh and Pierre Holat and Thierry Charnois},
85
+ year={2023},
86
+ eprint={2311.08526},
87
+ archivePrefix={arXiv},
88
+ primaryClass={cs.CL}
89
+ }
90
+ ```
GLiNER/model.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ import re
5
+ from typing import Dict, Optional, Union
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from modules.layers import LstmSeq2SeqEncoder
9
+ from modules.base import InstructBase
10
+ from modules.evaluator import Evaluator, greedy_search
11
+ from modules.span_rep import SpanRepLayer
12
+ from modules.token_rep import TokenRepLayer
13
+ from torch import nn
14
+ from torch.nn.utils.rnn import pad_sequence
15
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
16
+ from huggingface_hub.utils import HfHubHTTPError
17
+
18
+
19
+
20
+ class GLiNER(InstructBase, PyTorchModelHubMixin):
21
+ def __init__(self, config):
22
+ super().__init__(config)
23
+
24
+ self.config = config
25
+
26
+ # [ENT] token
27
+ self.entity_token = "<<ENT>>"
28
+ self.sep_token = "<<SEP>>"
29
+
30
+ # usually a pretrained bidirectional transformer, returns first subtoken representation
31
+ self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune,
32
+ subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size,
33
+ add_tokens=[self.entity_token, self.sep_token])
34
+
35
+ # hierarchical representation of tokens
36
+ self.rnn = LstmSeq2SeqEncoder(
37
+ input_size=config.hidden_size,
38
+ hidden_size=config.hidden_size // 2,
39
+ num_layers=1,
40
+ bidirectional=True,
41
+ )
42
+
43
+ # span representation
44
+ self.span_rep_layer = SpanRepLayer(
45
+ span_mode=config.span_mode,
46
+ hidden_size=config.hidden_size,
47
+ max_width=config.max_width,
48
+ dropout=config.dropout,
49
+ )
50
+
51
+ # prompt representation (FFN)
52
+ self.prompt_rep_layer = nn.Sequential(
53
+ nn.Linear(config.hidden_size, config.hidden_size * 4),
54
+ nn.Dropout(config.dropout),
55
+ nn.ReLU(),
56
+ nn.Linear(config.hidden_size * 4, config.hidden_size)
57
+ )
58
+
59
+ def compute_score_train(self, x):
60
+ span_idx = x['span_idx'] * x['span_mask'].unsqueeze(-1)
61
+
62
+ new_length = x['seq_length'].clone()
63
+ new_tokens = []
64
+ all_len_prompt = []
65
+ num_classes_all = []
66
+
67
+ # add prompt to the tokens
68
+ for i in range(len(x['tokens'])):
69
+ all_types_i = list(x['classes_to_id'][i].keys())
70
+ # multiple entity types in all_types. Prompt is appended at the start of tokens
71
+ entity_prompt = []
72
+ num_classes_all.append(len(all_types_i))
73
+ # add enity types to prompt
74
+ for entity_type in all_types_i:
75
+ entity_prompt.append(self.entity_token) # [ENT] token
76
+ entity_prompt.append(entity_type) # entity type
77
+ entity_prompt.append(self.sep_token) # [SEP] token
78
+
79
+ # prompt format:
80
+ # [ENT] entity_type [ENT] entity_type ... [ENT] entity_type [SEP]
81
+
82
+ # add prompt to the tokens
83
+ tokens_p = entity_prompt + x['tokens'][i]
84
+
85
+ # input format:
86
+ # [ENT] entity_type_1 [ENT] entity_type_2 ... [ENT] entity_type_m [SEP] token_1 token_2 ... token_n
87
+
88
+ # update length of the sequence (add prompt length to the original length)
89
+ new_length[i] = new_length[i] + len(entity_prompt)
90
+ # update tokens
91
+ new_tokens.append(tokens_p)
92
+ # store prompt length
93
+ all_len_prompt.append(len(entity_prompt))
94
+
95
+ # create a mask using num_classes_all (0, if it exceeds the number of classes, 1 otherwise)
96
+ max_num_classes = max(num_classes_all)
97
+ entity_type_mask = torch.arange(max_num_classes).unsqueeze(0).expand(len(num_classes_all), -1).to(
98
+ x['span_mask'].device)
99
+ entity_type_mask = entity_type_mask < torch.tensor(num_classes_all).unsqueeze(-1).to(
100
+ x['span_mask'].device) # [batch_size, max_num_classes]
101
+
102
+ # compute all token representations
103
+ bert_output = self.token_rep_layer(new_tokens, new_length)
104
+ word_rep_w_prompt = bert_output["embeddings"] # embeddings for all tokens (with prompt)
105
+ mask_w_prompt = bert_output["mask"] # mask for all tokens (with prompt)
106
+
107
+ # get word representation (after [SEP]), mask (after [SEP]) and entity type representation (before [SEP])
108
+ word_rep = [] # word representation (after [SEP])
109
+ mask = [] # mask (after [SEP])
110
+ entity_type_rep = [] # entity type representation (before [SEP])
111
+ for i in range(len(x['tokens'])):
112
+ prompt_entity_length = all_len_prompt[i] # length of prompt for this example
113
+ # get word representation (after [SEP])
114
+ word_rep.append(word_rep_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
115
+ # get mask (after [SEP])
116
+ mask.append(mask_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
117
+
118
+ # get entity type representation (before [SEP])
119
+ entity_rep = word_rep_w_prompt[i, :prompt_entity_length - 1] # remove [SEP]
120
+ entity_rep = entity_rep[0::2] # it means that we take every second element starting from the second one
121
+ entity_type_rep.append(entity_rep)
122
+
123
+ # padding for word_rep, mask and entity_type_rep
124
+ word_rep = pad_sequence(word_rep, batch_first=True) # [batch_size, seq_len, hidden_size]
125
+ mask = pad_sequence(mask, batch_first=True) # [batch_size, seq_len]
126
+ entity_type_rep = pad_sequence(entity_type_rep, batch_first=True) # [batch_size, len_types, hidden_size]
127
+
128
+ # compute span representation
129
+ word_rep = self.rnn(word_rep, mask)
130
+ span_rep = self.span_rep_layer(word_rep, span_idx)
131
+
132
+ # compute final entity type representation (FFN)
133
+ entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size)
134
+ num_classes = entity_type_rep.shape[1] # number of entity types
135
+
136
+ # similarity score
137
+ scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
138
+
139
+ return scores, num_classes, entity_type_mask
140
+
141
+ def forward(self, x):
142
+ # compute span representation
143
+ scores, num_classes, entity_type_mask = self.compute_score_train(x)
144
+ batch_size = scores.shape[0]
145
+
146
+ # loss for filtering classifier
147
+ logits_label = scores.view(-1, num_classes)
148
+ labels = x["span_label"].view(-1) # (batch_size * num_spans)
149
+ mask_label = labels != -1 # (batch_size * num_spans)
150
+ labels.masked_fill_(~mask_label, 0) # Set the labels of padding tokens to 0
151
+
152
+ # one-hot encoding
153
+ labels_one_hot = torch.zeros(labels.size(0), num_classes + 1, dtype=torch.float32).to(scores.device)
154
+ labels_one_hot.scatter_(1, labels.unsqueeze(1), 1) # Set the corresponding index to 1
155
+ labels_one_hot = labels_one_hot[:, 1:] # Remove the first column
156
+ # Shape of labels_one_hot: (batch_size * num_spans, num_classes)
157
+
158
+ # compute loss (without reduction)
159
+ all_losses = F.binary_cross_entropy_with_logits(logits_label, labels_one_hot,
160
+ reduction='none')
161
+ # mask loss using entity_type_mask (B, C)
162
+ masked_loss = all_losses.view(batch_size, -1, num_classes) * entity_type_mask.unsqueeze(1)
163
+ all_losses = masked_loss.view(-1, num_classes)
164
+ # expand mask_label to all_losses
165
+ mask_label = mask_label.unsqueeze(-1).expand_as(all_losses)
166
+ # put lower loss for in label_one_hot (2 for positive, 1 for negative)
167
+ weight_c = labels_one_hot + 1
168
+ # apply mask
169
+ all_losses = all_losses * mask_label.float() * weight_c
170
+ return all_losses.sum()
171
+
172
+ def compute_score_eval(self, x, device):
173
+ # check if classes_to_id is dict
174
+ assert isinstance(x['classes_to_id'], dict), "classes_to_id must be a dict"
175
+
176
+ span_idx = (x['span_idx'] * x['span_mask'].unsqueeze(-1)).to(device)
177
+
178
+ all_types = list(x['classes_to_id'].keys())
179
+ # multiple entity types in all_types. Prompt is appended at the start of tokens
180
+ entity_prompt = []
181
+
182
+ # add enity types to prompt
183
+ for entity_type in all_types:
184
+ entity_prompt.append(self.entity_token)
185
+ entity_prompt.append(entity_type)
186
+
187
+ entity_prompt.append(self.sep_token)
188
+
189
+ prompt_entity_length = len(entity_prompt)
190
+
191
+ # add prompt
192
+ tokens_p = [entity_prompt + tokens for tokens in x['tokens']]
193
+ seq_length_p = x['seq_length'] + prompt_entity_length
194
+
195
+ out = self.token_rep_layer(tokens_p, seq_length_p)
196
+
197
+ word_rep_w_prompt = out["embeddings"]
198
+ mask_w_prompt = out["mask"]
199
+
200
+ # remove prompt
201
+ word_rep = word_rep_w_prompt[:, prompt_entity_length:, :]
202
+ mask = mask_w_prompt[:, prompt_entity_length:]
203
+
204
+ # get_entity_type_rep
205
+ entity_type_rep = word_rep_w_prompt[:, :prompt_entity_length - 1, :]
206
+ # extract [ENT] tokens (which are at even positions in entity_type_rep)
207
+ entity_type_rep = entity_type_rep[:, 0::2, :]
208
+
209
+ entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size)
210
+
211
+ word_rep = self.rnn(word_rep, mask)
212
+
213
+ span_rep = self.span_rep_layer(word_rep, span_idx)
214
+
215
+ local_scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
216
+
217
+ return local_scores
218
+
219
+ @torch.no_grad()
220
+ def predict(self, x, flat_ner=False, threshold=0.5):
221
+ self.eval()
222
+ local_scores = self.compute_score_eval(x, device=next(self.parameters()).device)
223
+ spans = []
224
+ for i, _ in enumerate(x["tokens"]):
225
+ local_i = local_scores[i]
226
+ wh_i = [i.tolist() for i in torch.where(torch.sigmoid(local_i) > threshold)]
227
+ span_i = []
228
+ for s, k, c in zip(*wh_i):
229
+ if s + k < len(x["tokens"][i]):
230
+ span_i.append((s, s + k, x["id_to_classes"][c + 1], local_i[s, k, c]))
231
+ span_i = greedy_search(span_i, flat_ner)
232
+ spans.append(span_i)
233
+ return spans
234
+
235
+ def predict_entities(self, text, labels, flat_ner=True, threshold=0.5):
236
+ tokens = []
237
+ start_token_idx_to_text_idx = []
238
+ end_token_idx_to_text_idx = []
239
+ for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text):
240
+ tokens.append(match.group())
241
+ start_token_idx_to_text_idx.append(match.start())
242
+ end_token_idx_to_text_idx.append(match.end())
243
+
244
+ input_x = {"tokenized_text": tokens, "ner": None}
245
+ x = self.collate_fn([input_x], labels)
246
+ output = self.predict(x, flat_ner=flat_ner, threshold=threshold)
247
+
248
+ entities = []
249
+ for start_token_idx, end_token_idx, ent_type in output[0]:
250
+ start_text_idx = start_token_idx_to_text_idx[start_token_idx]
251
+ end_text_idx = end_token_idx_to_text_idx[end_token_idx]
252
+ entities.append({
253
+ "start": start_token_idx_to_text_idx[start_token_idx],
254
+ "end": end_token_idx_to_text_idx[end_token_idx],
255
+ "text": text[start_text_idx:end_text_idx],
256
+ "label": ent_type,
257
+ })
258
+ return entities
259
+
260
+ def evaluate(self, test_data, flat_ner=False, threshold=0.5, batch_size=12, entity_types=None):
261
+ self.eval()
262
+ data_loader = self.create_dataloader(test_data, batch_size=batch_size, entity_types=entity_types, shuffle=False)
263
+ device = next(self.parameters()).device
264
+ all_preds = []
265
+ all_trues = []
266
+ for x in data_loader:
267
+ for k, v in x.items():
268
+ if isinstance(v, torch.Tensor):
269
+ x[k] = v.to(device)
270
+ batch_predictions = self.predict(x, flat_ner, threshold)
271
+ all_preds.extend(batch_predictions)
272
+ all_trues.extend(x["entities"])
273
+ evaluator = Evaluator(all_trues, all_preds)
274
+ out, f1 = evaluator.evaluate()
275
+ return out, f1
276
+
277
+ @classmethod
278
+ def _from_pretrained(
279
+ cls,
280
+ *,
281
+ model_id: str,
282
+ revision: Optional[str],
283
+ cache_dir: Optional[Union[str, Path]],
284
+ force_download: bool,
285
+ proxies: Optional[Dict],
286
+ resume_download: bool,
287
+ local_files_only: bool,
288
+ token: Union[str, bool, None],
289
+ map_location: str = "cpu",
290
+ strict: bool = False,
291
+ **model_kwargs,
292
+ ):
293
+ # 1. Backwards compatibility: Use "gliner_base.pt" and "gliner_multi.pt" with all data
294
+ filenames = ["gliner_base.pt", "gliner_multi.pt"]
295
+ for filename in filenames:
296
+ model_file = Path(model_id) / filename
297
+ if not model_file.exists():
298
+ try:
299
+ model_file = hf_hub_download(
300
+ repo_id=model_id,
301
+ filename=filename,
302
+ revision=revision,
303
+ cache_dir=cache_dir,
304
+ force_download=force_download,
305
+ proxies=proxies,
306
+ resume_download=resume_download,
307
+ token=token,
308
+ local_files_only=local_files_only,
309
+ )
310
+ except HfHubHTTPError:
311
+ continue
312
+ dict_load = torch.load(model_file, map_location=torch.device(map_location))
313
+ config = dict_load["config"]
314
+ state_dict = dict_load["model_weights"]
315
+ config.model_name = "microsoft/deberta-v3-base" if filename == "gliner_base.pt" else "microsoft/mdeberta-v3-base"
316
+ model = cls(config)
317
+ model.load_state_dict(state_dict, strict=strict, assign=True)
318
+ # Required to update flair's internals as well:
319
+ model.to(map_location)
320
+ return model
321
+
322
+ # 2. Newer format: Use "pytorch_model.bin" and "gliner_config.json"
323
+ from train import load_config_as_namespace
324
+
325
+ model_file = Path(model_id) / "pytorch_model.bin"
326
+ if not model_file.exists():
327
+ model_file = hf_hub_download(
328
+ repo_id=model_id,
329
+ filename="pytorch_model.bin",
330
+ revision=revision,
331
+ cache_dir=cache_dir,
332
+ force_download=force_download,
333
+ proxies=proxies,
334
+ resume_download=resume_download,
335
+ token=token,
336
+ local_files_only=local_files_only,
337
+ )
338
+ config_file = Path(model_id) / "gliner_config.json"
339
+ if not config_file.exists():
340
+ config_file = hf_hub_download(
341
+ repo_id=model_id,
342
+ filename="gliner_config.json",
343
+ revision=revision,
344
+ cache_dir=cache_dir,
345
+ force_download=force_download,
346
+ proxies=proxies,
347
+ resume_download=resume_download,
348
+ token=token,
349
+ local_files_only=local_files_only,
350
+ )
351
+ config = load_config_as_namespace(config_file)
352
+ model = cls(config)
353
+ state_dict = torch.load(model_file, map_location=torch.device(map_location))
354
+ model.load_state_dict(state_dict, strict=strict, assign=True)
355
+ model.to(map_location)
356
+ return model
357
+
358
+ def save_pretrained(
359
+ self,
360
+ save_directory: Union[str, Path],
361
+ *,
362
+ config: Optional[Union[dict, "DataclassInstance"]] = None,
363
+ repo_id: Optional[str] = None,
364
+ push_to_hub: bool = False,
365
+ **push_to_hub_kwargs,
366
+ ) -> Optional[str]:
367
+ """
368
+ Save weights in local directory.
369
+
370
+ Args:
371
+ save_directory (`str` or `Path`):
372
+ Path to directory in which the model weights and configuration will be saved.
373
+ config (`dict` or `DataclassInstance`, *optional*):
374
+ Model configuration specified as a key/value dictionary or a dataclass instance.
375
+ push_to_hub (`bool`, *optional*, defaults to `False`):
376
+ Whether or not to push your model to the Huggingface Hub after saving it.
377
+ repo_id (`str`, *optional*):
378
+ ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
379
+ not provided.
380
+ kwargs:
381
+ Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
382
+ """
383
+ save_directory = Path(save_directory)
384
+ save_directory.mkdir(parents=True, exist_ok=True)
385
+
386
+ # save model weights/files
387
+ torch.save(self.state_dict(), save_directory / "pytorch_model.bin")
388
+
389
+ # save config (if provided)
390
+ if config is None:
391
+ config = self.config
392
+ if config is not None:
393
+ if isinstance(config, argparse.Namespace):
394
+ config = vars(config)
395
+ (save_directory / "gliner_config.json").write_text(json.dumps(config, indent=2))
396
+
397
+ # push to the Hub if required
398
+ if push_to_hub:
399
+ kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
400
+ if config is not None: # kwarg for `push_to_hub`
401
+ kwargs["config"] = config
402
+ if repo_id is None:
403
+ repo_id = save_directory.name # Defaults to `save_directory` name
404
+ return self.push_to_hub(repo_id=repo_id, **kwargs)
405
+ return None
406
+
407
+ def to(self, device):
408
+ super().to(device)
409
+ import flair
410
+
411
+ flair.device = device
412
+ return self
GLiNER/modules/base.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import List, Tuple, Dict
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn.utils.rnn import pad_sequence
7
+ from torch.utils.data import DataLoader
8
+ import random
9
+
10
+
11
+ class InstructBase(nn.Module):
12
+ def __init__(self, config):
13
+ super().__init__()
14
+ self.max_width = config.max_width
15
+ self.base_config = config
16
+
17
+ def get_dict(self, spans, classes_to_id):
18
+ dict_tag = defaultdict(int)
19
+ for span in spans:
20
+ if span[2] in classes_to_id:
21
+ dict_tag[(span[0], span[1])] = classes_to_id[span[2]]
22
+ return dict_tag
23
+
24
+ def preprocess_spans(self, tokens, ner, classes_to_id):
25
+
26
+ max_len = self.base_config.max_len
27
+
28
+ if len(tokens) > max_len:
29
+ length = max_len
30
+ tokens = tokens[:max_len]
31
+ else:
32
+ length = len(tokens)
33
+
34
+ spans_idx = []
35
+ for i in range(length):
36
+ spans_idx.extend([(i, i + j) for j in range(self.max_width)])
37
+
38
+ dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int)
39
+
40
+ # 0 for null labels
41
+ span_label = torch.LongTensor([dict_lab[i] for i in spans_idx])
42
+ spans_idx = torch.LongTensor(spans_idx)
43
+
44
+ # mask for valid spans
45
+ valid_span_mask = spans_idx[:, 1] > length - 1
46
+
47
+ # mask invalid positions
48
+ span_label = span_label.masked_fill(valid_span_mask, -1)
49
+
50
+ return {
51
+ 'tokens': tokens,
52
+ 'span_idx': spans_idx,
53
+ 'span_label': span_label,
54
+ 'seq_length': length,
55
+ 'entities': ner,
56
+ }
57
+
58
+ def collate_fn(self, batch_list, entity_types=None):
59
+ # batch_list: list of dict containing tokens, ner
60
+ if entity_types is None:
61
+ negs = self.get_negatives(batch_list, 100)
62
+ class_to_ids = []
63
+ id_to_classes = []
64
+ for b in batch_list:
65
+ # negs = b["negative"]
66
+ random.shuffle(negs)
67
+
68
+ # negs = negs[:sampled_neg]
69
+ max_neg_type_ratio = int(self.base_config.max_neg_type_ratio)
70
+
71
+ if max_neg_type_ratio == 0:
72
+ # no negatives
73
+ neg_type_ratio = 0
74
+ else:
75
+ neg_type_ratio = random.randint(0, max_neg_type_ratio)
76
+
77
+ if neg_type_ratio == 0:
78
+ # no negatives
79
+ negs_i = []
80
+ else:
81
+ negs_i = negs[:len(b['ner']) * neg_type_ratio]
82
+
83
+ # this is the list of all possible entity types (positive and negative)
84
+ types = list(set([el[-1] for el in b['ner']] + negs_i))
85
+
86
+ # shuffle (every epoch)
87
+ random.shuffle(types)
88
+
89
+ if len(types) != 0:
90
+ # prob of higher number shoul
91
+ # random drop
92
+ if self.base_config.random_drop:
93
+ num_ents = random.randint(1, len(types))
94
+ types = types[:num_ents]
95
+
96
+ # maximum number of entities types
97
+ types = types[:int(self.base_config.max_types)]
98
+
99
+ # supervised training
100
+ if "label" in b:
101
+ types = sorted(b["label"])
102
+
103
+ class_to_id = {k: v for v, k in enumerate(types, start=1)}
104
+ id_to_class = {k: v for v, k in class_to_id.items()}
105
+ class_to_ids.append(class_to_id)
106
+ id_to_classes.append(id_to_class)
107
+
108
+ batch = [
109
+ self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids[i]) for i, b in enumerate(batch_list)
110
+ ]
111
+
112
+ else:
113
+ class_to_ids = {k: v for v, k in enumerate(entity_types, start=1)}
114
+ id_to_classes = {k: v for v, k in class_to_ids.items()}
115
+ batch = [
116
+ self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids) for b in batch_list
117
+ ]
118
+
119
+ span_idx = pad_sequence(
120
+ [b['span_idx'] for b in batch], batch_first=True, padding_value=0
121
+ )
122
+
123
+ span_label = pad_sequence(
124
+ [el['span_label'] for el in batch], batch_first=True, padding_value=-1
125
+ )
126
+
127
+ return {
128
+ 'seq_length': torch.LongTensor([el['seq_length'] for el in batch]),
129
+ 'span_idx': span_idx,
130
+ 'tokens': [el['tokens'] for el in batch],
131
+ 'span_mask': span_label != -1,
132
+ 'span_label': span_label,
133
+ 'entities': [el['entities'] for el in batch],
134
+ 'classes_to_id': class_to_ids,
135
+ 'id_to_classes': id_to_classes,
136
+ }
137
+
138
+ @staticmethod
139
+ def get_negatives(batch_list, sampled_neg=5):
140
+ ent_types = []
141
+ for b in batch_list:
142
+ types = set([el[-1] for el in b['ner']])
143
+ ent_types.extend(list(types))
144
+ ent_types = list(set(ent_types))
145
+ # sample negatives
146
+ random.shuffle(ent_types)
147
+ return ent_types[:sampled_neg]
148
+
149
+ def create_dataloader(self, data, entity_types=None, **kwargs):
150
+ return DataLoader(data, collate_fn=lambda x: self.collate_fn(x, entity_types), **kwargs)
GLiNER/modules/data_proc.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from tqdm import tqdm
3
+ # ast.literal_eval
4
+ import ast, re
5
+
6
+ path = 'train.json'
7
+
8
+ with open(path, 'r') as f:
9
+ data = json.load(f)
10
+
11
+ def tokenize_text(text):
12
+ return re.findall(r'\w+(?:[-_]\w+)*|\S', text)
13
+
14
+ def extract_entity_spans(entry):
15
+ text = ""
16
+ len_start = len("What describes ")
17
+ len_end = len(" in the text?")
18
+ entity_types = []
19
+ entity_texts = []
20
+
21
+ for c in entry['conversations']:
22
+ if c['from'] == 'human' and c['value'].startswith('Text: '):
23
+ text = c['value'][len('Text: '):]
24
+ tokenized_text = tokenize_text(text)
25
+
26
+ if c['from'] == 'human' and c['value'].startswith('What describes '):
27
+
28
+ c_type = c['value'][len_start:-len_end]
29
+ c_type = c_type.replace(' ', '_')
30
+ entity_types.append(c_type)
31
+
32
+ elif c['from'] == 'gpt' and c['value'].startswith('['):
33
+ if c['value'] == '[]':
34
+ entity_types = entity_types[:-1]
35
+ continue
36
+
37
+ texts_ents = ast.literal_eval(c['value'])
38
+ # replace space to _ in texts_ents
39
+ entity_texts.extend(texts_ents)
40
+ num_repeat = len(texts_ents) - 1
41
+ entity_types.extend([entity_types[-1]] * num_repeat)
42
+
43
+ entity_spans = []
44
+ for j, entity_text in enumerate(entity_texts):
45
+ entity_tokens = tokenize_text(entity_text)
46
+ matches = []
47
+ for i in range(len(tokenized_text) - len(entity_tokens) + 1):
48
+ if " ".join(tokenized_text[i:i + len(entity_tokens)]).lower() == " ".join(entity_tokens).lower():
49
+ matches.append((i, i + len(entity_tokens) - 1, entity_types[j]))
50
+ if matches:
51
+ entity_spans.extend(matches)
52
+
53
+ return entity_spans, tokenized_text
54
+
55
+ # Usage:
56
+ # Replace 'entry' with the specific entry from your JSON data
57
+ entry = data[17818] # For example, taking the first entry
58
+ entity_spans, tokenized_text = extract_entity_spans(entry)
59
+ print("Entity Spans:", entity_spans)
60
+ #print("Tokenized Text:", tokenized_text)
61
+
62
+ # create a dict: {"tokenized_text": tokenized_text, "entity_spans": entity_spans}
63
+
64
+ all_data = []
65
+
66
+ for entry in tqdm(data):
67
+ entity_spans, tokenized_text = extract_entity_spans(entry)
68
+ all_data.append({"tokenized_text": tokenized_text, "ner": entity_spans})
69
+
70
+
71
+ with open('train_instruct.json', 'w') as f:
72
+ json.dump(all_data, f)
73
+
GLiNER/modules/evaluator.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
+ import numpy as np
4
+ import torch
5
+ from seqeval.metrics.v1 import _prf_divide
6
+
7
+
8
+ def extract_tp_actual_correct(y_true, y_pred):
9
+ entities_true = defaultdict(set)
10
+ entities_pred = defaultdict(set)
11
+
12
+ for type_name, (start, end), idx in y_true:
13
+ entities_true[type_name].add((start, end, idx))
14
+ for type_name, (start, end), idx in y_pred:
15
+ entities_pred[type_name].add((start, end, idx))
16
+
17
+ target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys()))
18
+
19
+ tp_sum = np.array([], dtype=np.int32)
20
+ pred_sum = np.array([], dtype=np.int32)
21
+ true_sum = np.array([], dtype=np.int32)
22
+ for type_name in target_names:
23
+ entities_true_type = entities_true.get(type_name, set())
24
+ entities_pred_type = entities_pred.get(type_name, set())
25
+ tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type))
26
+ pred_sum = np.append(pred_sum, len(entities_pred_type))
27
+ true_sum = np.append(true_sum, len(entities_true_type))
28
+
29
+ return pred_sum, tp_sum, true_sum, target_names
30
+
31
+
32
+ def flatten_for_eval(y_true, y_pred):
33
+ all_true = []
34
+ all_pred = []
35
+
36
+ for i, (true, pred) in enumerate(zip(y_true, y_pred)):
37
+ all_true.extend([t + [i] for t in true])
38
+ all_pred.extend([p + [i] for p in pred])
39
+
40
+ return all_true, all_pred
41
+
42
+
43
+ def compute_prf(y_true, y_pred, average='micro'):
44
+ y_true, y_pred = flatten_for_eval(y_true, y_pred)
45
+
46
+ pred_sum, tp_sum, true_sum, target_names = extract_tp_actual_correct(y_true, y_pred)
47
+
48
+ if average == 'micro':
49
+ tp_sum = np.array([tp_sum.sum()])
50
+ pred_sum = np.array([pred_sum.sum()])
51
+ true_sum = np.array([true_sum.sum()])
52
+
53
+ precision = _prf_divide(
54
+ numerator=tp_sum,
55
+ denominator=pred_sum,
56
+ metric='precision',
57
+ modifier='predicted',
58
+ average=average,
59
+ warn_for=('precision', 'recall', 'f-score'),
60
+ zero_division='warn'
61
+ )
62
+
63
+ recall = _prf_divide(
64
+ numerator=tp_sum,
65
+ denominator=true_sum,
66
+ metric='recall',
67
+ modifier='true',
68
+ average=average,
69
+ warn_for=('precision', 'recall', 'f-score'),
70
+ zero_division='warn'
71
+ )
72
+
73
+ denominator = precision + recall
74
+ denominator[denominator == 0.] = 1
75
+ f_score = 2 * (precision * recall) / denominator
76
+
77
+ return {'precision': precision[0], 'recall': recall[0], 'f_score': f_score[0]}
78
+
79
+
80
+ class Evaluator:
81
+ def __init__(self, all_true, all_outs):
82
+ self.all_true = all_true
83
+ self.all_outs = all_outs
84
+
85
+ def get_entities_fr(self, ents):
86
+ all_ents = []
87
+ for s, e, lab in ents:
88
+ all_ents.append([lab, (s, e)])
89
+ return all_ents
90
+
91
+ def transform_data(self):
92
+ all_true_ent = []
93
+ all_outs_ent = []
94
+ for i, j in zip(self.all_true, self.all_outs):
95
+ e = self.get_entities_fr(i)
96
+ all_true_ent.append(e)
97
+ e = self.get_entities_fr(j)
98
+ all_outs_ent.append(e)
99
+ return all_true_ent, all_outs_ent
100
+
101
+ @torch.no_grad()
102
+ def evaluate(self):
103
+ all_true_typed, all_outs_typed = self.transform_data()
104
+ precision, recall, f1 = compute_prf(all_true_typed, all_outs_typed).values()
105
+ output_str = f"P: {precision:.2%}\tR: {recall:.2%}\tF1: {f1:.2%}\n"
106
+ return output_str, f1
107
+
108
+
109
+ def is_nested(idx1, idx2):
110
+ # Return True if idx2 is nested inside idx1 or vice versa
111
+ return (idx1[0] <= idx2[0] and idx1[1] >= idx2[1]) or (idx2[0] <= idx1[0] and idx2[1] >= idx1[1])
112
+
113
+
114
+ def has_overlapping(idx1, idx2):
115
+ overlapping = True
116
+ if idx1[:2] == idx2[:2]:
117
+ return overlapping
118
+ if (idx1[0] > idx2[1] or idx2[0] > idx1[1]):
119
+ overlapping = False
120
+ return overlapping
121
+
122
+
123
+ def has_overlapping_nested(idx1, idx2):
124
+ # Return True if idx1 and idx2 overlap, but neither is nested inside the other
125
+ if idx1[:2] == idx2[:2]:
126
+ return True
127
+ if ((idx1[0] > idx2[1] or idx2[0] > idx1[1]) or is_nested(idx1, idx2)) and idx1 != idx2:
128
+ return False
129
+ else:
130
+ return True
131
+
132
+
133
+ def greedy_search(spans, flat_ner=True): # start, end, class, score
134
+
135
+ if flat_ner:
136
+ has_ov = has_overlapping
137
+ else:
138
+ has_ov = has_overlapping_nested
139
+
140
+ new_list = []
141
+ span_prob = sorted(spans, key=lambda x: -x[-1])
142
+ for i in range(len(spans)):
143
+ b = span_prob[i]
144
+ flag = False
145
+ for new in new_list:
146
+ if has_ov(b[:-1], new):
147
+ flag = True
148
+ break
149
+ if not flag:
150
+ new_list.append(b[:-1])
151
+ new_list = sorted(new_list, key=lambda x: x[0])
152
+ return new_list
GLiNER/modules/layers.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
5
+
6
+
7
+ class LstmSeq2SeqEncoder(nn.Module):
8
+ def __init__(self, input_size, hidden_size, num_layers=1, dropout=0., bidirectional=False):
9
+ super(LstmSeq2SeqEncoder, self).__init__()
10
+ self.lstm = nn.LSTM(input_size=input_size,
11
+ hidden_size=hidden_size,
12
+ num_layers=num_layers,
13
+ dropout=dropout,
14
+ bidirectional=bidirectional,
15
+ batch_first=True)
16
+
17
+ def forward(self, x, mask, hidden=None):
18
+ # Packing the input sequence
19
+ lengths = mask.sum(dim=1).cpu()
20
+ packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
21
+
22
+ # Passing packed sequence through LSTM
23
+ packed_output, hidden = self.lstm(packed_x, hidden)
24
+
25
+ # Unpacking the output sequence
26
+ output, _ = pad_packed_sequence(packed_output, batch_first=True)
27
+
28
+ return output
GLiNER/modules/run_evaluation.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import os
4
+ import os
5
+
6
+ import torch
7
+ from tqdm import tqdm
8
+ import random
9
+
10
+
11
+ def open_content(path):
12
+ paths = glob.glob(os.path.join(path, "*.json"))
13
+ train, dev, test, labels = None, None, None, None
14
+ for p in paths:
15
+ if "train" in p:
16
+ with open(p, "r") as f:
17
+ train = json.load(f)
18
+ elif "dev" in p:
19
+ with open(p, "r") as f:
20
+ dev = json.load(f)
21
+ elif "test" in p:
22
+ with open(p, "r") as f:
23
+ test = json.load(f)
24
+ elif "labels" in p:
25
+ with open(p, "r") as f:
26
+ labels = json.load(f)
27
+ return train, dev, test, labels
28
+
29
+
30
+ def process(data):
31
+ words = data['sentence'].split()
32
+ entities = [] # List of entities (start, end, type)
33
+
34
+ for entity in data['entities']:
35
+ start_char, end_char = entity['pos']
36
+
37
+ # Initialize variables to keep track of word positions
38
+ start_word = None
39
+ end_word = None
40
+
41
+ # Iterate through words and find the word positions
42
+ char_count = 0
43
+ for i, word in enumerate(words):
44
+ word_length = len(word)
45
+ if char_count == start_char:
46
+ start_word = i
47
+ if char_count + word_length == end_char:
48
+ end_word = i
49
+ break
50
+ char_count += word_length + 1 # Add 1 for the space
51
+
52
+ # Append the word positions to the list
53
+ entities.append((start_word, end_word, entity['type']))
54
+
55
+ # Create a list of word positions for each entity
56
+ sample = {
57
+ "tokenized_text": words,
58
+ "ner": entities
59
+ }
60
+
61
+ return sample
62
+
63
+
64
+ # create dataset
65
+ def create_dataset(path):
66
+ train, dev, test, labels = open_content(path)
67
+ train_dataset = []
68
+ dev_dataset = []
69
+ test_dataset = []
70
+ for data in train:
71
+ train_dataset.append(process(data))
72
+ for data in dev:
73
+ dev_dataset.append(process(data))
74
+ for data in test:
75
+ test_dataset.append(process(data))
76
+ return train_dataset, dev_dataset, test_dataset, labels
77
+
78
+
79
+ @torch.no_grad()
80
+ def get_for_one_path(path, model):
81
+ # load the dataset
82
+ _, _, test_dataset, entity_types = create_dataset(path)
83
+
84
+ data_name = path.split("/")[-1] # get the name of the dataset
85
+
86
+ # check if the dataset is flat_ner
87
+ flat_ner = True
88
+ if any([i in data_name for i in ["ACE", "GENIA", "Corpus"]]):
89
+ flat_ner = False
90
+
91
+ # evaluate the model
92
+ results, f1 = model.evaluate(test_dataset, flat_ner=flat_ner, threshold=0.5, batch_size=12,
93
+ entity_types=entity_types)
94
+ return data_name, results, f1
95
+
96
+
97
+ def get_for_all_path(model, steps, log_dir, data_paths):
98
+ all_paths = glob.glob(f"{data_paths}/*")
99
+
100
+ all_paths = sorted(all_paths)
101
+
102
+ # move the model to the device
103
+ device = next(model.parameters()).device
104
+ model.to(device)
105
+ # set the model to eval mode
106
+ model.eval()
107
+
108
+ # log the results
109
+ save_path = os.path.join(log_dir, "results.txt")
110
+
111
+ with open(save_path, "a") as f:
112
+ f.write("##############################################\n")
113
+ # write step
114
+ f.write("step: " + str(steps) + "\n")
115
+
116
+ zero_shot_benc = ["mit-movie", "mit-restaurant", "CrossNER_AI", "CrossNER_literature", "CrossNER_music",
117
+ "CrossNER_politics", "CrossNER_science"]
118
+
119
+ zero_shot_benc_results = {}
120
+ all_results = {} # without crossNER
121
+
122
+ for p in tqdm(all_paths):
123
+ if "sample_" not in p:
124
+ data_name, results, f1 = get_for_one_path(p, model)
125
+ # write to file
126
+ with open(save_path, "a") as f:
127
+ f.write(data_name + "\n")
128
+ f.write(str(results) + "\n")
129
+
130
+ if data_name in zero_shot_benc:
131
+ zero_shot_benc_results[data_name] = f1
132
+ else:
133
+ all_results[data_name] = f1
134
+
135
+ avg_all = sum(all_results.values()) / len(all_results)
136
+ avg_zs = sum(zero_shot_benc_results.values()) / len(zero_shot_benc_results)
137
+
138
+ save_path_table = os.path.join(log_dir, "tables.txt")
139
+
140
+ # results for all datasets except crossNER
141
+ table_bench_all = ""
142
+ for k, v in all_results.items():
143
+ table_bench_all += f"{k:20}: {v:.1%}\n"
144
+ # (20 size aswell for average i.e. :20)
145
+ table_bench_all += f"{'Average':20}: {avg_all:.1%}"
146
+
147
+ # results for zero-shot benchmark
148
+ table_bench_zeroshot = ""
149
+ for k, v in zero_shot_benc_results.items():
150
+ table_bench_zeroshot += f"{k:20}: {v:.1%}\n"
151
+ table_bench_zeroshot += f"{'Average':20}: {avg_zs:.1%}"
152
+
153
+ # write to file
154
+ with open(save_path_table, "a") as f:
155
+ f.write("##############################################\n")
156
+ f.write("step: " + str(steps) + "\n")
157
+ f.write("Table for all datasets except crossNER\n")
158
+ f.write(table_bench_all + "\n\n")
159
+ f.write("Table for zero-shot benchmark\n")
160
+ f.write(table_bench_zeroshot + "\n")
161
+ f.write("##############################################\n\n")
162
+
163
+
164
+ def sample_train_data(data_paths, sample_size=10000):
165
+ all_paths = glob.glob(f"{data_paths}/*")
166
+
167
+ all_paths = sorted(all_paths)
168
+
169
+ # to exclude the zero-shot benchmark datasets
170
+ zero_shot_benc = ["CrossNER_AI", "CrossNER_literature", "CrossNER_music",
171
+ "CrossNER_politics", "CrossNER_science", "ACE 2004"]
172
+
173
+ new_train = []
174
+ # take 10k samples from each dataset
175
+ for p in tqdm(all_paths):
176
+ if any([i in p for i in zero_shot_benc]):
177
+ continue
178
+ train, dev, test, labels = create_dataset(p)
179
+
180
+ # add label key to the train data
181
+ for i in range(len(train)):
182
+ train[i]["label"] = labels
183
+
184
+ random.shuffle(train)
185
+ train = train[:sample_size]
186
+ new_train.extend(train)
187
+
188
+ return new_train
GLiNER/modules/span_rep.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ def create_projection_layer(hidden_size: int, dropout: float, out_dim: int = None) -> nn.Sequential:
6
+ """
7
+ Creates a projection layer with specified configurations.
8
+ """
9
+ if out_dim is None:
10
+ out_dim = hidden_size
11
+
12
+ return nn.Sequential(
13
+ nn.Linear(hidden_size, out_dim * 4),
14
+ nn.ReLU(),
15
+ nn.Dropout(dropout),
16
+ nn.Linear(out_dim * 4, out_dim)
17
+ )
18
+
19
+
20
+ class SpanQuery(nn.Module):
21
+
22
+ def __init__(self, hidden_size, max_width, trainable=True):
23
+ super().__init__()
24
+
25
+ self.query_seg = nn.Parameter(torch.randn(hidden_size, max_width))
26
+
27
+ nn.init.uniform_(self.query_seg, a=-1, b=1)
28
+
29
+ if not trainable:
30
+ self.query_seg.requires_grad = False
31
+
32
+ self.project = nn.Sequential(
33
+ nn.Linear(hidden_size, hidden_size),
34
+ nn.ReLU()
35
+ )
36
+
37
+ def forward(self, h, *args):
38
+ # h of shape [B, L, D]
39
+ # query_seg of shape [D, max_width]
40
+
41
+ span_rep = torch.einsum('bld, ds->blsd', h, self.query_seg)
42
+
43
+ return self.project(span_rep)
44
+
45
+
46
+ class SpanMLP(nn.Module):
47
+
48
+ def __init__(self, hidden_size, max_width):
49
+ super().__init__()
50
+
51
+ self.mlp = nn.Linear(hidden_size, hidden_size * max_width)
52
+
53
+ def forward(self, h, *args):
54
+ # h of shape [B, L, D]
55
+ # query_seg of shape [D, max_width]
56
+
57
+ B, L, D = h.size()
58
+
59
+ span_rep = self.mlp(h)
60
+
61
+ span_rep = span_rep.view(B, L, -1, D)
62
+
63
+ return span_rep.relu()
64
+
65
+
66
+ class SpanCAT(nn.Module):
67
+
68
+ def __init__(self, hidden_size, max_width):
69
+ super().__init__()
70
+
71
+ self.max_width = max_width
72
+
73
+ self.query_seg = nn.Parameter(torch.randn(128, max_width))
74
+
75
+ self.project = nn.Sequential(
76
+ nn.Linear(hidden_size + 128, hidden_size),
77
+ nn.ReLU()
78
+ )
79
+
80
+ def forward(self, h, *args):
81
+ # h of shape [B, L, D]
82
+ # query_seg of shape [D, max_width]
83
+
84
+ B, L, D = h.size()
85
+
86
+ h = h.view(B, L, 1, D).repeat(1, 1, self.max_width, 1)
87
+
88
+ q = self.query_seg.view(1, 1, self.max_width, -1).repeat(B, L, 1, 1)
89
+
90
+ span_rep = torch.cat([h, q], dim=-1)
91
+
92
+ span_rep = self.project(span_rep)
93
+
94
+ return span_rep
95
+
96
+
97
+ class SpanConvBlock(nn.Module):
98
+ def __init__(self, hidden_size, kernel_size, span_mode='conv_normal'):
99
+ super().__init__()
100
+
101
+ if span_mode == 'conv_conv':
102
+ self.conv = nn.Conv1d(hidden_size, hidden_size,
103
+ kernel_size=kernel_size)
104
+
105
+ # initialize the weights
106
+ nn.init.kaiming_uniform_(self.conv.weight, nonlinearity='relu')
107
+
108
+ elif span_mode == 'conv_max':
109
+ self.conv = nn.MaxPool1d(kernel_size=kernel_size, stride=1)
110
+ elif span_mode == 'conv_mean' or span_mode == 'conv_sum':
111
+ self.conv = nn.AvgPool1d(kernel_size=kernel_size, stride=1)
112
+
113
+ self.span_mode = span_mode
114
+
115
+ self.pad = kernel_size - 1
116
+
117
+ def forward(self, x):
118
+
119
+ x = torch.einsum('bld->bdl', x)
120
+
121
+ if self.pad > 0:
122
+ x = F.pad(x, (0, self.pad), "constant", 0)
123
+
124
+ x = self.conv(x)
125
+
126
+ if self.span_mode == "conv_sum":
127
+ x = x * (self.pad + 1)
128
+
129
+ return torch.einsum('bdl->bld', x)
130
+
131
+
132
+ class SpanConv(nn.Module):
133
+ def __init__(self, hidden_size, max_width, span_mode):
134
+ super().__init__()
135
+
136
+ kernels = [i + 2 for i in range(max_width - 1)]
137
+
138
+ self.convs = nn.ModuleList()
139
+
140
+ for kernel in kernels:
141
+ self.convs.append(SpanConvBlock(hidden_size, kernel, span_mode))
142
+
143
+ self.project = nn.Sequential(
144
+ nn.ReLU(),
145
+ nn.Linear(hidden_size, hidden_size)
146
+ )
147
+
148
+ def forward(self, x, *args):
149
+
150
+ span_reps = [x]
151
+
152
+ for conv in self.convs:
153
+ h = conv(x)
154
+ span_reps.append(h)
155
+
156
+ span_reps = torch.stack(span_reps, dim=-2)
157
+
158
+ return self.project(span_reps)
159
+
160
+
161
+ class SpanEndpointsBlock(nn.Module):
162
+ def __init__(self, kernel_size):
163
+ super().__init__()
164
+
165
+ self.kernel_size = kernel_size
166
+
167
+ def forward(self, x):
168
+ B, L, D = x.size()
169
+
170
+ span_idx = torch.LongTensor(
171
+ [[i, i + self.kernel_size - 1] for i in range(L)]).to(x.device)
172
+
173
+ x = F.pad(x, (0, 0, 0, self.kernel_size - 1), "constant", 0)
174
+
175
+ # endrep
176
+ start_end_rep = torch.index_select(x, dim=1, index=span_idx.view(-1))
177
+
178
+ start_end_rep = start_end_rep.view(B, L, 2, D)
179
+
180
+ return start_end_rep
181
+
182
+
183
+ class ConvShare(nn.Module):
184
+ def __init__(self, hidden_size, max_width):
185
+ super().__init__()
186
+
187
+ self.max_width = max_width
188
+
189
+ self.conv_weigth = nn.Parameter(
190
+ torch.randn(hidden_size, hidden_size, max_width))
191
+
192
+ nn.init.kaiming_uniform_(self.conv_weigth, nonlinearity='relu')
193
+
194
+ self.project = nn.Sequential(
195
+ nn.ReLU(),
196
+ nn.Linear(hidden_size, hidden_size)
197
+ )
198
+
199
+ def forward(self, x, *args):
200
+ span_reps = []
201
+
202
+ x = torch.einsum('bld->bdl', x)
203
+
204
+ for i in range(self.max_width):
205
+ pad = i
206
+ x_i = F.pad(x, (0, pad), "constant", 0)
207
+ conv_w = self.conv_weigth[:, :, :i + 1]
208
+ out_i = F.conv1d(x_i, conv_w)
209
+ span_reps.append(out_i.transpose(-1, -2))
210
+
211
+ out = torch.stack(span_reps, dim=-2)
212
+
213
+ return self.project(out)
214
+
215
+
216
+ def extract_elements(sequence, indices):
217
+ B, L, D = sequence.shape
218
+ K = indices.shape[1]
219
+
220
+ # Expand indices to [B, K, D]
221
+ expanded_indices = indices.unsqueeze(2).expand(-1, -1, D)
222
+
223
+ # Gather the elements
224
+ extracted_elements = torch.gather(sequence, 1, expanded_indices)
225
+
226
+ return extracted_elements
227
+
228
+
229
+ class SpanMarker(nn.Module):
230
+
231
+ def __init__(self, hidden_size, max_width, dropout=0.4):
232
+ super().__init__()
233
+
234
+ self.max_width = max_width
235
+
236
+ self.project_start = nn.Sequential(
237
+ nn.Linear(hidden_size, hidden_size * 2, bias=True),
238
+ nn.ReLU(),
239
+ nn.Dropout(dropout),
240
+ nn.Linear(hidden_size * 2, hidden_size, bias=True),
241
+ )
242
+
243
+ self.project_end = nn.Sequential(
244
+ nn.Linear(hidden_size, hidden_size * 2, bias=True),
245
+ nn.ReLU(),
246
+ nn.Dropout(dropout),
247
+ nn.Linear(hidden_size * 2, hidden_size, bias=True),
248
+ )
249
+
250
+ self.out_project = nn.Linear(hidden_size * 2, hidden_size, bias=True)
251
+
252
+ def forward(self, h, span_idx):
253
+ # h of shape [B, L, D]
254
+ # query_seg of shape [D, max_width]
255
+
256
+ B, L, D = h.size()
257
+
258
+ # project start and end
259
+ start_rep = self.project_start(h)
260
+ end_rep = self.project_end(h)
261
+
262
+ start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
263
+ end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
264
+
265
+ # concat start and end
266
+ cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
267
+
268
+ # project
269
+ cat = self.out_project(cat)
270
+
271
+ # reshape
272
+ return cat.view(B, L, self.max_width, D)
273
+
274
+
275
+ class SpanMarkerV0(nn.Module):
276
+ """
277
+ Marks and projects span endpoints using an MLP.
278
+ """
279
+
280
+ def __init__(self, hidden_size: int, max_width: int, dropout: float = 0.4):
281
+ super().__init__()
282
+ self.max_width = max_width
283
+ self.project_start = create_projection_layer(hidden_size, dropout)
284
+ self.project_end = create_projection_layer(hidden_size, dropout)
285
+
286
+ self.out_project = create_projection_layer(hidden_size * 2, dropout, hidden_size)
287
+
288
+ def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor:
289
+ B, L, D = h.size()
290
+
291
+ start_rep = self.project_start(h)
292
+ end_rep = self.project_end(h)
293
+
294
+ start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
295
+ end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
296
+
297
+ cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
298
+
299
+ return self.out_project(cat).view(B, L, self.max_width, D)
300
+
301
+
302
+ class ConvShareV2(nn.Module):
303
+ def __init__(self, hidden_size, max_width):
304
+ super().__init__()
305
+
306
+ self.max_width = max_width
307
+
308
+ self.conv_weigth = nn.Parameter(
309
+ torch.randn(hidden_size, hidden_size, max_width)
310
+ )
311
+
312
+ nn.init.xavier_normal_(self.conv_weigth)
313
+
314
+ def forward(self, x, *args):
315
+ span_reps = []
316
+
317
+ x = torch.einsum('bld->bdl', x)
318
+
319
+ for i in range(self.max_width):
320
+ pad = i
321
+ x_i = F.pad(x, (0, pad), "constant", 0)
322
+ conv_w = self.conv_weigth[:, :, :i + 1]
323
+ out_i = F.conv1d(x_i, conv_w)
324
+ span_reps.append(out_i.transpose(-1, -2))
325
+
326
+ out = torch.stack(span_reps, dim=-2)
327
+
328
+ return out
329
+
330
+
331
+ class SpanRepLayer(nn.Module):
332
+ """
333
+ Various span representation approaches
334
+ """
335
+
336
+ def __init__(self, hidden_size, max_width, span_mode, **kwargs):
337
+ super().__init__()
338
+
339
+ if span_mode == 'marker':
340
+ self.span_rep_layer = SpanMarker(hidden_size, max_width, **kwargs)
341
+ elif span_mode == 'markerV0':
342
+ self.span_rep_layer = SpanMarkerV0(hidden_size, max_width, **kwargs)
343
+ elif span_mode == 'query':
344
+ self.span_rep_layer = SpanQuery(
345
+ hidden_size, max_width, trainable=True)
346
+ elif span_mode == 'mlp':
347
+ self.span_rep_layer = SpanMLP(hidden_size, max_width)
348
+ elif span_mode == 'cat':
349
+ self.span_rep_layer = SpanCAT(hidden_size, max_width)
350
+ elif span_mode == 'conv_conv':
351
+ self.span_rep_layer = SpanConv(
352
+ hidden_size, max_width, span_mode='conv_conv')
353
+ elif span_mode == 'conv_max':
354
+ self.span_rep_layer = SpanConv(
355
+ hidden_size, max_width, span_mode='conv_max')
356
+ elif span_mode == 'conv_mean':
357
+ self.span_rep_layer = SpanConv(
358
+ hidden_size, max_width, span_mode='conv_mean')
359
+ elif span_mode == 'conv_sum':
360
+ self.span_rep_layer = SpanConv(
361
+ hidden_size, max_width, span_mode='conv_sum')
362
+ elif span_mode == 'conv_share':
363
+ self.span_rep_layer = ConvShare(hidden_size, max_width)
364
+ else:
365
+ raise ValueError(f'Unknown span mode {span_mode}')
366
+
367
+ def forward(self, x, *args):
368
+
369
+ return self.span_rep_layer(x, *args)
GLiNER/modules/token_rep.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ from flair.data import Sentence
5
+ from flair.embeddings import TransformerWordEmbeddings
6
+ from torch import nn
7
+ from torch.nn.utils.rnn import pad_sequence
8
+
9
+
10
+ # flair.cache_root = '/gpfswork/rech/pds/upa43yu/.cache'
11
+
12
+
13
+ class TokenRepLayer(nn.Module):
14
+ def __init__(self, model_name: str = "bert-base-cased", fine_tune: bool = True, subtoken_pooling: str = "first",
15
+ hidden_size: int = 768,
16
+ add_tokens=["[SEP]", "[ENT]"]
17
+ ):
18
+ super().__init__()
19
+
20
+ self.bert_layer = TransformerWordEmbeddings(
21
+ model_name,
22
+ fine_tune=fine_tune,
23
+ subtoken_pooling=subtoken_pooling,
24
+ allow_long_sentences=True
25
+ )
26
+
27
+ # add tokens to vocabulary
28
+ self.bert_layer.tokenizer.add_tokens(add_tokens)
29
+
30
+ # resize token embeddings
31
+ self.bert_layer.model.resize_token_embeddings(len(self.bert_layer.tokenizer))
32
+
33
+ bert_hidden_size = self.bert_layer.embedding_length
34
+
35
+ if hidden_size != bert_hidden_size:
36
+ self.projection = nn.Linear(bert_hidden_size, hidden_size)
37
+
38
+ def forward(self, tokens: List[List[str]], lengths: torch.Tensor):
39
+ token_embeddings = self.compute_word_embedding(tokens)
40
+
41
+ if hasattr(self, "projection"):
42
+ token_embeddings = self.projection(token_embeddings)
43
+
44
+ B = len(lengths)
45
+ max_length = lengths.max()
46
+ mask = (torch.arange(max_length).view(1, -1).repeat(B, 1) < lengths.cpu().unsqueeze(1)).to(
47
+ token_embeddings.device).long()
48
+ return {"embeddings": token_embeddings, "mask": mask}
49
+
50
+ def compute_word_embedding(self, tokens):
51
+ sentences = [Sentence(i) for i in tokens]
52
+ self.bert_layer.embed(sentences)
53
+ token_embeddings = pad_sequence([torch.stack([t.embedding for t in k]) for k in sentences], batch_first=True)
54
+ return token_embeddings
GLiNER/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ huggingface_hub
4
+ flair
5
+ seqeval
6
+ tqdm
GLiNER/save_load.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model import GLiNER
3
+
4
+
5
+ def save_model(current_model, path):
6
+ config = current_model.config
7
+ dict_save = {"model_weights": current_model.state_dict(), "config": config}
8
+ torch.save(dict_save, path)
9
+
10
+
11
+ def load_model(path, model_name=None, device=None):
12
+ dict_load = torch.load(path, map_location=torch.device('cpu'))
13
+ config = dict_load["config"]
14
+
15
+ if model_name is not None:
16
+ config.model_name = model_name
17
+
18
+ loaded_model = GLiNER(config)
19
+ loaded_model.load_state_dict(dict_load["model_weights"])
20
+ return loaded_model.to(device) if device is not None else loaded_model
GLiNER/train.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ import yaml
6
+ from tqdm import tqdm
7
+ from transformers import get_cosine_schedule_with_warmup
8
+
9
+ # from model_nested import NerFilteredSemiCRF
10
+ from model import GLiNER
11
+ from modules.run_evaluation import get_for_all_path, sample_train_data
12
+ from save_load import save_model, load_model
13
+ import json
14
+
15
+
16
+ # train function
17
+ def train(model, optimizer, train_data, num_steps=1000, eval_every=100, log_dir="logs", warmup_ratio=0.1,
18
+ train_batch_size=8, device='cuda'):
19
+ model.train()
20
+
21
+ # initialize data loaders
22
+ train_loader = model.create_dataloader(train_data, batch_size=train_batch_size, shuffle=True)
23
+
24
+ pbar = tqdm(range(num_steps))
25
+
26
+ if warmup_ratio < 1:
27
+ num_warmup_steps = int(num_steps * warmup_ratio)
28
+ else:
29
+ num_warmup_steps = int(warmup_ratio)
30
+
31
+ scheduler = get_cosine_schedule_with_warmup(
32
+ optimizer,
33
+ num_warmup_steps=num_warmup_steps,
34
+ num_training_steps=num_steps
35
+ )
36
+
37
+ iter_train_loader = iter(train_loader)
38
+
39
+ for step in pbar:
40
+ try:
41
+ x = next(iter_train_loader)
42
+ except StopIteration:
43
+ iter_train_loader = iter(train_loader)
44
+ x = next(iter_train_loader)
45
+
46
+ for k, v in x.items():
47
+ if isinstance(v, torch.Tensor):
48
+ x[k] = v.to(device)
49
+
50
+ try:
51
+ loss = model(x) # Forward pass
52
+ except:
53
+ continue
54
+
55
+ # check if loss is nan
56
+ if torch.isnan(loss):
57
+ continue
58
+
59
+ loss.backward() # Compute gradients
60
+ optimizer.step() # Update parameters
61
+ scheduler.step() # Update learning rate schedule
62
+ optimizer.zero_grad() # Reset gradients
63
+
64
+ description = f"step: {step} | epoch: {step // len(train_loader)} | loss: {loss.item():.2f}"
65
+
66
+ if (step + 1) % eval_every == 0:
67
+ current_path = os.path.join(log_dir, f'model_{step + 1}')
68
+ save_model(model, current_path)
69
+ #val_data_dir = "/gpfswork/rech/ohy/upa43yu/NER_datasets" # can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view"
70
+ #get_for_all_path(model, step, log_dir, val_data_dir) # you can remove this comment if you want to evaluate the model
71
+
72
+ model.train()
73
+
74
+ pbar.set_description(description)
75
+
76
+
77
+ def create_parser():
78
+ parser = argparse.ArgumentParser(description="Span-based NER")
79
+ parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file")
80
+ parser.add_argument('--log_dir', type=str, default='logs', help='Path to the log directory')
81
+ return parser
82
+
83
+
84
+ def load_config_as_namespace(config_file):
85
+ with open(config_file, 'r') as f:
86
+ config_dict = yaml.safe_load(f)
87
+ return argparse.Namespace(**config_dict)
88
+
89
+
90
+ if __name__ == "__main__":
91
+ # parse args
92
+ parser = create_parser()
93
+ args = parser.parse_args()
94
+
95
+ # load config
96
+ config = load_config_as_namespace(args.config)
97
+
98
+ config.log_dir = args.log_dir
99
+
100
+ try:
101
+ with open(config.train_data, 'r') as f:
102
+ data = json.load(f)
103
+ except:
104
+ data = sample_train_data(config.train_data, 10000)
105
+
106
+ if config.prev_path != "none":
107
+ model = load_model(config.prev_path)
108
+ model.config = config
109
+ else:
110
+ model = GLiNER(config)
111
+
112
+ if torch.cuda.is_available():
113
+ model = model.cuda()
114
+
115
+ lr_encoder = float(config.lr_encoder)
116
+ lr_others = float(config.lr_others)
117
+
118
+ optimizer = torch.optim.AdamW([
119
+ # encoder
120
+ {'params': model.token_rep_layer.parameters(), 'lr': lr_encoder},
121
+ {'params': model.rnn.parameters(), 'lr': lr_others},
122
+ # projection layers
123
+ {'params': model.span_rep_layer.parameters(), 'lr': lr_others},
124
+ {'params': model.prompt_rep_layer.parameters(), 'lr': lr_others},
125
+ ])
126
+
127
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
128
+
129
+ train(model, optimizer, data, num_steps=config.num_steps, eval_every=config.eval_every,
130
+ log_dir=config.log_dir, warmup_ratio=config.warmup_ratio, train_batch_size=config.train_batch_size,
131
+ device=device)