WebashalarForML commited on
Commit
9ae46f4
·
verified ·
1 Parent(s): ddc0102

Upload 5 files

Browse files
backup/backup.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .model import GLiNER
2
+
3
+ # Initialize GLiNER with the base model
4
+ model = GLiNER.from_pretrained("urchade/gliner_mediumv2.1")
5
+
6
+ # Sample text for entity prediction
7
+ text = """
8
+ lenskart m: (0)9428002330 Lenskart Store,Surat m: (0)9723817060) e:[email protected] Store Address UG-4.Ascon City.Opp.Maheshwari Bhavan,Citylight,Surat-395007"""
9
+
10
+ # Labels for entity prediction
11
+ # # Most GLiNER models should work best when entity types are in lower case or title case
12
+ # labels = ["Person", "Mail", "Number", "Address", "Organization","Designation"]
13
+
14
+ # # Perform entity prediction
15
+ # entities = model.predict_entities(text, labels, threshold=0.5)
16
+
17
+
18
+ def NER_Model(text):
19
+
20
+ labels = ["Person", "Mail", "Number", "Address", "Organization","Designation","Link"]
21
+
22
+ # Perform entity prediction
23
+ entities = model.predict_entities(text, labels, threshold=0.5)
24
+
25
+ # Initialize the processed data dictionary
26
+ processed_data = {
27
+ "Name": [],
28
+ "Contact": [],
29
+ "Designation": [],
30
+ "Address": [],
31
+ "Link": [],
32
+ "Company": [],
33
+ "Email": [],
34
+ "extracted_text": "",
35
+ }
36
+
37
+ for entity in entities:
38
+
39
+ print(entity["text"], "=>", entity["label"])
40
+
41
+ #loading the data into json
42
+ if entity["label"]==labels[0]:
43
+ processed_data['Name'].extend([entity["text"]])
44
+
45
+ if entity["label"]==labels[1]:
46
+ processed_data['Email'].extend([entity["text"]])
47
+
48
+ if entity["label"]==labels[2]:
49
+ processed_data['Contact'].extend([entity["text"]])
50
+
51
+ if entity["label"]==labels[3]:
52
+ processed_data['Address'].extend([entity["text"]])
53
+
54
+ if entity["label"]==labels[4]:
55
+ processed_data['Company'].extend([entity["text"]])
56
+
57
+ if entity["label"]==labels[5]:
58
+ processed_data['Designation'].extend([entity["text"]])
59
+
60
+ if entity["label"]==labels[6]:
61
+ processed_data['Link'].extend([entity["text"]])
62
+
63
+
64
+ processed_data['Address']=[', '.join(processed_data['Address'])]
65
+ processed_data['extracted_text']=[text]
66
+
67
+ return processed_data
68
+
69
+ # result=NER_Model(text)
70
+ # print(result)
71
+
72
+
73
+
74
+
75
+
backup/model.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ import re
5
+ from typing import Dict, Optional, Union
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from .modules.layers import LstmSeq2SeqEncoder
9
+ from .modules.base import InstructBase
10
+ from .modules.evaluator import Evaluator, greedy_search
11
+ from .modules.span_rep import SpanRepLayer
12
+ from .modules.token_rep import TokenRepLayer
13
+ from torch import nn
14
+ from torch.nn.utils.rnn import pad_sequence
15
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
16
+ from huggingface_hub.utils import HfHubHTTPError
17
+
18
+
19
+
20
+ class GLiNER(InstructBase, PyTorchModelHubMixin):
21
+ def __init__(self, config):
22
+ super().__init__(config)
23
+
24
+ self.config = config
25
+
26
+ # [ENT] token
27
+ self.entity_token = "<<ENT>>"
28
+ self.sep_token = "<<SEP>>"
29
+
30
+ # usually a pretrained bidirectional transformer, returns first subtoken representation
31
+ self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune,
32
+ subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size,
33
+ add_tokens=[self.entity_token, self.sep_token])
34
+
35
+ # hierarchical representation of tokens
36
+ self.rnn = LstmSeq2SeqEncoder(
37
+ input_size=config.hidden_size,
38
+ hidden_size=config.hidden_size // 2,
39
+ num_layers=1,
40
+ bidirectional=True,
41
+ )
42
+
43
+ # span representation
44
+ self.span_rep_layer = SpanRepLayer(
45
+ span_mode=config.span_mode,
46
+ hidden_size=config.hidden_size,
47
+ max_width=config.max_width,
48
+ dropout=config.dropout,
49
+ )
50
+
51
+ # prompt representation (FFN)
52
+ self.prompt_rep_layer = nn.Sequential(
53
+ nn.Linear(config.hidden_size, config.hidden_size * 4),
54
+ nn.Dropout(config.dropout),
55
+ nn.ReLU(),
56
+ nn.Linear(config.hidden_size * 4, config.hidden_size)
57
+ )
58
+
59
+ def compute_score_train(self, x):
60
+ span_idx = x['span_idx'] * x['span_mask'].unsqueeze(-1)
61
+
62
+ new_length = x['seq_length'].clone()
63
+ new_tokens = []
64
+ all_len_prompt = []
65
+ num_classes_all = []
66
+
67
+ # add prompt to the tokens
68
+ for i in range(len(x['tokens'])):
69
+ all_types_i = list(x['classes_to_id'][i].keys())
70
+ # multiple entity types in all_types. Prompt is appended at the start of tokens
71
+ entity_prompt = []
72
+ num_classes_all.append(len(all_types_i))
73
+ # add enity types to prompt
74
+ for entity_type in all_types_i:
75
+ entity_prompt.append(self.entity_token) # [ENT] token
76
+ entity_prompt.append(entity_type) # entity type
77
+ entity_prompt.append(self.sep_token) # [SEP] token
78
+
79
+ # prompt format:
80
+ # [ENT] entity_type [ENT] entity_type ... [ENT] entity_type [SEP]
81
+
82
+ # add prompt to the tokens
83
+ tokens_p = entity_prompt + x['tokens'][i]
84
+
85
+ # input format:
86
+ # [ENT] entity_type_1 [ENT] entity_type_2 ... [ENT] entity_type_m [SEP] token_1 token_2 ... token_n
87
+
88
+ # update length of the sequence (add prompt length to the original length)
89
+ new_length[i] = new_length[i] + len(entity_prompt)
90
+ # update tokens
91
+ new_tokens.append(tokens_p)
92
+ # store prompt length
93
+ all_len_prompt.append(len(entity_prompt))
94
+
95
+ # create a mask using num_classes_all (0, if it exceeds the number of classes, 1 otherwise)
96
+ max_num_classes = max(num_classes_all)
97
+ entity_type_mask = torch.arange(max_num_classes).unsqueeze(0).expand(len(num_classes_all), -1).to(
98
+ x['span_mask'].device)
99
+ entity_type_mask = entity_type_mask < torch.tensor(num_classes_all).unsqueeze(-1).to(
100
+ x['span_mask'].device) # [batch_size, max_num_classes]
101
+
102
+ # compute all token representations
103
+ bert_output = self.token_rep_layer(new_tokens, new_length)
104
+ word_rep_w_prompt = bert_output["embeddings"] # embeddings for all tokens (with prompt)
105
+ mask_w_prompt = bert_output["mask"] # mask for all tokens (with prompt)
106
+
107
+ # get word representation (after [SEP]), mask (after [SEP]) and entity type representation (before [SEP])
108
+ word_rep = [] # word representation (after [SEP])
109
+ mask = [] # mask (after [SEP])
110
+ entity_type_rep = [] # entity type representation (before [SEP])
111
+ for i in range(len(x['tokens'])):
112
+ prompt_entity_length = all_len_prompt[i] # length of prompt for this example
113
+ # get word representation (after [SEP])
114
+ word_rep.append(word_rep_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
115
+ # get mask (after [SEP])
116
+ mask.append(mask_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]])
117
+
118
+ # get entity type representation (before [SEP])
119
+ entity_rep = word_rep_w_prompt[i, :prompt_entity_length - 1] # remove [SEP]
120
+ entity_rep = entity_rep[0::2] # it means that we take every second element starting from the second one
121
+ entity_type_rep.append(entity_rep)
122
+
123
+ # padding for word_rep, mask and entity_type_rep
124
+ word_rep = pad_sequence(word_rep, batch_first=True) # [batch_size, seq_len, hidden_size]
125
+ mask = pad_sequence(mask, batch_first=True) # [batch_size, seq_len]
126
+ entity_type_rep = pad_sequence(entity_type_rep, batch_first=True) # [batch_size, len_types, hidden_size]
127
+
128
+ # compute span representation
129
+ word_rep = self.rnn(word_rep, mask)
130
+ span_rep = self.span_rep_layer(word_rep, span_idx)
131
+
132
+ # compute final entity type representation (FFN)
133
+ entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size)
134
+ num_classes = entity_type_rep.shape[1] # number of entity types
135
+
136
+ # similarity score
137
+ scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
138
+
139
+ return scores, num_classes, entity_type_mask
140
+
141
+ def forward(self, x):
142
+ # compute span representation
143
+ scores, num_classes, entity_type_mask = self.compute_score_train(x)
144
+ batch_size = scores.shape[0]
145
+
146
+ # loss for filtering classifier
147
+ logits_label = scores.view(-1, num_classes)
148
+ labels = x["span_label"].view(-1) # (batch_size * num_spans)
149
+ mask_label = labels != -1 # (batch_size * num_spans)
150
+ labels.masked_fill_(~mask_label, 0) # Set the labels of padding tokens to 0
151
+
152
+ # one-hot encoding
153
+ labels_one_hot = torch.zeros(labels.size(0), num_classes + 1, dtype=torch.float32).to(scores.device)
154
+ labels_one_hot.scatter_(1, labels.unsqueeze(1), 1) # Set the corresponding index to 1
155
+ labels_one_hot = labels_one_hot[:, 1:] # Remove the first column
156
+ # Shape of labels_one_hot: (batch_size * num_spans, num_classes)
157
+
158
+ # compute loss (without reduction)
159
+ all_losses = F.binary_cross_entropy_with_logits(logits_label, labels_one_hot,
160
+ reduction='none')
161
+ # mask loss using entity_type_mask (B, C)
162
+ masked_loss = all_losses.view(batch_size, -1, num_classes) * entity_type_mask.unsqueeze(1)
163
+ all_losses = masked_loss.view(-1, num_classes)
164
+ # expand mask_label to all_losses
165
+ mask_label = mask_label.unsqueeze(-1).expand_as(all_losses)
166
+ # put lower loss for in label_one_hot (2 for positive, 1 for negative)
167
+ weight_c = labels_one_hot + 1
168
+ # apply mask
169
+ all_losses = all_losses * mask_label.float() * weight_c
170
+ return all_losses.sum()
171
+
172
+ def compute_score_eval(self, x, device):
173
+ # check if classes_to_id is dict
174
+ assert isinstance(x['classes_to_id'], dict), "classes_to_id must be a dict"
175
+
176
+ span_idx = (x['span_idx'] * x['span_mask'].unsqueeze(-1)).to(device)
177
+
178
+ all_types = list(x['classes_to_id'].keys())
179
+ # multiple entity types in all_types. Prompt is appended at the start of tokens
180
+ entity_prompt = []
181
+
182
+ # add enity types to prompt
183
+ for entity_type in all_types:
184
+ entity_prompt.append(self.entity_token)
185
+ entity_prompt.append(entity_type)
186
+
187
+ entity_prompt.append(self.sep_token)
188
+
189
+ prompt_entity_length = len(entity_prompt)
190
+
191
+ # add prompt
192
+ tokens_p = [entity_prompt + tokens for tokens in x['tokens']]
193
+ seq_length_p = x['seq_length'] + prompt_entity_length
194
+
195
+ out = self.token_rep_layer(tokens_p, seq_length_p)
196
+
197
+ word_rep_w_prompt = out["embeddings"]
198
+ mask_w_prompt = out["mask"]
199
+
200
+ # remove prompt
201
+ word_rep = word_rep_w_prompt[:, prompt_entity_length:, :]
202
+ mask = mask_w_prompt[:, prompt_entity_length:]
203
+
204
+ # get_entity_type_rep
205
+ entity_type_rep = word_rep_w_prompt[:, :prompt_entity_length - 1, :]
206
+ # extract [ENT] tokens (which are at even positions in entity_type_rep)
207
+ entity_type_rep = entity_type_rep[:, 0::2, :]
208
+
209
+ entity_type_rep = self.prompt_rep_layer(entity_type_rep) # (batch_size, len_types, hidden_size)
210
+
211
+ word_rep = self.rnn(word_rep, mask)
212
+
213
+ span_rep = self.span_rep_layer(word_rep, span_idx)
214
+
215
+ local_scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep)
216
+
217
+ return local_scores
218
+
219
+ @torch.no_grad()
220
+ def predict(self, x, flat_ner=False, threshold=0.5):
221
+ self.eval()
222
+ local_scores = self.compute_score_eval(x, device=next(self.parameters()).device)
223
+ spans = []
224
+ for i, _ in enumerate(x["tokens"]):
225
+ local_i = local_scores[i]
226
+ wh_i = [i.tolist() for i in torch.where(torch.sigmoid(local_i) > threshold)]
227
+ span_i = []
228
+ for s, k, c in zip(*wh_i):
229
+ if s + k < len(x["tokens"][i]):
230
+ span_i.append((s, s + k, x["id_to_classes"][c + 1], local_i[s, k, c]))
231
+ span_i = greedy_search(span_i, flat_ner)
232
+ spans.append(span_i)
233
+ return spans
234
+
235
+ def predict_entities(self, text, labels, flat_ner=True, threshold=0.5):
236
+ tokens = []
237
+ start_token_idx_to_text_idx = []
238
+ end_token_idx_to_text_idx = []
239
+ for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text):
240
+ tokens.append(match.group())
241
+ start_token_idx_to_text_idx.append(match.start())
242
+ end_token_idx_to_text_idx.append(match.end())
243
+
244
+ input_x = {"tokenized_text": tokens, "ner": None}
245
+ x = self.collate_fn([input_x], labels)
246
+ output = self.predict(x, flat_ner=flat_ner, threshold=threshold)
247
+
248
+ entities = []
249
+ for start_token_idx, end_token_idx, ent_type in output[0]:
250
+ start_text_idx = start_token_idx_to_text_idx[start_token_idx]
251
+ end_text_idx = end_token_idx_to_text_idx[end_token_idx]
252
+ entities.append({
253
+ "start": start_token_idx_to_text_idx[start_token_idx],
254
+ "end": end_token_idx_to_text_idx[end_token_idx],
255
+ "text": text[start_text_idx:end_text_idx],
256
+ "label": ent_type,
257
+ })
258
+ return entities
259
+
260
+ def evaluate(self, test_data, flat_ner=False, threshold=0.5, batch_size=12, entity_types=None):
261
+ self.eval()
262
+ data_loader = self.create_dataloader(test_data, batch_size=batch_size, entity_types=entity_types, shuffle=False)
263
+ device = next(self.parameters()).device
264
+ all_preds = []
265
+ all_trues = []
266
+ for x in data_loader:
267
+ for k, v in x.items():
268
+ if isinstance(v, torch.Tensor):
269
+ x[k] = v.to(device)
270
+ batch_predictions = self.predict(x, flat_ner, threshold)
271
+ all_preds.extend(batch_predictions)
272
+ all_trues.extend(x["entities"])
273
+ evaluator = Evaluator(all_trues, all_preds)
274
+ out, f1 = evaluator.evaluate()
275
+ return out, f1
276
+
277
+ @classmethod
278
+ def _from_pretrained(
279
+ cls,
280
+ *,
281
+ model_id: str,
282
+ revision: Optional[str],
283
+ cache_dir: Optional[Union[str, Path]],
284
+ force_download: bool,
285
+ proxies: Optional[Dict],
286
+ resume_download: bool,
287
+ local_files_only: bool,
288
+ token: Union[str, bool, None],
289
+ map_location: str = "cpu",
290
+ strict: bool = False,
291
+ **model_kwargs,
292
+ ):
293
+ # 1. Backwards compatibility: Use "gliner_base.pt" and "gliner_multi.pt" with all data
294
+ filenames = ["gliner_base.pt", "gliner_multi.pt"]
295
+ for filename in filenames:
296
+ model_file = Path(model_id) / filename
297
+ if not model_file.exists():
298
+ try:
299
+ model_file = hf_hub_download(
300
+ repo_id=model_id,
301
+ filename=filename,
302
+ revision=revision,
303
+ cache_dir=cache_dir,
304
+ force_download=force_download,
305
+ proxies=proxies,
306
+ resume_download=resume_download,
307
+ token=token,
308
+ local_files_only=local_files_only,
309
+ )
310
+ except HfHubHTTPError:
311
+ continue
312
+ dict_load = torch.load(model_file, map_location=torch.device(map_location))
313
+ config = dict_load["config"]
314
+ state_dict = dict_load["model_weights"]
315
+ config.model_name = "microsoft/deberta-v3-base" if filename == "gliner_base.pt" else "microsoft/mdeberta-v3-base"
316
+ model = cls(config)
317
+ model.load_state_dict(state_dict, strict=strict, assign=True)
318
+ # Required to update flair's internals as well:
319
+ model.to(map_location)
320
+ return model
321
+
322
+ # 2. Newer format: Use "pytorch_model.bin" and "gliner_config.json"
323
+ from .train import load_config_as_namespace
324
+
325
+ model_file = Path(model_id) / "pytorch_model.bin"
326
+ if not model_file.exists():
327
+ model_file = hf_hub_download(
328
+ repo_id=model_id,
329
+ filename="pytorch_model.bin",
330
+ revision=revision,
331
+ cache_dir=cache_dir,
332
+ force_download=force_download,
333
+ proxies=proxies,
334
+ resume_download=resume_download,
335
+ token=token,
336
+ local_files_only=local_files_only,
337
+ )
338
+ config_file = Path(model_id) / "gliner_config.json"
339
+ if not config_file.exists():
340
+ config_file = hf_hub_download(
341
+ repo_id=model_id,
342
+ filename="gliner_config.json",
343
+ revision=revision,
344
+ cache_dir=cache_dir,
345
+ force_download=force_download,
346
+ proxies=proxies,
347
+ resume_download=resume_download,
348
+ token=token,
349
+ local_files_only=local_files_only,
350
+ )
351
+ config = load_config_as_namespace(config_file)
352
+ model = cls(config)
353
+ state_dict = torch.load(model_file, map_location=torch.device(map_location))
354
+ model.load_state_dict(state_dict, strict=strict, assign=True)
355
+ model.to(map_location)
356
+ return model
357
+
358
+ def save_pretrained(
359
+ self,
360
+ save_directory: Union[str, Path],
361
+ *,
362
+ config: Optional[Union[dict, "DataclassInstance"]] = None,
363
+ repo_id: Optional[str] = None,
364
+ push_to_hub: bool = False,
365
+ **push_to_hub_kwargs,
366
+ ) -> Optional[str]:
367
+ """
368
+ Save weights in local directory.
369
+
370
+ Args:
371
+ save_directory (`str` or `Path`):
372
+ Path to directory in which the model weights and configuration will be saved.
373
+ config (`dict` or `DataclassInstance`, *optional*):
374
+ Model configuration specified as a key/value dictionary or a dataclass instance.
375
+ push_to_hub (`bool`, *optional*, defaults to `False`):
376
+ Whether or not to push your model to the Huggingface Hub after saving it.
377
+ repo_id (`str`, *optional*):
378
+ ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if
379
+ not provided.
380
+ kwargs:
381
+ Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method.
382
+ """
383
+ save_directory = Path(save_directory)
384
+ save_directory.mkdir(parents=True, exist_ok=True)
385
+
386
+ # save model weights/files
387
+ torch.save(self.state_dict(), save_directory / "pytorch_model.bin")
388
+
389
+ # save config (if provided)
390
+ if config is None:
391
+ config = self.config
392
+ if config is not None:
393
+ if isinstance(config, argparse.Namespace):
394
+ config = vars(config)
395
+ (save_directory / "gliner_config.json").write_text(json.dumps(config, indent=2))
396
+
397
+ # push to the Hub if required
398
+ if push_to_hub:
399
+ kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input
400
+ if config is not None: # kwarg for `push_to_hub`
401
+ kwargs["config"] = config
402
+ if repo_id is None:
403
+ repo_id = save_directory.name # Defaults to `save_directory` name
404
+ return self.push_to_hub(repo_id=repo_id, **kwargs)
405
+ return None
406
+
407
+ def to(self, device):
408
+ super().to(device)
409
+ import flair
410
+
411
+ flair.device = device
412
+ return self
backup/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ huggingface_hub
4
+ flair
5
+ seqeval
6
+ tqdm
backup/save_load.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model import GLiNER
3
+
4
+
5
+ def save_model(current_model, path):
6
+ config = current_model.config
7
+ dict_save = {"model_weights": current_model.state_dict(), "config": config}
8
+ torch.save(dict_save, path)
9
+
10
+
11
+ def load_model(path, model_name=None, device=None):
12
+ dict_load = torch.load(path, map_location=torch.device('cpu'))
13
+ config = dict_load["config"]
14
+
15
+ if model_name is not None:
16
+ config.model_name = model_name
17
+
18
+ loaded_model = GLiNER(config)
19
+ loaded_model.load_state_dict(dict_load["model_weights"])
20
+ return loaded_model.to(device) if device is not None else loaded_model
backup/train.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ import yaml
6
+ from tqdm import tqdm
7
+ from transformers import get_cosine_schedule_with_warmup
8
+
9
+ # from model_nested import NerFilteredSemiCRF
10
+ from .model import GLiNER
11
+ from .modules.run_evaluation import get_for_all_path, sample_train_data
12
+ from save_load import save_model, load_model
13
+ import json
14
+
15
+
16
+ # train function
17
+ def train(model, optimizer, train_data, num_steps=1000, eval_every=100, log_dir="logs", warmup_ratio=0.1,
18
+ train_batch_size=8, device='cuda'):
19
+ model.train()
20
+
21
+ # initialize data loaders
22
+ train_loader = model.create_dataloader(train_data, batch_size=train_batch_size, shuffle=True)
23
+
24
+ pbar = tqdm(range(num_steps))
25
+
26
+ if warmup_ratio < 1:
27
+ num_warmup_steps = int(num_steps * warmup_ratio)
28
+ else:
29
+ num_warmup_steps = int(warmup_ratio)
30
+
31
+ scheduler = get_cosine_schedule_with_warmup(
32
+ optimizer,
33
+ num_warmup_steps=num_warmup_steps,
34
+ num_training_steps=num_steps
35
+ )
36
+
37
+ iter_train_loader = iter(train_loader)
38
+
39
+ for step in pbar:
40
+ try:
41
+ x = next(iter_train_loader)
42
+ except StopIteration:
43
+ iter_train_loader = iter(train_loader)
44
+ x = next(iter_train_loader)
45
+
46
+ for k, v in x.items():
47
+ if isinstance(v, torch.Tensor):
48
+ x[k] = v.to(device)
49
+
50
+ try:
51
+ loss = model(x) # Forward pass
52
+ except:
53
+ continue
54
+
55
+ # check if loss is nan
56
+ if torch.isnan(loss):
57
+ continue
58
+
59
+ loss.backward() # Compute gradients
60
+ optimizer.step() # Update parameters
61
+ scheduler.step() # Update learning rate schedule
62
+ optimizer.zero_grad() # Reset gradients
63
+
64
+ description = f"step: {step} | epoch: {step // len(train_loader)} | loss: {loss.item():.2f}"
65
+
66
+ if (step + 1) % eval_every == 0:
67
+ current_path = os.path.join(log_dir, f'model_{step + 1}')
68
+ save_model(model, current_path)
69
+ #val_data_dir = "/gpfswork/rech/ohy/upa43yu/NER_datasets" # can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view"
70
+ #get_for_all_path(model, step, log_dir, val_data_dir) # you can remove this comment if you want to evaluate the model
71
+
72
+ model.train()
73
+
74
+ pbar.set_description(description)
75
+
76
+
77
+ def create_parser():
78
+ parser = argparse.ArgumentParser(description="Span-based NER")
79
+ parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file")
80
+ parser.add_argument('--log_dir', type=str, default='logs', help='Path to the log directory')
81
+ return parser
82
+
83
+
84
+ def load_config_as_namespace(config_file):
85
+ with open(config_file, 'r') as f:
86
+ config_dict = yaml.safe_load(f)
87
+ return argparse.Namespace(**config_dict)
88
+
89
+
90
+ if __name__ == "__main__":
91
+ # parse args
92
+ parser = create_parser()
93
+ args = parser.parse_args()
94
+
95
+ # load config
96
+ config = load_config_as_namespace(args.config)
97
+
98
+ config.log_dir = args.log_dir
99
+
100
+ try:
101
+ with open(config.train_data, 'r') as f:
102
+ data = json.load(f)
103
+ except:
104
+ data = sample_train_data(config.train_data, 10000)
105
+
106
+ if config.prev_path != "none":
107
+ model = load_model(config.prev_path)
108
+ model.config = config
109
+ else:
110
+ model = GLiNER(config)
111
+
112
+ if torch.cuda.is_available():
113
+ model = model.cuda()
114
+
115
+ lr_encoder = float(config.lr_encoder)
116
+ lr_others = float(config.lr_others)
117
+
118
+ optimizer = torch.optim.AdamW([
119
+ # encoder
120
+ {'params': model.token_rep_layer.parameters(), 'lr': lr_encoder},
121
+ {'params': model.rnn.parameters(), 'lr': lr_others},
122
+ # projection layers
123
+ {'params': model.span_rep_layer.parameters(), 'lr': lr_others},
124
+ {'params': model.prompt_rep_layer.parameters(), 'lr': lr_others},
125
+ ])
126
+
127
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
128
+
129
+ train(model, optimizer, data, num_steps=config.num_steps, eval_every=config.eval_every,
130
+ log_dir=config.log_dir, warmup_ratio=config.warmup_ratio, train_batch_size=config.train_batch_size,
131
+ device=device)