Spaces:
Running
Running
DeepMount00
commited on
Upload 12 files
Browse files- GLiNER/README.md +90 -0
- GLiNER/model.py +412 -0
- GLiNER/modules/base.py +150 -0
- GLiNER/modules/data_proc.py +73 -0
- GLiNER/modules/evaluator.py +152 -0
- GLiNER/modules/layers.py +28 -0
- GLiNER/modules/run_evaluation.py +188 -0
- GLiNER/modules/span_rep.py +369 -0
- GLiNER/modules/token_rep.py +54 -0
- GLiNER/requirements.txt +6 -0
- GLiNER/save_load.py +20 -0
- GLiNER/train.py +131 -0
GLiNER/README.md
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Card for GLiNER-base
|
2 |
+
|
3 |
+
GLiNER is a Named Entity Recognition (NER) model capable of identifying any entity type using a bidirectional transformer encoder (BERT-like). It provides a practical alternative to traditional NER models, which are limited to predefined entities, and Large Language Models (LLMs) that, despite their flexibility, are costly and large for resource-constrained scenarios.
|
4 |
+
|
5 |
+
## Models Status
|
6 |
+
|
7 |
+
### Available Models on Hugging Face
|
8 |
+
- [x] [GLiNER-Base](https://huggingface.co/urchade/gliner_base) (CC BY NC 4.0)
|
9 |
+
- [x] [GLiNER-Multi](https://huggingface.co/urchade/gliner_multi) (CC BY NC 4.0)
|
10 |
+
- [x] [GLiNER-small](https://huggingface.co/urchade/gliner_small) (CC BY NC 4.0)
|
11 |
+
- [x] [GLiNER-small-v2](https://huggingface.co/urchade/gliner_smallv2) (Apache)
|
12 |
+
- [x] [GLiNER-medium](https://huggingface.co/urchade/gliner_medium) (CC BY NC 4.0)
|
13 |
+
- [x] [GLiNER-medium-v2](https://huggingface.co/urchade/gliner_mediumv2) (Apache)
|
14 |
+
- [x] [GLiNER-large](https://huggingface.co/urchade/gliner_large) (CC BY NC 4.0)
|
15 |
+
- [x] [GLiNER-ledium-v2](https://huggingface.co/urchade/gliner_largev2) (Apache)
|
16 |
+
|
17 |
+
### To Release
|
18 |
+
- [ ] ⏳ GLiNER-Multiv2
|
19 |
+
- [ ] ⏳ GLiNER-Sup (trained on mixture of NER datasets)
|
20 |
+
|
21 |
+
## Links
|
22 |
+
|
23 |
+
* Paper: https://arxiv.org/abs/2311.08526
|
24 |
+
* Repository: https://github.com/urchade/GLiNER
|
25 |
+
|
26 |
+
## Installation
|
27 |
+
To use this model, you must download the GLiNER repository and install its dependencies:
|
28 |
+
```
|
29 |
+
!git clone https://github.com/urchade/GLiNER.git
|
30 |
+
%cd GLiNER
|
31 |
+
!pip install -r requirements.txt
|
32 |
+
```
|
33 |
+
|
34 |
+
## Usage
|
35 |
+
Once you've downloaded the GLiNER repository, you can import the GLiNER class from the `model` file. You can then load this model using `GLiNER.from_pretrained` and predict entities with `predict_entities`.
|
36 |
+
|
37 |
+
```python
|
38 |
+
from model import GLiNER
|
39 |
+
|
40 |
+
model = GLiNER.from_pretrained("urchade/gliner_base")
|
41 |
+
|
42 |
+
text = """
|
43 |
+
Cristiano Ronaldo dos Santos Aveiro (Portuguese pronunciation: [kɾiʃˈtjɐnu ʁɔˈnaldu]; born 5 February 1985) is a Portuguese professional footballer who plays as a forward for and captains both Saudi Pro League club Al Nassr and the Portugal national team. Widely regarded as one of the greatest players of all time, Ronaldo has won five Ballon d'Or awards,[note 3] a record three UEFA Men's Player of the Year Awards, and four European Golden Shoes, the most by a European player. He has won 33 trophies in his career, including seven league titles, five UEFA Champions Leagues, the UEFA European Championship and the UEFA Nations League. Ronaldo holds the records for most appearances (183), goals (140) and assists (42) in the Champions League, goals in the European Championship (14), international goals (128) and international appearances (205). He is one of the few players to have made over 1,200 professional career appearances, the most by an outfield player, and has scored over 850 official senior career goals for club and country, making him the top goalscorer of all time.
|
44 |
+
"""
|
45 |
+
|
46 |
+
labels = ["person", "award", "date", "competitions", "teams"]
|
47 |
+
|
48 |
+
entities = model.predict_entities(text, labels)
|
49 |
+
|
50 |
+
for entity in entities:
|
51 |
+
print(entity["text"], "=>", entity["label"])
|
52 |
+
```
|
53 |
+
|
54 |
+
```
|
55 |
+
Cristiano Ronaldo dos Santos Aveiro => person
|
56 |
+
5 February 1985 => date
|
57 |
+
Al Nassr => teams
|
58 |
+
Portugal national team => teams
|
59 |
+
Ballon d'Or => award
|
60 |
+
UEFA Men's Player of the Year Awards => award
|
61 |
+
European Golden Shoes => award
|
62 |
+
UEFA Champions Leagues => competitions
|
63 |
+
UEFA European Championship => competitions
|
64 |
+
UEFA Nations League => competitions
|
65 |
+
Champions League => competitions
|
66 |
+
European Championship => competitions
|
67 |
+
```
|
68 |
+
|
69 |
+
## Named Entity Recognition benchmark result
|
70 |
+
|
71 |
+
![image/png](https://cdn-uploads.huggingface.co/production/uploads/6317233cc92fd6fee317e030/Y5f7tK8lonGqeeO6L6bVI.png)
|
72 |
+
|
73 |
+
## Model Authors
|
74 |
+
The model authors are:
|
75 |
+
* [Urchade Zaratiana](https://huggingface.co/urchade)
|
76 |
+
* Nadi Tomeh
|
77 |
+
* Pierre Holat
|
78 |
+
* Thierry Charnois
|
79 |
+
|
80 |
+
## Citation
|
81 |
+
```bibtex
|
82 |
+
@misc{zaratiana2023gliner,
|
83 |
+
title={GLiNER: Generalist Model for Named Entity Recognition using Bidirectional Transformer},
|
84 |
+
author={Urchade Zaratiana and Nadi Tomeh and Pierre Holat and Thierry Charnois},
|
85 |
+
year={2023},
|
86 |
+
eprint={2311.08526},
|
87 |
+
archivePrefix={arXiv},
|
88 |
+
primaryClass={cs.CL}
|
89 |
+
}
|
90 |
+
```
|
GLiNER/model.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
import re
|
5 |
+
from typing import Dict, Optional, Union
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from modules.layers import LstmSeq2SeqEncoder
|
9 |
+
from modules.base import InstructBase
|
10 |
+
from modules.evaluator import Evaluator, greedy_search
|
11 |
+
from modules.span_rep import SpanRepLayer
|
12 |
+
from modules.token_rep import TokenRepLayer
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn.utils.rnn import pad_sequence
|
15 |
+
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
16 |
+
from huggingface_hub.utils import HfHubHTTPError
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
class GLiNER(InstructBase, PyTorchModelHubMixin):
|
21 |
+
def __init__(self, config):
|
22 |
+
super().__init__(config)
|
23 |
+
|
24 |
+
self.config = config
|
25 |
+
|
26 |
+
# [ENT] token
|
27 |
+
self.entity_token = "<<ENT>>"
|
28 |
+
self.sep_token = "<<SEP>>"
|
29 |
+
|
30 |
+
# usually a pretrained bidirectional transformer, returns first subtoken representation
|
31 |
+
self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune,
|
32 |
+
subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size,
|
33 |
+
add_tokens=[self.entity_token, self.sep_token])
|
34 |
+
|
35 |
+
# hierarchical representation of tokens
|
36 |
+
self.rnn = LstmSeq2SeqEncoder(
|
37 |
+
input_size=config.hidden_size,
|
38 |
+
hidden_size=config.hidden_size // 2,
|
39 |
+
num_layers=1,
|
40 |
+
bidirectional=True,
|
41 |
+
)
|
42 |
+
|
43 |
+
# span representation
|
44 |
+
self.span_rep_layer = SpanRepLayer(
|
45 |
+
span_mode=config.span_mode,
|
46 |
+
hidden_size=config.hidden_size,
|
47 |
+
max_width=config.max_width,
|
48 |
+
dropout=config.dropout,
|
49 |
+
)
|
50 |
+
|
51 |
+
# prompt representation (FFN)
|
52 |
+
self.prompt_rep_layer = nn.Sequential(
|
53 |
+
nn.Linear(config.hidden_size, config.hidden_size * 4),
|
54 |
+
nn.Dropout(config.dropout),
|
55 |
+
nn.ReLU(),
|
56 |
+
nn.Linear(config.hidden_size * 4, config.hidden_size)
|
57 |
+
)
|
58 |
+
|
59 |
+
def compute_score_train(self, x):
|
60 |
+
span_idx = x['span_idx'] * x['span_mask'].unsqueeze(-1)
|
61 |
+
|
62 |
+
new_length = x['seq_length'].clone()
|
63 |
+
new_tokens = []
|
64 |
+
all_len_prompt = []
|
65 |
+
num_classes_all = []
|
66 |
+
|
67 |
+
# add prompt to the tokens
|
68 |
+
for i in range(len(x['tokens'])):
|
69 |
+
all_types_i = list(x['classes_to_id'][i].keys())
|
70 |
+
# multiple entity types in all_types. Prompt is appended at the start of tokens
|
71 |
+
entity_prompt = []
|
72 |
+
num_classes_all.append(len(all_types_i))
|
73 |
+
# add enity types to prompt
|
74 |
+
for entity_type in all_types_i:
|
75 |
+
entity_prompt.append(self.entity_token) # [ENT] token
|
76 |
+
entity_prompt.append(entity_type) # entity type
|
77 |
+
entity_prompt.append(self.sep_token) # [SEP] token
|
78 |
+
|
79 |
+
# prompt format:
|
80 |
+
# [ENT] entity_type [ENT] entity_type ... [ENT] entity_type [SEP]
|
81 |
+
|
82 |
+
# add prompt to the tokens
|
83 |
+
tokens_p = entity_prompt + x['tokens'][i]
|
84 |
+
|
85 |
+
# input format:
|
86 |
+
# [ENT] entity_type_1 [ENT] entity_type_2 ... [ENT] entity_type_m [SEP] token_1 token_2 ... token_n
|
87 |
+
|
88 |
+
# update length of the sequence (add prompt length to the original length)
|
89 |
+
new_length[i] = new_length[i] + len(entity_prompt)
|
90 |
+
# update tokens
|
91 |
+
new_tokens.append(tokens_p)
|
92 |
+
# store prompt length
|
93 |
+
all_len_prompt.append(len(entity_prompt))
|
94 |
+
|
95 |
+
# create a mask using num_classes_all (0, if it exceeds the number of classes, 1 otherwise)
|
96 |
+
max_num_classes = max(num_classes_all)
|
97 |
+
entity_type_mask = torch.arange(max_num_classes).unsqueeze(0).expand(len(num_classes_all), -1).to(
|
98 |
+
x['span_mask'].device)
|
99 |
+
entity_type_mask = entity_type_mask < torch.tensor(num_classes_all).unsqueeze(-1).to(
|
100 |
+
x['span_mask'].device) # [batch_size, max_num_classes]
|
101 |
+
|
102 |
+
# compute all token representations
|
103 |
+
bert_output = self.token_rep_layer(new_tokens, new_length)
|
104 |
+
word_rep_w_prompt = bert_output["embeddings"] # embeddings for all tokens (with prompt)
|
105 |
+
mask_w_prompt = bert_output["mask"] # mask for all tokens (with prompt)
|
106 |
+
|
107 |
+
# get word representation (after [SEP]), mask (after [SEP]) and entity type representation (before [SEP])
|
108 |
+
word_rep = [] # word representation (after [SEP])
|
109 |
+
mask = [] # mask (after [SEP])
|
110 |
+
entity_type_rep = [] # entity type representation (before [SEP])
|
111 |
+
for i in range(len(x['tokens'])):
|
112 |
+
prompt_entity_length = all_len_prompt[i] # length of prompt for this example
|
113 |
+
# get word representation (after [SEP])
|
114 |
+
word_rep.append(word_rep_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
|
115 |
+
# get mask (after [SEP])
|
116 |
+
mask.append(mask_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
|
117 |
+
|
118 |
+
# get entity type representation (before [SEP])
|
119 |
+
entity_rep = word_rep_w_prompt[i, :prompt_entity_length - 1] # remove [SEP]
|
120 |
+
entity_rep = entity_rep[0::2] # it means that we take every second element starting from the second one
|
121 |
+
entity_type_rep.append(entity_rep)
|
122 |
+
|
123 |
+
# padding for word_rep, mask and entity_type_rep
|
124 |
+
word_rep = pad_sequence(word_rep, batch_first=True) # [batch_size, seq_len, hidden_size]
|
125 |
+
mask = pad_sequence(mask, batch_first=True) # [batch_size, seq_len]
|
126 |
+
entity_type_rep = pad_sequence(entity_type_rep, batch_first=True) # [batch_size, len_types, hidden_size]
|
127 |
+
|
128 |
+
# compute span representation
|
129 |
+
word_rep = self.rnn(word_rep, mask)
|
130 |
+
span_rep = self.span_rep_layer(word_rep, span_idx)
|
131 |
+
|
132 |
+
# compute final entity type representation (FFN)
|
133 |
+
entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size)
|
134 |
+
num_classes = entity_type_rep.shape[1] # number of entity types
|
135 |
+
|
136 |
+
# similarity score
|
137 |
+
scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
|
138 |
+
|
139 |
+
return scores, num_classes, entity_type_mask
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
# compute span representation
|
143 |
+
scores, num_classes, entity_type_mask = self.compute_score_train(x)
|
144 |
+
batch_size = scores.shape[0]
|
145 |
+
|
146 |
+
# loss for filtering classifier
|
147 |
+
logits_label = scores.view(-1, num_classes)
|
148 |
+
labels = x["span_label"].view(-1) # (batch_size * num_spans)
|
149 |
+
mask_label = labels != -1 # (batch_size * num_spans)
|
150 |
+
labels.masked_fill_(~mask_label, 0) # Set the labels of padding tokens to 0
|
151 |
+
|
152 |
+
# one-hot encoding
|
153 |
+
labels_one_hot = torch.zeros(labels.size(0), num_classes + 1, dtype=torch.float32).to(scores.device)
|
154 |
+
labels_one_hot.scatter_(1, labels.unsqueeze(1), 1) # Set the corresponding index to 1
|
155 |
+
labels_one_hot = labels_one_hot[:, 1:] # Remove the first column
|
156 |
+
# Shape of labels_one_hot: (batch_size * num_spans, num_classes)
|
157 |
+
|
158 |
+
# compute loss (without reduction)
|
159 |
+
all_losses = F.binary_cross_entropy_with_logits(logits_label, labels_one_hot,
|
160 |
+
reduction='none')
|
161 |
+
# mask loss using entity_type_mask (B, C)
|
162 |
+
masked_loss = all_losses.view(batch_size, -1, num_classes) * entity_type_mask.unsqueeze(1)
|
163 |
+
all_losses = masked_loss.view(-1, num_classes)
|
164 |
+
# expand mask_label to all_losses
|
165 |
+
mask_label = mask_label.unsqueeze(-1).expand_as(all_losses)
|
166 |
+
# put lower loss for in label_one_hot (2 for positive, 1 for negative)
|
167 |
+
weight_c = labels_one_hot + 1
|
168 |
+
# apply mask
|
169 |
+
all_losses = all_losses * mask_label.float() * weight_c
|
170 |
+
return all_losses.sum()
|
171 |
+
|
172 |
+
def compute_score_eval(self, x, device):
|
173 |
+
# check if classes_to_id is dict
|
174 |
+
assert isinstance(x['classes_to_id'], dict), "classes_to_id must be a dict"
|
175 |
+
|
176 |
+
span_idx = (x['span_idx'] * x['span_mask'].unsqueeze(-1)).to(device)
|
177 |
+
|
178 |
+
all_types = list(x['classes_to_id'].keys())
|
179 |
+
# multiple entity types in all_types. Prompt is appended at the start of tokens
|
180 |
+
entity_prompt = []
|
181 |
+
|
182 |
+
# add enity types to prompt
|
183 |
+
for entity_type in all_types:
|
184 |
+
entity_prompt.append(self.entity_token)
|
185 |
+
entity_prompt.append(entity_type)
|
186 |
+
|
187 |
+
entity_prompt.append(self.sep_token)
|
188 |
+
|
189 |
+
prompt_entity_length = len(entity_prompt)
|
190 |
+
|
191 |
+
# add prompt
|
192 |
+
tokens_p = [entity_prompt + tokens for tokens in x['tokens']]
|
193 |
+
seq_length_p = x['seq_length'] + prompt_entity_length
|
194 |
+
|
195 |
+
out = self.token_rep_layer(tokens_p, seq_length_p)
|
196 |
+
|
197 |
+
word_rep_w_prompt = out["embeddings"]
|
198 |
+
mask_w_prompt = out["mask"]
|
199 |
+
|
200 |
+
# remove prompt
|
201 |
+
word_rep = word_rep_w_prompt[:, prompt_entity_length:, :]
|
202 |
+
mask = mask_w_prompt[:, prompt_entity_length:]
|
203 |
+
|
204 |
+
# get_entity_type_rep
|
205 |
+
entity_type_rep = word_rep_w_prompt[:, :prompt_entity_length - 1, :]
|
206 |
+
# extract [ENT] tokens (which are at even positions in entity_type_rep)
|
207 |
+
entity_type_rep = entity_type_rep[:, 0::2, :]
|
208 |
+
|
209 |
+
entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size)
|
210 |
+
|
211 |
+
word_rep = self.rnn(word_rep, mask)
|
212 |
+
|
213 |
+
span_rep = self.span_rep_layer(word_rep, span_idx)
|
214 |
+
|
215 |
+
local_scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
|
216 |
+
|
217 |
+
return local_scores
|
218 |
+
|
219 |
+
@torch.no_grad()
|
220 |
+
def predict(self, x, flat_ner=False, threshold=0.5):
|
221 |
+
self.eval()
|
222 |
+
local_scores = self.compute_score_eval(x, device=next(self.parameters()).device)
|
223 |
+
spans = []
|
224 |
+
for i, _ in enumerate(x["tokens"]):
|
225 |
+
local_i = local_scores[i]
|
226 |
+
wh_i = [i.tolist() for i in torch.where(torch.sigmoid(local_i) > threshold)]
|
227 |
+
span_i = []
|
228 |
+
for s, k, c in zip(*wh_i):
|
229 |
+
if s + k < len(x["tokens"][i]):
|
230 |
+
span_i.append((s, s + k, x["id_to_classes"][c + 1], local_i[s, k, c]))
|
231 |
+
span_i = greedy_search(span_i, flat_ner)
|
232 |
+
spans.append(span_i)
|
233 |
+
return spans
|
234 |
+
|
235 |
+
def predict_entities(self, text, labels, flat_ner=True, threshold=0.5):
|
236 |
+
tokens = []
|
237 |
+
start_token_idx_to_text_idx = []
|
238 |
+
end_token_idx_to_text_idx = []
|
239 |
+
for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text):
|
240 |
+
tokens.append(match.group())
|
241 |
+
start_token_idx_to_text_idx.append(match.start())
|
242 |
+
end_token_idx_to_text_idx.append(match.end())
|
243 |
+
|
244 |
+
input_x = {"tokenized_text": tokens, "ner": None}
|
245 |
+
x = self.collate_fn([input_x], labels)
|
246 |
+
output = self.predict(x, flat_ner=flat_ner, threshold=threshold)
|
247 |
+
|
248 |
+
entities = []
|
249 |
+
for start_token_idx, end_token_idx, ent_type in output[0]:
|
250 |
+
start_text_idx = start_token_idx_to_text_idx[start_token_idx]
|
251 |
+
end_text_idx = end_token_idx_to_text_idx[end_token_idx]
|
252 |
+
entities.append({
|
253 |
+
"start": start_token_idx_to_text_idx[start_token_idx],
|
254 |
+
"end": end_token_idx_to_text_idx[end_token_idx],
|
255 |
+
"text": text[start_text_idx:end_text_idx],
|
256 |
+
"label": ent_type,
|
257 |
+
})
|
258 |
+
return entities
|
259 |
+
|
260 |
+
def evaluate(self, test_data, flat_ner=False, threshold=0.5, batch_size=12, entity_types=None):
|
261 |
+
self.eval()
|
262 |
+
data_loader = self.create_dataloader(test_data, batch_size=batch_size, entity_types=entity_types, shuffle=False)
|
263 |
+
device = next(self.parameters()).device
|
264 |
+
all_preds = []
|
265 |
+
all_trues = []
|
266 |
+
for x in data_loader:
|
267 |
+
for k, v in x.items():
|
268 |
+
if isinstance(v, torch.Tensor):
|
269 |
+
x[k] = v.to(device)
|
270 |
+
batch_predictions = self.predict(x, flat_ner, threshold)
|
271 |
+
all_preds.extend(batch_predictions)
|
272 |
+
all_trues.extend(x["entities"])
|
273 |
+
evaluator = Evaluator(all_trues, all_preds)
|
274 |
+
out, f1 = evaluator.evaluate()
|
275 |
+
return out, f1
|
276 |
+
|
277 |
+
@classmethod
|
278 |
+
def _from_pretrained(
|
279 |
+
cls,
|
280 |
+
*,
|
281 |
+
model_id: str,
|
282 |
+
revision: Optional[str],
|
283 |
+
cache_dir: Optional[Union[str, Path]],
|
284 |
+
force_download: bool,
|
285 |
+
proxies: Optional[Dict],
|
286 |
+
resume_download: bool,
|
287 |
+
local_files_only: bool,
|
288 |
+
token: Union[str, bool, None],
|
289 |
+
map_location: str = "cpu",
|
290 |
+
strict: bool = False,
|
291 |
+
**model_kwargs,
|
292 |
+
):
|
293 |
+
# 1. Backwards compatibility: Use "gliner_base.pt" and "gliner_multi.pt" with all data
|
294 |
+
filenames = ["gliner_base.pt", "gliner_multi.pt"]
|
295 |
+
for filename in filenames:
|
296 |
+
model_file = Path(model_id) / filename
|
297 |
+
if not model_file.exists():
|
298 |
+
try:
|
299 |
+
model_file = hf_hub_download(
|
300 |
+
repo_id=model_id,
|
301 |
+
filename=filename,
|
302 |
+
revision=revision,
|
303 |
+
cache_dir=cache_dir,
|
304 |
+
force_download=force_download,
|
305 |
+
proxies=proxies,
|
306 |
+
resume_download=resume_download,
|
307 |
+
token=token,
|
308 |
+
local_files_only=local_files_only,
|
309 |
+
)
|
310 |
+
except HfHubHTTPError:
|
311 |
+
continue
|
312 |
+
dict_load = torch.load(model_file, map_location=torch.device(map_location))
|
313 |
+
config = dict_load["config"]
|
314 |
+
state_dict = dict_load["model_weights"]
|
315 |
+
config.model_name = "microsoft/deberta-v3-base" if filename == "gliner_base.pt" else "microsoft/mdeberta-v3-base"
|
316 |
+
model = cls(config)
|
317 |
+
model.load_state_dict(state_dict, strict=strict, assign=True)
|
318 |
+
# Required to update flair's internals as well:
|
319 |
+
model.to(map_location)
|
320 |
+
return model
|
321 |
+
|
322 |
+
# 2. Newer format: Use "pytorch_model.bin" and "gliner_config.json"
|
323 |
+
from train import load_config_as_namespace
|
324 |
+
|
325 |
+
model_file = Path(model_id) / "pytorch_model.bin"
|
326 |
+
if not model_file.exists():
|
327 |
+
model_file = hf_hub_download(
|
328 |
+
repo_id=model_id,
|
329 |
+
filename="pytorch_model.bin",
|
330 |
+
revision=revision,
|
331 |
+
cache_dir=cache_dir,
|
332 |
+
force_download=force_download,
|
333 |
+
proxies=proxies,
|
334 |
+
resume_download=resume_download,
|
335 |
+
token=token,
|
336 |
+
local_files_only=local_files_only,
|
337 |
+
)
|
338 |
+
config_file = Path(model_id) / "gliner_config.json"
|
339 |
+
if not config_file.exists():
|
340 |
+
config_file = hf_hub_download(
|
341 |
+
repo_id=model_id,
|
342 |
+
filename="gliner_config.json",
|
343 |
+
revision=revision,
|
344 |
+
cache_dir=cache_dir,
|
345 |
+
force_download=force_download,
|
346 |
+
proxies=proxies,
|
347 |
+
resume_download=resume_download,
|
348 |
+
token=token,
|
349 |
+
local_files_only=local_files_only,
|
350 |
+
)
|
351 |
+
config = load_config_as_namespace(config_file)
|
352 |
+
model = cls(config)
|
353 |
+
state_dict = torch.load(model_file, map_location=torch.device(map_location))
|
354 |
+
model.load_state_dict(state_dict, strict=strict, assign=True)
|
355 |
+
model.to(map_location)
|
356 |
+
return model
|
357 |
+
|
358 |
+
def save_pretrained(
|
359 |
+
self,
|
360 |
+
save_directory: Union[str, Path],
|
361 |
+
*,
|
362 |
+
config: Optional[Union[dict, "DataclassInstance"]] = None,
|
363 |
+
repo_id: Optional[str] = None,
|
364 |
+
push_to_hub: bool = False,
|
365 |
+
**push_to_hub_kwargs,
|
366 |
+
) -> Optional[str]:
|
367 |
+
"""
|
368 |
+
Save weights in local directory.
|
369 |
+
|
370 |
+
Args:
|
371 |
+
save_directory (`str` or `Path`):
|
372 |
+
Path to directory in which the model weights and configuration will be saved.
|
373 |
+
config (`dict` or `DataclassInstance`, *optional*):
|
374 |
+
Model configuration specified as a key/value dictionary or a dataclass instance.
|
375 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
376 |
+
Whether or not to push your model to the Huggingface Hub after saving it.
|
377 |
+
repo_id (`str`, *optional*):
|
378 |
+
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
|
379 |
+
not provided.
|
380 |
+
kwargs:
|
381 |
+
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
|
382 |
+
"""
|
383 |
+
save_directory = Path(save_directory)
|
384 |
+
save_directory.mkdir(parents=True, exist_ok=True)
|
385 |
+
|
386 |
+
# save model weights/files
|
387 |
+
torch.save(self.state_dict(), save_directory / "pytorch_model.bin")
|
388 |
+
|
389 |
+
# save config (if provided)
|
390 |
+
if config is None:
|
391 |
+
config = self.config
|
392 |
+
if config is not None:
|
393 |
+
if isinstance(config, argparse.Namespace):
|
394 |
+
config = vars(config)
|
395 |
+
(save_directory / "gliner_config.json").write_text(json.dumps(config, indent=2))
|
396 |
+
|
397 |
+
# push to the Hub if required
|
398 |
+
if push_to_hub:
|
399 |
+
kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
|
400 |
+
if config is not None: # kwarg for `push_to_hub`
|
401 |
+
kwargs["config"] = config
|
402 |
+
if repo_id is None:
|
403 |
+
repo_id = save_directory.name # Defaults to `save_directory` name
|
404 |
+
return self.push_to_hub(repo_id=repo_id, **kwargs)
|
405 |
+
return None
|
406 |
+
|
407 |
+
def to(self, device):
|
408 |
+
super().to(device)
|
409 |
+
import flair
|
410 |
+
|
411 |
+
flair.device = device
|
412 |
+
return self
|
GLiNER/modules/base.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from typing import List, Tuple, Dict
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn.utils.rnn import pad_sequence
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
import random
|
9 |
+
|
10 |
+
|
11 |
+
class InstructBase(nn.Module):
|
12 |
+
def __init__(self, config):
|
13 |
+
super().__init__()
|
14 |
+
self.max_width = config.max_width
|
15 |
+
self.base_config = config
|
16 |
+
|
17 |
+
def get_dict(self, spans, classes_to_id):
|
18 |
+
dict_tag = defaultdict(int)
|
19 |
+
for span in spans:
|
20 |
+
if span[2] in classes_to_id:
|
21 |
+
dict_tag[(span[0], span[1])] = classes_to_id[span[2]]
|
22 |
+
return dict_tag
|
23 |
+
|
24 |
+
def preprocess_spans(self, tokens, ner, classes_to_id):
|
25 |
+
|
26 |
+
max_len = self.base_config.max_len
|
27 |
+
|
28 |
+
if len(tokens) > max_len:
|
29 |
+
length = max_len
|
30 |
+
tokens = tokens[:max_len]
|
31 |
+
else:
|
32 |
+
length = len(tokens)
|
33 |
+
|
34 |
+
spans_idx = []
|
35 |
+
for i in range(length):
|
36 |
+
spans_idx.extend([(i, i + j) for j in range(self.max_width)])
|
37 |
+
|
38 |
+
dict_lab = self.get_dict(ner, classes_to_id) if ner else defaultdict(int)
|
39 |
+
|
40 |
+
# 0 for null labels
|
41 |
+
span_label = torch.LongTensor([dict_lab[i] for i in spans_idx])
|
42 |
+
spans_idx = torch.LongTensor(spans_idx)
|
43 |
+
|
44 |
+
# mask for valid spans
|
45 |
+
valid_span_mask = spans_idx[:, 1] > length - 1
|
46 |
+
|
47 |
+
# mask invalid positions
|
48 |
+
span_label = span_label.masked_fill(valid_span_mask, -1)
|
49 |
+
|
50 |
+
return {
|
51 |
+
'tokens': tokens,
|
52 |
+
'span_idx': spans_idx,
|
53 |
+
'span_label': span_label,
|
54 |
+
'seq_length': length,
|
55 |
+
'entities': ner,
|
56 |
+
}
|
57 |
+
|
58 |
+
def collate_fn(self, batch_list, entity_types=None):
|
59 |
+
# batch_list: list of dict containing tokens, ner
|
60 |
+
if entity_types is None:
|
61 |
+
negs = self.get_negatives(batch_list, 100)
|
62 |
+
class_to_ids = []
|
63 |
+
id_to_classes = []
|
64 |
+
for b in batch_list:
|
65 |
+
# negs = b["negative"]
|
66 |
+
random.shuffle(negs)
|
67 |
+
|
68 |
+
# negs = negs[:sampled_neg]
|
69 |
+
max_neg_type_ratio = int(self.base_config.max_neg_type_ratio)
|
70 |
+
|
71 |
+
if max_neg_type_ratio == 0:
|
72 |
+
# no negatives
|
73 |
+
neg_type_ratio = 0
|
74 |
+
else:
|
75 |
+
neg_type_ratio = random.randint(0, max_neg_type_ratio)
|
76 |
+
|
77 |
+
if neg_type_ratio == 0:
|
78 |
+
# no negatives
|
79 |
+
negs_i = []
|
80 |
+
else:
|
81 |
+
negs_i = negs[:len(b['ner']) * neg_type_ratio]
|
82 |
+
|
83 |
+
# this is the list of all possible entity types (positive and negative)
|
84 |
+
types = list(set([el[-1] for el in b['ner']] + negs_i))
|
85 |
+
|
86 |
+
# shuffle (every epoch)
|
87 |
+
random.shuffle(types)
|
88 |
+
|
89 |
+
if len(types) != 0:
|
90 |
+
# prob of higher number shoul
|
91 |
+
# random drop
|
92 |
+
if self.base_config.random_drop:
|
93 |
+
num_ents = random.randint(1, len(types))
|
94 |
+
types = types[:num_ents]
|
95 |
+
|
96 |
+
# maximum number of entities types
|
97 |
+
types = types[:int(self.base_config.max_types)]
|
98 |
+
|
99 |
+
# supervised training
|
100 |
+
if "label" in b:
|
101 |
+
types = sorted(b["label"])
|
102 |
+
|
103 |
+
class_to_id = {k: v for v, k in enumerate(types, start=1)}
|
104 |
+
id_to_class = {k: v for v, k in class_to_id.items()}
|
105 |
+
class_to_ids.append(class_to_id)
|
106 |
+
id_to_classes.append(id_to_class)
|
107 |
+
|
108 |
+
batch = [
|
109 |
+
self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids[i]) for i, b in enumerate(batch_list)
|
110 |
+
]
|
111 |
+
|
112 |
+
else:
|
113 |
+
class_to_ids = {k: v for v, k in enumerate(entity_types, start=1)}
|
114 |
+
id_to_classes = {k: v for v, k in class_to_ids.items()}
|
115 |
+
batch = [
|
116 |
+
self.preprocess_spans(b["tokenized_text"], b["ner"], class_to_ids) for b in batch_list
|
117 |
+
]
|
118 |
+
|
119 |
+
span_idx = pad_sequence(
|
120 |
+
[b['span_idx'] for b in batch], batch_first=True, padding_value=0
|
121 |
+
)
|
122 |
+
|
123 |
+
span_label = pad_sequence(
|
124 |
+
[el['span_label'] for el in batch], batch_first=True, padding_value=-1
|
125 |
+
)
|
126 |
+
|
127 |
+
return {
|
128 |
+
'seq_length': torch.LongTensor([el['seq_length'] for el in batch]),
|
129 |
+
'span_idx': span_idx,
|
130 |
+
'tokens': [el['tokens'] for el in batch],
|
131 |
+
'span_mask': span_label != -1,
|
132 |
+
'span_label': span_label,
|
133 |
+
'entities': [el['entities'] for el in batch],
|
134 |
+
'classes_to_id': class_to_ids,
|
135 |
+
'id_to_classes': id_to_classes,
|
136 |
+
}
|
137 |
+
|
138 |
+
@staticmethod
|
139 |
+
def get_negatives(batch_list, sampled_neg=5):
|
140 |
+
ent_types = []
|
141 |
+
for b in batch_list:
|
142 |
+
types = set([el[-1] for el in b['ner']])
|
143 |
+
ent_types.extend(list(types))
|
144 |
+
ent_types = list(set(ent_types))
|
145 |
+
# sample negatives
|
146 |
+
random.shuffle(ent_types)
|
147 |
+
return ent_types[:sampled_neg]
|
148 |
+
|
149 |
+
def create_dataloader(self, data, entity_types=None, **kwargs):
|
150 |
+
return DataLoader(data, collate_fn=lambda x: self.collate_fn(x, entity_types), **kwargs)
|
GLiNER/modules/data_proc.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from tqdm import tqdm
|
3 |
+
# ast.literal_eval
|
4 |
+
import ast, re
|
5 |
+
|
6 |
+
path = 'train.json'
|
7 |
+
|
8 |
+
with open(path, 'r') as f:
|
9 |
+
data = json.load(f)
|
10 |
+
|
11 |
+
def tokenize_text(text):
|
12 |
+
return re.findall(r'\w+(?:[-_]\w+)*|\S', text)
|
13 |
+
|
14 |
+
def extract_entity_spans(entry):
|
15 |
+
text = ""
|
16 |
+
len_start = len("What describes ")
|
17 |
+
len_end = len(" in the text?")
|
18 |
+
entity_types = []
|
19 |
+
entity_texts = []
|
20 |
+
|
21 |
+
for c in entry['conversations']:
|
22 |
+
if c['from'] == 'human' and c['value'].startswith('Text: '):
|
23 |
+
text = c['value'][len('Text: '):]
|
24 |
+
tokenized_text = tokenize_text(text)
|
25 |
+
|
26 |
+
if c['from'] == 'human' and c['value'].startswith('What describes '):
|
27 |
+
|
28 |
+
c_type = c['value'][len_start:-len_end]
|
29 |
+
c_type = c_type.replace(' ', '_')
|
30 |
+
entity_types.append(c_type)
|
31 |
+
|
32 |
+
elif c['from'] == 'gpt' and c['value'].startswith('['):
|
33 |
+
if c['value'] == '[]':
|
34 |
+
entity_types = entity_types[:-1]
|
35 |
+
continue
|
36 |
+
|
37 |
+
texts_ents = ast.literal_eval(c['value'])
|
38 |
+
# replace space to _ in texts_ents
|
39 |
+
entity_texts.extend(texts_ents)
|
40 |
+
num_repeat = len(texts_ents) - 1
|
41 |
+
entity_types.extend([entity_types[-1]] * num_repeat)
|
42 |
+
|
43 |
+
entity_spans = []
|
44 |
+
for j, entity_text in enumerate(entity_texts):
|
45 |
+
entity_tokens = tokenize_text(entity_text)
|
46 |
+
matches = []
|
47 |
+
for i in range(len(tokenized_text) - len(entity_tokens) + 1):
|
48 |
+
if " ".join(tokenized_text[i:i + len(entity_tokens)]).lower() == " ".join(entity_tokens).lower():
|
49 |
+
matches.append((i, i + len(entity_tokens) - 1, entity_types[j]))
|
50 |
+
if matches:
|
51 |
+
entity_spans.extend(matches)
|
52 |
+
|
53 |
+
return entity_spans, tokenized_text
|
54 |
+
|
55 |
+
# Usage:
|
56 |
+
# Replace 'entry' with the specific entry from your JSON data
|
57 |
+
entry = data[17818] # For example, taking the first entry
|
58 |
+
entity_spans, tokenized_text = extract_entity_spans(entry)
|
59 |
+
print("Entity Spans:", entity_spans)
|
60 |
+
#print("Tokenized Text:", tokenized_text)
|
61 |
+
|
62 |
+
# create a dict: {"tokenized_text": tokenized_text, "entity_spans": entity_spans}
|
63 |
+
|
64 |
+
all_data = []
|
65 |
+
|
66 |
+
for entry in tqdm(data):
|
67 |
+
entity_spans, tokenized_text = extract_entity_spans(entry)
|
68 |
+
all_data.append({"tokenized_text": tokenized_text, "ner": entity_spans})
|
69 |
+
|
70 |
+
|
71 |
+
with open('train_instruct.json', 'w') as f:
|
72 |
+
json.dump(all_data, f)
|
73 |
+
|
GLiNER/modules/evaluator.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from seqeval.metrics.v1 import _prf_divide
|
6 |
+
|
7 |
+
|
8 |
+
def extract_tp_actual_correct(y_true, y_pred):
|
9 |
+
entities_true = defaultdict(set)
|
10 |
+
entities_pred = defaultdict(set)
|
11 |
+
|
12 |
+
for type_name, (start, end), idx in y_true:
|
13 |
+
entities_true[type_name].add((start, end, idx))
|
14 |
+
for type_name, (start, end), idx in y_pred:
|
15 |
+
entities_pred[type_name].add((start, end, idx))
|
16 |
+
|
17 |
+
target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys()))
|
18 |
+
|
19 |
+
tp_sum = np.array([], dtype=np.int32)
|
20 |
+
pred_sum = np.array([], dtype=np.int32)
|
21 |
+
true_sum = np.array([], dtype=np.int32)
|
22 |
+
for type_name in target_names:
|
23 |
+
entities_true_type = entities_true.get(type_name, set())
|
24 |
+
entities_pred_type = entities_pred.get(type_name, set())
|
25 |
+
tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type))
|
26 |
+
pred_sum = np.append(pred_sum, len(entities_pred_type))
|
27 |
+
true_sum = np.append(true_sum, len(entities_true_type))
|
28 |
+
|
29 |
+
return pred_sum, tp_sum, true_sum, target_names
|
30 |
+
|
31 |
+
|
32 |
+
def flatten_for_eval(y_true, y_pred):
|
33 |
+
all_true = []
|
34 |
+
all_pred = []
|
35 |
+
|
36 |
+
for i, (true, pred) in enumerate(zip(y_true, y_pred)):
|
37 |
+
all_true.extend([t + [i] for t in true])
|
38 |
+
all_pred.extend([p + [i] for p in pred])
|
39 |
+
|
40 |
+
return all_true, all_pred
|
41 |
+
|
42 |
+
|
43 |
+
def compute_prf(y_true, y_pred, average='micro'):
|
44 |
+
y_true, y_pred = flatten_for_eval(y_true, y_pred)
|
45 |
+
|
46 |
+
pred_sum, tp_sum, true_sum, target_names = extract_tp_actual_correct(y_true, y_pred)
|
47 |
+
|
48 |
+
if average == 'micro':
|
49 |
+
tp_sum = np.array([tp_sum.sum()])
|
50 |
+
pred_sum = np.array([pred_sum.sum()])
|
51 |
+
true_sum = np.array([true_sum.sum()])
|
52 |
+
|
53 |
+
precision = _prf_divide(
|
54 |
+
numerator=tp_sum,
|
55 |
+
denominator=pred_sum,
|
56 |
+
metric='precision',
|
57 |
+
modifier='predicted',
|
58 |
+
average=average,
|
59 |
+
warn_for=('precision', 'recall', 'f-score'),
|
60 |
+
zero_division='warn'
|
61 |
+
)
|
62 |
+
|
63 |
+
recall = _prf_divide(
|
64 |
+
numerator=tp_sum,
|
65 |
+
denominator=true_sum,
|
66 |
+
metric='recall',
|
67 |
+
modifier='true',
|
68 |
+
average=average,
|
69 |
+
warn_for=('precision', 'recall', 'f-score'),
|
70 |
+
zero_division='warn'
|
71 |
+
)
|
72 |
+
|
73 |
+
denominator = precision + recall
|
74 |
+
denominator[denominator == 0.] = 1
|
75 |
+
f_score = 2 * (precision * recall) / denominator
|
76 |
+
|
77 |
+
return {'precision': precision[0], 'recall': recall[0], 'f_score': f_score[0]}
|
78 |
+
|
79 |
+
|
80 |
+
class Evaluator:
|
81 |
+
def __init__(self, all_true, all_outs):
|
82 |
+
self.all_true = all_true
|
83 |
+
self.all_outs = all_outs
|
84 |
+
|
85 |
+
def get_entities_fr(self, ents):
|
86 |
+
all_ents = []
|
87 |
+
for s, e, lab in ents:
|
88 |
+
all_ents.append([lab, (s, e)])
|
89 |
+
return all_ents
|
90 |
+
|
91 |
+
def transform_data(self):
|
92 |
+
all_true_ent = []
|
93 |
+
all_outs_ent = []
|
94 |
+
for i, j in zip(self.all_true, self.all_outs):
|
95 |
+
e = self.get_entities_fr(i)
|
96 |
+
all_true_ent.append(e)
|
97 |
+
e = self.get_entities_fr(j)
|
98 |
+
all_outs_ent.append(e)
|
99 |
+
return all_true_ent, all_outs_ent
|
100 |
+
|
101 |
+
@torch.no_grad()
|
102 |
+
def evaluate(self):
|
103 |
+
all_true_typed, all_outs_typed = self.transform_data()
|
104 |
+
precision, recall, f1 = compute_prf(all_true_typed, all_outs_typed).values()
|
105 |
+
output_str = f"P: {precision:.2%}\tR: {recall:.2%}\tF1: {f1:.2%}\n"
|
106 |
+
return output_str, f1
|
107 |
+
|
108 |
+
|
109 |
+
def is_nested(idx1, idx2):
|
110 |
+
# Return True if idx2 is nested inside idx1 or vice versa
|
111 |
+
return (idx1[0] <= idx2[0] and idx1[1] >= idx2[1]) or (idx2[0] <= idx1[0] and idx2[1] >= idx1[1])
|
112 |
+
|
113 |
+
|
114 |
+
def has_overlapping(idx1, idx2):
|
115 |
+
overlapping = True
|
116 |
+
if idx1[:2] == idx2[:2]:
|
117 |
+
return overlapping
|
118 |
+
if (idx1[0] > idx2[1] or idx2[0] > idx1[1]):
|
119 |
+
overlapping = False
|
120 |
+
return overlapping
|
121 |
+
|
122 |
+
|
123 |
+
def has_overlapping_nested(idx1, idx2):
|
124 |
+
# Return True if idx1 and idx2 overlap, but neither is nested inside the other
|
125 |
+
if idx1[:2] == idx2[:2]:
|
126 |
+
return True
|
127 |
+
if ((idx1[0] > idx2[1] or idx2[0] > idx1[1]) or is_nested(idx1, idx2)) and idx1 != idx2:
|
128 |
+
return False
|
129 |
+
else:
|
130 |
+
return True
|
131 |
+
|
132 |
+
|
133 |
+
def greedy_search(spans, flat_ner=True): # start, end, class, score
|
134 |
+
|
135 |
+
if flat_ner:
|
136 |
+
has_ov = has_overlapping
|
137 |
+
else:
|
138 |
+
has_ov = has_overlapping_nested
|
139 |
+
|
140 |
+
new_list = []
|
141 |
+
span_prob = sorted(spans, key=lambda x: -x[-1])
|
142 |
+
for i in range(len(spans)):
|
143 |
+
b = span_prob[i]
|
144 |
+
flag = False
|
145 |
+
for new in new_list:
|
146 |
+
if has_ov(b[:-1], new):
|
147 |
+
flag = True
|
148 |
+
break
|
149 |
+
if not flag:
|
150 |
+
new_list.append(b[:-1])
|
151 |
+
new_list = sorted(new_list, key=lambda x: x[0])
|
152 |
+
return new_list
|
GLiNER/modules/layers.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
5 |
+
|
6 |
+
|
7 |
+
class LstmSeq2SeqEncoder(nn.Module):
|
8 |
+
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0., bidirectional=False):
|
9 |
+
super(LstmSeq2SeqEncoder, self).__init__()
|
10 |
+
self.lstm = nn.LSTM(input_size=input_size,
|
11 |
+
hidden_size=hidden_size,
|
12 |
+
num_layers=num_layers,
|
13 |
+
dropout=dropout,
|
14 |
+
bidirectional=bidirectional,
|
15 |
+
batch_first=True)
|
16 |
+
|
17 |
+
def forward(self, x, mask, hidden=None):
|
18 |
+
# Packing the input sequence
|
19 |
+
lengths = mask.sum(dim=1).cpu()
|
20 |
+
packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
|
21 |
+
|
22 |
+
# Passing packed sequence through LSTM
|
23 |
+
packed_output, hidden = self.lstm(packed_x, hidden)
|
24 |
+
|
25 |
+
# Unpacking the output sequence
|
26 |
+
output, _ = pad_packed_sequence(packed_output, batch_first=True)
|
27 |
+
|
28 |
+
return output
|
GLiNER/modules/run_evaluation.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import os
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from tqdm import tqdm
|
8 |
+
import random
|
9 |
+
|
10 |
+
|
11 |
+
def open_content(path):
|
12 |
+
paths = glob.glob(os.path.join(path, "*.json"))
|
13 |
+
train, dev, test, labels = None, None, None, None
|
14 |
+
for p in paths:
|
15 |
+
if "train" in p:
|
16 |
+
with open(p, "r") as f:
|
17 |
+
train = json.load(f)
|
18 |
+
elif "dev" in p:
|
19 |
+
with open(p, "r") as f:
|
20 |
+
dev = json.load(f)
|
21 |
+
elif "test" in p:
|
22 |
+
with open(p, "r") as f:
|
23 |
+
test = json.load(f)
|
24 |
+
elif "labels" in p:
|
25 |
+
with open(p, "r") as f:
|
26 |
+
labels = json.load(f)
|
27 |
+
return train, dev, test, labels
|
28 |
+
|
29 |
+
|
30 |
+
def process(data):
|
31 |
+
words = data['sentence'].split()
|
32 |
+
entities = [] # List of entities (start, end, type)
|
33 |
+
|
34 |
+
for entity in data['entities']:
|
35 |
+
start_char, end_char = entity['pos']
|
36 |
+
|
37 |
+
# Initialize variables to keep track of word positions
|
38 |
+
start_word = None
|
39 |
+
end_word = None
|
40 |
+
|
41 |
+
# Iterate through words and find the word positions
|
42 |
+
char_count = 0
|
43 |
+
for i, word in enumerate(words):
|
44 |
+
word_length = len(word)
|
45 |
+
if char_count == start_char:
|
46 |
+
start_word = i
|
47 |
+
if char_count + word_length == end_char:
|
48 |
+
end_word = i
|
49 |
+
break
|
50 |
+
char_count += word_length + 1 # Add 1 for the space
|
51 |
+
|
52 |
+
# Append the word positions to the list
|
53 |
+
entities.append((start_word, end_word, entity['type']))
|
54 |
+
|
55 |
+
# Create a list of word positions for each entity
|
56 |
+
sample = {
|
57 |
+
"tokenized_text": words,
|
58 |
+
"ner": entities
|
59 |
+
}
|
60 |
+
|
61 |
+
return sample
|
62 |
+
|
63 |
+
|
64 |
+
# create dataset
|
65 |
+
def create_dataset(path):
|
66 |
+
train, dev, test, labels = open_content(path)
|
67 |
+
train_dataset = []
|
68 |
+
dev_dataset = []
|
69 |
+
test_dataset = []
|
70 |
+
for data in train:
|
71 |
+
train_dataset.append(process(data))
|
72 |
+
for data in dev:
|
73 |
+
dev_dataset.append(process(data))
|
74 |
+
for data in test:
|
75 |
+
test_dataset.append(process(data))
|
76 |
+
return train_dataset, dev_dataset, test_dataset, labels
|
77 |
+
|
78 |
+
|
79 |
+
@torch.no_grad()
|
80 |
+
def get_for_one_path(path, model):
|
81 |
+
# load the dataset
|
82 |
+
_, _, test_dataset, entity_types = create_dataset(path)
|
83 |
+
|
84 |
+
data_name = path.split("/")[-1] # get the name of the dataset
|
85 |
+
|
86 |
+
# check if the dataset is flat_ner
|
87 |
+
flat_ner = True
|
88 |
+
if any([i in data_name for i in ["ACE", "GENIA", "Corpus"]]):
|
89 |
+
flat_ner = False
|
90 |
+
|
91 |
+
# evaluate the model
|
92 |
+
results, f1 = model.evaluate(test_dataset, flat_ner=flat_ner, threshold=0.5, batch_size=12,
|
93 |
+
entity_types=entity_types)
|
94 |
+
return data_name, results, f1
|
95 |
+
|
96 |
+
|
97 |
+
def get_for_all_path(model, steps, log_dir, data_paths):
|
98 |
+
all_paths = glob.glob(f"{data_paths}/*")
|
99 |
+
|
100 |
+
all_paths = sorted(all_paths)
|
101 |
+
|
102 |
+
# move the model to the device
|
103 |
+
device = next(model.parameters()).device
|
104 |
+
model.to(device)
|
105 |
+
# set the model to eval mode
|
106 |
+
model.eval()
|
107 |
+
|
108 |
+
# log the results
|
109 |
+
save_path = os.path.join(log_dir, "results.txt")
|
110 |
+
|
111 |
+
with open(save_path, "a") as f:
|
112 |
+
f.write("##############################################\n")
|
113 |
+
# write step
|
114 |
+
f.write("step: " + str(steps) + "\n")
|
115 |
+
|
116 |
+
zero_shot_benc = ["mit-movie", "mit-restaurant", "CrossNER_AI", "CrossNER_literature", "CrossNER_music",
|
117 |
+
"CrossNER_politics", "CrossNER_science"]
|
118 |
+
|
119 |
+
zero_shot_benc_results = {}
|
120 |
+
all_results = {} # without crossNER
|
121 |
+
|
122 |
+
for p in tqdm(all_paths):
|
123 |
+
if "sample_" not in p:
|
124 |
+
data_name, results, f1 = get_for_one_path(p, model)
|
125 |
+
# write to file
|
126 |
+
with open(save_path, "a") as f:
|
127 |
+
f.write(data_name + "\n")
|
128 |
+
f.write(str(results) + "\n")
|
129 |
+
|
130 |
+
if data_name in zero_shot_benc:
|
131 |
+
zero_shot_benc_results[data_name] = f1
|
132 |
+
else:
|
133 |
+
all_results[data_name] = f1
|
134 |
+
|
135 |
+
avg_all = sum(all_results.values()) / len(all_results)
|
136 |
+
avg_zs = sum(zero_shot_benc_results.values()) / len(zero_shot_benc_results)
|
137 |
+
|
138 |
+
save_path_table = os.path.join(log_dir, "tables.txt")
|
139 |
+
|
140 |
+
# results for all datasets except crossNER
|
141 |
+
table_bench_all = ""
|
142 |
+
for k, v in all_results.items():
|
143 |
+
table_bench_all += f"{k:20}: {v:.1%}\n"
|
144 |
+
# (20 size aswell for average i.e. :20)
|
145 |
+
table_bench_all += f"{'Average':20}: {avg_all:.1%}"
|
146 |
+
|
147 |
+
# results for zero-shot benchmark
|
148 |
+
table_bench_zeroshot = ""
|
149 |
+
for k, v in zero_shot_benc_results.items():
|
150 |
+
table_bench_zeroshot += f"{k:20}: {v:.1%}\n"
|
151 |
+
table_bench_zeroshot += f"{'Average':20}: {avg_zs:.1%}"
|
152 |
+
|
153 |
+
# write to file
|
154 |
+
with open(save_path_table, "a") as f:
|
155 |
+
f.write("##############################################\n")
|
156 |
+
f.write("step: " + str(steps) + "\n")
|
157 |
+
f.write("Table for all datasets except crossNER\n")
|
158 |
+
f.write(table_bench_all + "\n\n")
|
159 |
+
f.write("Table for zero-shot benchmark\n")
|
160 |
+
f.write(table_bench_zeroshot + "\n")
|
161 |
+
f.write("##############################################\n\n")
|
162 |
+
|
163 |
+
|
164 |
+
def sample_train_data(data_paths, sample_size=10000):
|
165 |
+
all_paths = glob.glob(f"{data_paths}/*")
|
166 |
+
|
167 |
+
all_paths = sorted(all_paths)
|
168 |
+
|
169 |
+
# to exclude the zero-shot benchmark datasets
|
170 |
+
zero_shot_benc = ["CrossNER_AI", "CrossNER_literature", "CrossNER_music",
|
171 |
+
"CrossNER_politics", "CrossNER_science", "ACE 2004"]
|
172 |
+
|
173 |
+
new_train = []
|
174 |
+
# take 10k samples from each dataset
|
175 |
+
for p in tqdm(all_paths):
|
176 |
+
if any([i in p for i in zero_shot_benc]):
|
177 |
+
continue
|
178 |
+
train, dev, test, labels = create_dataset(p)
|
179 |
+
|
180 |
+
# add label key to the train data
|
181 |
+
for i in range(len(train)):
|
182 |
+
train[i]["label"] = labels
|
183 |
+
|
184 |
+
random.shuffle(train)
|
185 |
+
train = train[:sample_size]
|
186 |
+
new_train.extend(train)
|
187 |
+
|
188 |
+
return new_train
|
GLiNER/modules/span_rep.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
def create_projection_layer(hidden_size: int, dropout: float, out_dim: int = None) -> nn.Sequential:
|
6 |
+
"""
|
7 |
+
Creates a projection layer with specified configurations.
|
8 |
+
"""
|
9 |
+
if out_dim is None:
|
10 |
+
out_dim = hidden_size
|
11 |
+
|
12 |
+
return nn.Sequential(
|
13 |
+
nn.Linear(hidden_size, out_dim * 4),
|
14 |
+
nn.ReLU(),
|
15 |
+
nn.Dropout(dropout),
|
16 |
+
nn.Linear(out_dim * 4, out_dim)
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class SpanQuery(nn.Module):
|
21 |
+
|
22 |
+
def __init__(self, hidden_size, max_width, trainable=True):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
self.query_seg = nn.Parameter(torch.randn(hidden_size, max_width))
|
26 |
+
|
27 |
+
nn.init.uniform_(self.query_seg, a=-1, b=1)
|
28 |
+
|
29 |
+
if not trainable:
|
30 |
+
self.query_seg.requires_grad = False
|
31 |
+
|
32 |
+
self.project = nn.Sequential(
|
33 |
+
nn.Linear(hidden_size, hidden_size),
|
34 |
+
nn.ReLU()
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, h, *args):
|
38 |
+
# h of shape [B, L, D]
|
39 |
+
# query_seg of shape [D, max_width]
|
40 |
+
|
41 |
+
span_rep = torch.einsum('bld, ds->blsd', h, self.query_seg)
|
42 |
+
|
43 |
+
return self.project(span_rep)
|
44 |
+
|
45 |
+
|
46 |
+
class SpanMLP(nn.Module):
|
47 |
+
|
48 |
+
def __init__(self, hidden_size, max_width):
|
49 |
+
super().__init__()
|
50 |
+
|
51 |
+
self.mlp = nn.Linear(hidden_size, hidden_size * max_width)
|
52 |
+
|
53 |
+
def forward(self, h, *args):
|
54 |
+
# h of shape [B, L, D]
|
55 |
+
# query_seg of shape [D, max_width]
|
56 |
+
|
57 |
+
B, L, D = h.size()
|
58 |
+
|
59 |
+
span_rep = self.mlp(h)
|
60 |
+
|
61 |
+
span_rep = span_rep.view(B, L, -1, D)
|
62 |
+
|
63 |
+
return span_rep.relu()
|
64 |
+
|
65 |
+
|
66 |
+
class SpanCAT(nn.Module):
|
67 |
+
|
68 |
+
def __init__(self, hidden_size, max_width):
|
69 |
+
super().__init__()
|
70 |
+
|
71 |
+
self.max_width = max_width
|
72 |
+
|
73 |
+
self.query_seg = nn.Parameter(torch.randn(128, max_width))
|
74 |
+
|
75 |
+
self.project = nn.Sequential(
|
76 |
+
nn.Linear(hidden_size + 128, hidden_size),
|
77 |
+
nn.ReLU()
|
78 |
+
)
|
79 |
+
|
80 |
+
def forward(self, h, *args):
|
81 |
+
# h of shape [B, L, D]
|
82 |
+
# query_seg of shape [D, max_width]
|
83 |
+
|
84 |
+
B, L, D = h.size()
|
85 |
+
|
86 |
+
h = h.view(B, L, 1, D).repeat(1, 1, self.max_width, 1)
|
87 |
+
|
88 |
+
q = self.query_seg.view(1, 1, self.max_width, -1).repeat(B, L, 1, 1)
|
89 |
+
|
90 |
+
span_rep = torch.cat([h, q], dim=-1)
|
91 |
+
|
92 |
+
span_rep = self.project(span_rep)
|
93 |
+
|
94 |
+
return span_rep
|
95 |
+
|
96 |
+
|
97 |
+
class SpanConvBlock(nn.Module):
|
98 |
+
def __init__(self, hidden_size, kernel_size, span_mode='conv_normal'):
|
99 |
+
super().__init__()
|
100 |
+
|
101 |
+
if span_mode == 'conv_conv':
|
102 |
+
self.conv = nn.Conv1d(hidden_size, hidden_size,
|
103 |
+
kernel_size=kernel_size)
|
104 |
+
|
105 |
+
# initialize the weights
|
106 |
+
nn.init.kaiming_uniform_(self.conv.weight, nonlinearity='relu')
|
107 |
+
|
108 |
+
elif span_mode == 'conv_max':
|
109 |
+
self.conv = nn.MaxPool1d(kernel_size=kernel_size, stride=1)
|
110 |
+
elif span_mode == 'conv_mean' or span_mode == 'conv_sum':
|
111 |
+
self.conv = nn.AvgPool1d(kernel_size=kernel_size, stride=1)
|
112 |
+
|
113 |
+
self.span_mode = span_mode
|
114 |
+
|
115 |
+
self.pad = kernel_size - 1
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
|
119 |
+
x = torch.einsum('bld->bdl', x)
|
120 |
+
|
121 |
+
if self.pad > 0:
|
122 |
+
x = F.pad(x, (0, self.pad), "constant", 0)
|
123 |
+
|
124 |
+
x = self.conv(x)
|
125 |
+
|
126 |
+
if self.span_mode == "conv_sum":
|
127 |
+
x = x * (self.pad + 1)
|
128 |
+
|
129 |
+
return torch.einsum('bdl->bld', x)
|
130 |
+
|
131 |
+
|
132 |
+
class SpanConv(nn.Module):
|
133 |
+
def __init__(self, hidden_size, max_width, span_mode):
|
134 |
+
super().__init__()
|
135 |
+
|
136 |
+
kernels = [i + 2 for i in range(max_width - 1)]
|
137 |
+
|
138 |
+
self.convs = nn.ModuleList()
|
139 |
+
|
140 |
+
for kernel in kernels:
|
141 |
+
self.convs.append(SpanConvBlock(hidden_size, kernel, span_mode))
|
142 |
+
|
143 |
+
self.project = nn.Sequential(
|
144 |
+
nn.ReLU(),
|
145 |
+
nn.Linear(hidden_size, hidden_size)
|
146 |
+
)
|
147 |
+
|
148 |
+
def forward(self, x, *args):
|
149 |
+
|
150 |
+
span_reps = [x]
|
151 |
+
|
152 |
+
for conv in self.convs:
|
153 |
+
h = conv(x)
|
154 |
+
span_reps.append(h)
|
155 |
+
|
156 |
+
span_reps = torch.stack(span_reps, dim=-2)
|
157 |
+
|
158 |
+
return self.project(span_reps)
|
159 |
+
|
160 |
+
|
161 |
+
class SpanEndpointsBlock(nn.Module):
|
162 |
+
def __init__(self, kernel_size):
|
163 |
+
super().__init__()
|
164 |
+
|
165 |
+
self.kernel_size = kernel_size
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
B, L, D = x.size()
|
169 |
+
|
170 |
+
span_idx = torch.LongTensor(
|
171 |
+
[[i, i + self.kernel_size - 1] for i in range(L)]).to(x.device)
|
172 |
+
|
173 |
+
x = F.pad(x, (0, 0, 0, self.kernel_size - 1), "constant", 0)
|
174 |
+
|
175 |
+
# endrep
|
176 |
+
start_end_rep = torch.index_select(x, dim=1, index=span_idx.view(-1))
|
177 |
+
|
178 |
+
start_end_rep = start_end_rep.view(B, L, 2, D)
|
179 |
+
|
180 |
+
return start_end_rep
|
181 |
+
|
182 |
+
|
183 |
+
class ConvShare(nn.Module):
|
184 |
+
def __init__(self, hidden_size, max_width):
|
185 |
+
super().__init__()
|
186 |
+
|
187 |
+
self.max_width = max_width
|
188 |
+
|
189 |
+
self.conv_weigth = nn.Parameter(
|
190 |
+
torch.randn(hidden_size, hidden_size, max_width))
|
191 |
+
|
192 |
+
nn.init.kaiming_uniform_(self.conv_weigth, nonlinearity='relu')
|
193 |
+
|
194 |
+
self.project = nn.Sequential(
|
195 |
+
nn.ReLU(),
|
196 |
+
nn.Linear(hidden_size, hidden_size)
|
197 |
+
)
|
198 |
+
|
199 |
+
def forward(self, x, *args):
|
200 |
+
span_reps = []
|
201 |
+
|
202 |
+
x = torch.einsum('bld->bdl', x)
|
203 |
+
|
204 |
+
for i in range(self.max_width):
|
205 |
+
pad = i
|
206 |
+
x_i = F.pad(x, (0, pad), "constant", 0)
|
207 |
+
conv_w = self.conv_weigth[:, :, :i + 1]
|
208 |
+
out_i = F.conv1d(x_i, conv_w)
|
209 |
+
span_reps.append(out_i.transpose(-1, -2))
|
210 |
+
|
211 |
+
out = torch.stack(span_reps, dim=-2)
|
212 |
+
|
213 |
+
return self.project(out)
|
214 |
+
|
215 |
+
|
216 |
+
def extract_elements(sequence, indices):
|
217 |
+
B, L, D = sequence.shape
|
218 |
+
K = indices.shape[1]
|
219 |
+
|
220 |
+
# Expand indices to [B, K, D]
|
221 |
+
expanded_indices = indices.unsqueeze(2).expand(-1, -1, D)
|
222 |
+
|
223 |
+
# Gather the elements
|
224 |
+
extracted_elements = torch.gather(sequence, 1, expanded_indices)
|
225 |
+
|
226 |
+
return extracted_elements
|
227 |
+
|
228 |
+
|
229 |
+
class SpanMarker(nn.Module):
|
230 |
+
|
231 |
+
def __init__(self, hidden_size, max_width, dropout=0.4):
|
232 |
+
super().__init__()
|
233 |
+
|
234 |
+
self.max_width = max_width
|
235 |
+
|
236 |
+
self.project_start = nn.Sequential(
|
237 |
+
nn.Linear(hidden_size, hidden_size * 2, bias=True),
|
238 |
+
nn.ReLU(),
|
239 |
+
nn.Dropout(dropout),
|
240 |
+
nn.Linear(hidden_size * 2, hidden_size, bias=True),
|
241 |
+
)
|
242 |
+
|
243 |
+
self.project_end = nn.Sequential(
|
244 |
+
nn.Linear(hidden_size, hidden_size * 2, bias=True),
|
245 |
+
nn.ReLU(),
|
246 |
+
nn.Dropout(dropout),
|
247 |
+
nn.Linear(hidden_size * 2, hidden_size, bias=True),
|
248 |
+
)
|
249 |
+
|
250 |
+
self.out_project = nn.Linear(hidden_size * 2, hidden_size, bias=True)
|
251 |
+
|
252 |
+
def forward(self, h, span_idx):
|
253 |
+
# h of shape [B, L, D]
|
254 |
+
# query_seg of shape [D, max_width]
|
255 |
+
|
256 |
+
B, L, D = h.size()
|
257 |
+
|
258 |
+
# project start and end
|
259 |
+
start_rep = self.project_start(h)
|
260 |
+
end_rep = self.project_end(h)
|
261 |
+
|
262 |
+
start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
|
263 |
+
end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
|
264 |
+
|
265 |
+
# concat start and end
|
266 |
+
cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
|
267 |
+
|
268 |
+
# project
|
269 |
+
cat = self.out_project(cat)
|
270 |
+
|
271 |
+
# reshape
|
272 |
+
return cat.view(B, L, self.max_width, D)
|
273 |
+
|
274 |
+
|
275 |
+
class SpanMarkerV0(nn.Module):
|
276 |
+
"""
|
277 |
+
Marks and projects span endpoints using an MLP.
|
278 |
+
"""
|
279 |
+
|
280 |
+
def __init__(self, hidden_size: int, max_width: int, dropout: float = 0.4):
|
281 |
+
super().__init__()
|
282 |
+
self.max_width = max_width
|
283 |
+
self.project_start = create_projection_layer(hidden_size, dropout)
|
284 |
+
self.project_end = create_projection_layer(hidden_size, dropout)
|
285 |
+
|
286 |
+
self.out_project = create_projection_layer(hidden_size * 2, dropout, hidden_size)
|
287 |
+
|
288 |
+
def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor:
|
289 |
+
B, L, D = h.size()
|
290 |
+
|
291 |
+
start_rep = self.project_start(h)
|
292 |
+
end_rep = self.project_end(h)
|
293 |
+
|
294 |
+
start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
|
295 |
+
end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
|
296 |
+
|
297 |
+
cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
|
298 |
+
|
299 |
+
return self.out_project(cat).view(B, L, self.max_width, D)
|
300 |
+
|
301 |
+
|
302 |
+
class ConvShareV2(nn.Module):
|
303 |
+
def __init__(self, hidden_size, max_width):
|
304 |
+
super().__init__()
|
305 |
+
|
306 |
+
self.max_width = max_width
|
307 |
+
|
308 |
+
self.conv_weigth = nn.Parameter(
|
309 |
+
torch.randn(hidden_size, hidden_size, max_width)
|
310 |
+
)
|
311 |
+
|
312 |
+
nn.init.xavier_normal_(self.conv_weigth)
|
313 |
+
|
314 |
+
def forward(self, x, *args):
|
315 |
+
span_reps = []
|
316 |
+
|
317 |
+
x = torch.einsum('bld->bdl', x)
|
318 |
+
|
319 |
+
for i in range(self.max_width):
|
320 |
+
pad = i
|
321 |
+
x_i = F.pad(x, (0, pad), "constant", 0)
|
322 |
+
conv_w = self.conv_weigth[:, :, :i + 1]
|
323 |
+
out_i = F.conv1d(x_i, conv_w)
|
324 |
+
span_reps.append(out_i.transpose(-1, -2))
|
325 |
+
|
326 |
+
out = torch.stack(span_reps, dim=-2)
|
327 |
+
|
328 |
+
return out
|
329 |
+
|
330 |
+
|
331 |
+
class SpanRepLayer(nn.Module):
|
332 |
+
"""
|
333 |
+
Various span representation approaches
|
334 |
+
"""
|
335 |
+
|
336 |
+
def __init__(self, hidden_size, max_width, span_mode, **kwargs):
|
337 |
+
super().__init__()
|
338 |
+
|
339 |
+
if span_mode == 'marker':
|
340 |
+
self.span_rep_layer = SpanMarker(hidden_size, max_width, **kwargs)
|
341 |
+
elif span_mode == 'markerV0':
|
342 |
+
self.span_rep_layer = SpanMarkerV0(hidden_size, max_width, **kwargs)
|
343 |
+
elif span_mode == 'query':
|
344 |
+
self.span_rep_layer = SpanQuery(
|
345 |
+
hidden_size, max_width, trainable=True)
|
346 |
+
elif span_mode == 'mlp':
|
347 |
+
self.span_rep_layer = SpanMLP(hidden_size, max_width)
|
348 |
+
elif span_mode == 'cat':
|
349 |
+
self.span_rep_layer = SpanCAT(hidden_size, max_width)
|
350 |
+
elif span_mode == 'conv_conv':
|
351 |
+
self.span_rep_layer = SpanConv(
|
352 |
+
hidden_size, max_width, span_mode='conv_conv')
|
353 |
+
elif span_mode == 'conv_max':
|
354 |
+
self.span_rep_layer = SpanConv(
|
355 |
+
hidden_size, max_width, span_mode='conv_max')
|
356 |
+
elif span_mode == 'conv_mean':
|
357 |
+
self.span_rep_layer = SpanConv(
|
358 |
+
hidden_size, max_width, span_mode='conv_mean')
|
359 |
+
elif span_mode == 'conv_sum':
|
360 |
+
self.span_rep_layer = SpanConv(
|
361 |
+
hidden_size, max_width, span_mode='conv_sum')
|
362 |
+
elif span_mode == 'conv_share':
|
363 |
+
self.span_rep_layer = ConvShare(hidden_size, max_width)
|
364 |
+
else:
|
365 |
+
raise ValueError(f'Unknown span mode {span_mode}')
|
366 |
+
|
367 |
+
def forward(self, x, *args):
|
368 |
+
|
369 |
+
return self.span_rep_layer(x, *args)
|
GLiNER/modules/token_rep.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from flair.data import Sentence
|
5 |
+
from flair.embeddings import TransformerWordEmbeddings
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn.utils.rnn import pad_sequence
|
8 |
+
|
9 |
+
|
10 |
+
# flair.cache_root = '/gpfswork/rech/pds/upa43yu/.cache'
|
11 |
+
|
12 |
+
|
13 |
+
class TokenRepLayer(nn.Module):
|
14 |
+
def __init__(self, model_name: str = "bert-base-cased", fine_tune: bool = True, subtoken_pooling: str = "first",
|
15 |
+
hidden_size: int = 768,
|
16 |
+
add_tokens=["[SEP]", "[ENT]"]
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
self.bert_layer = TransformerWordEmbeddings(
|
21 |
+
model_name,
|
22 |
+
fine_tune=fine_tune,
|
23 |
+
subtoken_pooling=subtoken_pooling,
|
24 |
+
allow_long_sentences=True
|
25 |
+
)
|
26 |
+
|
27 |
+
# add tokens to vocabulary
|
28 |
+
self.bert_layer.tokenizer.add_tokens(add_tokens)
|
29 |
+
|
30 |
+
# resize token embeddings
|
31 |
+
self.bert_layer.model.resize_token_embeddings(len(self.bert_layer.tokenizer))
|
32 |
+
|
33 |
+
bert_hidden_size = self.bert_layer.embedding_length
|
34 |
+
|
35 |
+
if hidden_size != bert_hidden_size:
|
36 |
+
self.projection = nn.Linear(bert_hidden_size, hidden_size)
|
37 |
+
|
38 |
+
def forward(self, tokens: List[List[str]], lengths: torch.Tensor):
|
39 |
+
token_embeddings = self.compute_word_embedding(tokens)
|
40 |
+
|
41 |
+
if hasattr(self, "projection"):
|
42 |
+
token_embeddings = self.projection(token_embeddings)
|
43 |
+
|
44 |
+
B = len(lengths)
|
45 |
+
max_length = lengths.max()
|
46 |
+
mask = (torch.arange(max_length).view(1, -1).repeat(B, 1) < lengths.cpu().unsqueeze(1)).to(
|
47 |
+
token_embeddings.device).long()
|
48 |
+
return {"embeddings": token_embeddings, "mask": mask}
|
49 |
+
|
50 |
+
def compute_word_embedding(self, tokens):
|
51 |
+
sentences = [Sentence(i) for i in tokens]
|
52 |
+
self.bert_layer.embed(sentences)
|
53 |
+
token_embeddings = pad_sequence([torch.stack([t.embedding for t in k]) for k in sentences], batch_first=True)
|
54 |
+
return token_embeddings
|
GLiNER/requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
huggingface_hub
|
4 |
+
flair
|
5 |
+
seqeval
|
6 |
+
tqdm
|
GLiNER/save_load.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from model import GLiNER
|
3 |
+
|
4 |
+
|
5 |
+
def save_model(current_model, path):
|
6 |
+
config = current_model.config
|
7 |
+
dict_save = {"model_weights": current_model.state_dict(), "config": config}
|
8 |
+
torch.save(dict_save, path)
|
9 |
+
|
10 |
+
|
11 |
+
def load_model(path, model_name=None, device=None):
|
12 |
+
dict_load = torch.load(path, map_location=torch.device('cpu'))
|
13 |
+
config = dict_load["config"]
|
14 |
+
|
15 |
+
if model_name is not None:
|
16 |
+
config.model_name = model_name
|
17 |
+
|
18 |
+
loaded_model = GLiNER(config)
|
19 |
+
loaded_model.load_state_dict(dict_load["model_weights"])
|
20 |
+
return loaded_model.to(device) if device is not None else loaded_model
|
GLiNER/train.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import yaml
|
6 |
+
from tqdm import tqdm
|
7 |
+
from transformers import get_cosine_schedule_with_warmup
|
8 |
+
|
9 |
+
# from model_nested import NerFilteredSemiCRF
|
10 |
+
from model import GLiNER
|
11 |
+
from modules.run_evaluation import get_for_all_path, sample_train_data
|
12 |
+
from save_load import save_model, load_model
|
13 |
+
import json
|
14 |
+
|
15 |
+
|
16 |
+
# train function
|
17 |
+
def train(model, optimizer, train_data, num_steps=1000, eval_every=100, log_dir="logs", warmup_ratio=0.1,
|
18 |
+
train_batch_size=8, device='cuda'):
|
19 |
+
model.train()
|
20 |
+
|
21 |
+
# initialize data loaders
|
22 |
+
train_loader = model.create_dataloader(train_data, batch_size=train_batch_size, shuffle=True)
|
23 |
+
|
24 |
+
pbar = tqdm(range(num_steps))
|
25 |
+
|
26 |
+
if warmup_ratio < 1:
|
27 |
+
num_warmup_steps = int(num_steps * warmup_ratio)
|
28 |
+
else:
|
29 |
+
num_warmup_steps = int(warmup_ratio)
|
30 |
+
|
31 |
+
scheduler = get_cosine_schedule_with_warmup(
|
32 |
+
optimizer,
|
33 |
+
num_warmup_steps=num_warmup_steps,
|
34 |
+
num_training_steps=num_steps
|
35 |
+
)
|
36 |
+
|
37 |
+
iter_train_loader = iter(train_loader)
|
38 |
+
|
39 |
+
for step in pbar:
|
40 |
+
try:
|
41 |
+
x = next(iter_train_loader)
|
42 |
+
except StopIteration:
|
43 |
+
iter_train_loader = iter(train_loader)
|
44 |
+
x = next(iter_train_loader)
|
45 |
+
|
46 |
+
for k, v in x.items():
|
47 |
+
if isinstance(v, torch.Tensor):
|
48 |
+
x[k] = v.to(device)
|
49 |
+
|
50 |
+
try:
|
51 |
+
loss = model(x) # Forward pass
|
52 |
+
except:
|
53 |
+
continue
|
54 |
+
|
55 |
+
# check if loss is nan
|
56 |
+
if torch.isnan(loss):
|
57 |
+
continue
|
58 |
+
|
59 |
+
loss.backward() # Compute gradients
|
60 |
+
optimizer.step() # Update parameters
|
61 |
+
scheduler.step() # Update learning rate schedule
|
62 |
+
optimizer.zero_grad() # Reset gradients
|
63 |
+
|
64 |
+
description = f"step: {step} | epoch: {step // len(train_loader)} | loss: {loss.item():.2f}"
|
65 |
+
|
66 |
+
if (step + 1) % eval_every == 0:
|
67 |
+
current_path = os.path.join(log_dir, f'model_{step + 1}')
|
68 |
+
save_model(model, current_path)
|
69 |
+
#val_data_dir = "/gpfswork/rech/ohy/upa43yu/NER_datasets" # can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view"
|
70 |
+
#get_for_all_path(model, step, log_dir, val_data_dir) # you can remove this comment if you want to evaluate the model
|
71 |
+
|
72 |
+
model.train()
|
73 |
+
|
74 |
+
pbar.set_description(description)
|
75 |
+
|
76 |
+
|
77 |
+
def create_parser():
|
78 |
+
parser = argparse.ArgumentParser(description="Span-based NER")
|
79 |
+
parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file")
|
80 |
+
parser.add_argument('--log_dir', type=str, default='logs', help='Path to the log directory')
|
81 |
+
return parser
|
82 |
+
|
83 |
+
|
84 |
+
def load_config_as_namespace(config_file):
|
85 |
+
with open(config_file, 'r') as f:
|
86 |
+
config_dict = yaml.safe_load(f)
|
87 |
+
return argparse.Namespace(**config_dict)
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
# parse args
|
92 |
+
parser = create_parser()
|
93 |
+
args = parser.parse_args()
|
94 |
+
|
95 |
+
# load config
|
96 |
+
config = load_config_as_namespace(args.config)
|
97 |
+
|
98 |
+
config.log_dir = args.log_dir
|
99 |
+
|
100 |
+
try:
|
101 |
+
with open(config.train_data, 'r') as f:
|
102 |
+
data = json.load(f)
|
103 |
+
except:
|
104 |
+
data = sample_train_data(config.train_data, 10000)
|
105 |
+
|
106 |
+
if config.prev_path != "none":
|
107 |
+
model = load_model(config.prev_path)
|
108 |
+
model.config = config
|
109 |
+
else:
|
110 |
+
model = GLiNER(config)
|
111 |
+
|
112 |
+
if torch.cuda.is_available():
|
113 |
+
model = model.cuda()
|
114 |
+
|
115 |
+
lr_encoder = float(config.lr_encoder)
|
116 |
+
lr_others = float(config.lr_others)
|
117 |
+
|
118 |
+
optimizer = torch.optim.AdamW([
|
119 |
+
# encoder
|
120 |
+
{'params': model.token_rep_layer.parameters(), 'lr': lr_encoder},
|
121 |
+
{'params': model.rnn.parameters(), 'lr': lr_others},
|
122 |
+
# projection layers
|
123 |
+
{'params': model.span_rep_layer.parameters(), 'lr': lr_others},
|
124 |
+
{'params': model.prompt_rep_layer.parameters(), 'lr': lr_others},
|
125 |
+
])
|
126 |
+
|
127 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
128 |
+
|
129 |
+
train(model, optimizer, data, num_steps=config.num_steps, eval_every=config.eval_every,
|
130 |
+
log_dir=config.log_dir, warmup_ratio=config.warmup_ratio, train_batch_size=config.train_batch_size,
|
131 |
+
device=device)
|