alexneakameni commited on
Commit
e0b8f03
·
verified ·
1 Parent(s): 3e365e5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
4
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
5
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
6
+ os.environ["WANDB_PROJECT"] = "gliner_finetuning"
7
+ # os.environ["WANDB_LOG_MODEL"] = "true"
8
+ os.environ["WANDB_WATCH"] = "none"
9
+ import argparse
10
+ import random
11
+ from glob import glob
12
+ import json
13
+
14
+ from transformers import AutoTokenizer, EarlyStoppingCallback
15
+ import torch
16
+
17
+ from gliner import GLiNERConfig, GLiNER
18
+ from gliner.training import Trainer, TrainingArguments
19
+ from gliner.data_processing.collator import DataCollatorWithPadding, DataCollator
20
+ from gliner.utils import load_config_as_namespace
21
+ from gliner.data_processing import WordsSplitter, GLiNERDataset
22
+ from utils import GLiNERConfigArgs
23
+
24
+ if __name__ == "__main__":
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument("--config", type=str, default="config/config.yaml")
27
+ parser.add_argument("--log_dir", type=str, default="data/models/")
28
+ parser.add_argument("--compile_model", type=bool, default=False)
29
+ parser.add_argument("--freeze_language_model", type=bool, default=False)
30
+ parser.add_argument("--new_data_schema", type=bool, default=False)
31
+ args = parser.parse_args()
32
+
33
+ device = (
34
+ torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
35
+ )
36
+
37
+ config: GLiNERConfigArgs = load_config_as_namespace(args.config)
38
+ config.log_dir = args.log_dir
39
+
40
+ print("Start loading dataset...")
41
+
42
+ files = glob(os.path.join(config.train_data))
43
+ data = [json.load(open(f, "r")) for f in files]
44
+ train_data = sum(data, start=[])
45
+
46
+ files = glob(os.path.join(config.val_data_dir))
47
+ data = [json.load(open(f, "r")) for f in files]
48
+ test_data = sum(data, start=[])
49
+
50
+ random.shuffle(train_data)
51
+
52
+ print("Dataset is splitted...", len(train_data), len(test_data))
53
+
54
+ if config.prev_path is not None:
55
+ tokenizer = AutoTokenizer.from_pretrained(config.prev_path)
56
+ model = GLiNER.from_pretrained(config.prev_path)
57
+ model_config = model.config
58
+ else:
59
+ model_config = GLiNERConfig(**vars(config))
60
+ tokenizer = AutoTokenizer.from_pretrained(model_config.model_name)
61
+
62
+ words_splitter = WordsSplitter(model_config.words_splitter_type)
63
+
64
+ model = GLiNER(model_config, tokenizer=tokenizer, words_splitter=words_splitter)
65
+
66
+ if not config.labels_encoder:
67
+ model_config.class_token_index = len(tokenizer)
68
+ tokenizer.add_tokens(
69
+ [model_config.ent_token, model_config.sep_token], special_tokens=True
70
+ )
71
+ model_config.vocab_size = len(tokenizer)
72
+ model.resize_token_embeddings(
73
+ [model_config.ent_token, model_config.sep_token],
74
+ set_class_token_index=False,
75
+ add_tokens_to_tokenizer=False,
76
+ )
77
+
78
+ if args.compile_model:
79
+ torch.set_float32_matmul_precision("high")
80
+ model.to(device)
81
+ model.compile_for_training()
82
+
83
+ if args.freeze_language_model:
84
+ model.model.token_rep_layer.bert_layer.model.requires_grad_(False)
85
+ else:
86
+ model.model.token_rep_layer.bert_layer.model.requires_grad_(True)
87
+
88
+ if args.new_data_schema:
89
+ train_dataset = GLiNERDataset(
90
+ train_data, model_config, tokenizer, words_splitter
91
+ )
92
+ test_dataset = GLiNERDataset(test_data, model_config, tokenizer, words_splitter)
93
+ data_collator = DataCollatorWithPadding(model_config)
94
+ else:
95
+ train_dataset = train_data
96
+ test_dataset = test_data
97
+ data_collator = DataCollator(
98
+ model.config, data_processor=model.data_processor, prepare_labels=True
99
+ )
100
+
101
+ save_steps = int(0.5 * len(train_dataset) // config.train_batch_size)
102
+
103
+ training_args = TrainingArguments(
104
+ output_dir=config.log_dir,
105
+ learning_rate=float(config.lr_encoder),
106
+ weight_decay=float(config.weight_decay_encoder),
107
+ others_lr=float(config.lr_others),
108
+ others_weight_decay=float(config.weight_decay_other),
109
+ lr_scheduler_type=config.scheduler_type,
110
+ warmup_ratio=config.warmup_ratio,
111
+ per_device_train_batch_size=config.train_batch_size,
112
+ per_device_eval_batch_size=config.train_batch_size,
113
+ max_grad_norm=config.max_grad_norm,
114
+ max_steps=config.num_steps,
115
+ evaluation_strategy=config.eval_strategy,
116
+ save_strategy=config.save_strategy,
117
+ save_steps=save_steps,
118
+ logging_steps=save_steps // 2,
119
+ save_total_limit=config.save_total_limit,
120
+ dataloader_num_workers=8,
121
+ use_cpu=False,
122
+ report_to="wandb",
123
+ bf16=True,
124
+ load_best_model_at_end=True,
125
+ )
126
+
127
+ trainer = Trainer(
128
+ model=model,
129
+ args=training_args,
130
+ train_dataset=train_dataset,
131
+ eval_dataset=test_dataset,
132
+ tokenizer=tokenizer,
133
+ data_collator=data_collator,
134
+ callbacks=[EarlyStoppingCallback(3)],
135
+ )
136
+ trainer.train()