StarCycle commited on
Commit
32bafc1
1 Parent(s): 7ff4b3a
iter_11628.pth/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80fa8b0a2d18a086edc4427abe3235ce83ac62eec9fba40b4d363f554a87b3b6
3
+ size 19676784
iter_11628.pth/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6fee544a3b533b90133bdb531442b0424b6fc1df8ed938035368397ba5614129
3
+ size 19676912
iter_11628.pth/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d054aaf321ebccde7b7abf24f88f781ad9712d41fdd65b68c4c205251131bccb
3
+ size 19676784
iter_11628.pth/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:935528913d084969a5cc9ea0294e7e34ad48e60dc6c85ab014cf1da9030f8ed7
3
+ size 19676848
iter_11628.pth/mp_rank_00_model_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:260b0e1988cd523cc76ef05f76ba172fd1de395a83d565fc8f2625bb960b153e
3
+ size 13988972
pretrain.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.dataset import DefaultSampler
4
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
5
+ LoggerHook, ParamSchedulerHook)
6
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, SiglipImageProcessor,
10
+ SiglipVisionModel)
11
+
12
+ from xtuner.dataset import LLaVADataset
13
+ from xtuner.dataset.collate_fns import default_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook
16
+ from xtuner.engine.runner import TrainLoop
17
+ from xtuner.model import LLaVAModel
18
+ from xtuner.utils import PROMPT_TEMPLATE
19
+
20
+ #######################################################################
21
+ # PART 1 Settings #
22
+ #######################################################################
23
+ # Model
24
+ llm_name_or_path = 'internlm/internlm2-1_8b'
25
+ visual_encoder_name_or_path = 'google/siglip-so400m-patch14-384'
26
+
27
+ # Data
28
+ data_root = './'
29
+ data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json'
30
+ image_folder = data_root + 'LLaVA-Pretrain/images'
31
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
32
+ max_length = int(2048 - (336 / 14)**2)
33
+
34
+ # Scheduler & Optimizer
35
+ batch_size = 12 # per_device
36
+ accumulative_counts = 5
37
+ dataloader_num_workers = 12
38
+ max_epochs = 1
39
+ optim_type = AdamW
40
+ lr = 1e-3
41
+ betas = (0.9, 0.999)
42
+ weight_decay = 0
43
+ max_norm = 1 # grad clip
44
+ warmup_ratio = 0.03
45
+
46
+ # Save
47
+ save_steps = 2000
48
+ save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
49
+
50
+ # Evaluate the generation performance during the training
51
+ evaluation_freq = 2000
52
+ SYSTEM = ''
53
+ evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
54
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
55
+
56
+ #######################################################################
57
+ # PART 2 Model & Tokenizer & Image Processor #
58
+ #######################################################################
59
+ tokenizer = dict(
60
+ type=AutoTokenizer.from_pretrained,
61
+ pretrained_model_name_or_path=llm_name_or_path,
62
+ trust_remote_code=True,
63
+ padding_side='right')
64
+
65
+ image_processor = dict(
66
+ type=SiglipImageProcessor.from_pretrained,
67
+ pretrained_model_name_or_path=visual_encoder_name_or_path,
68
+ trust_remote_code=True)
69
+
70
+ model = dict(
71
+ type=LLaVAModel,
72
+ freeze_llm=True,
73
+ freeze_visual_encoder=True,
74
+ llm=dict(
75
+ type=AutoModelForCausalLM.from_pretrained,
76
+ pretrained_model_name_or_path=llm_name_or_path,
77
+ trust_remote_code=True,
78
+ torch_dtype=torch.float16,
79
+ quantization_config=dict(
80
+ type=BitsAndBytesConfig,
81
+ load_in_4bit=True,
82
+ load_in_8bit=False,
83
+ llm_int8_threshold=6.0,
84
+ llm_int8_has_fp16_weight=False,
85
+ bnb_4bit_compute_dtype=torch.float16,
86
+ bnb_4bit_use_double_quant=True,
87
+ bnb_4bit_quant_type='nf4')),
88
+ visual_encoder=dict(
89
+ type=SiglipVisionModel.from_pretrained,
90
+ pretrained_model_name_or_path=visual_encoder_name_or_path))
91
+
92
+ #######################################################################
93
+ # PART 3 Dataset & Dataloader #
94
+ #######################################################################
95
+ llava_dataset = dict(
96
+ type=LLaVADataset,
97
+ data_path=data_path,
98
+ image_folder=image_folder,
99
+ tokenizer=tokenizer,
100
+ image_processor=image_processor,
101
+ dataset_map_fn=llava_map_fn,
102
+ template_map_fn=dict(
103
+ type=template_map_fn_factory, template=prompt_template),
104
+ max_length=max_length,
105
+ pad_image_to_square=False)
106
+
107
+ train_dataloader = dict(
108
+ batch_size=batch_size,
109
+ num_workers=dataloader_num_workers,
110
+ dataset=llava_dataset,
111
+ sampler=dict(type=DefaultSampler, shuffle=True),
112
+ collate_fn=dict(type=default_collate_fn))
113
+
114
+ #######################################################################
115
+ # PART 4 Scheduler & Optimizer #
116
+ #######################################################################
117
+ # optimizer
118
+ optim_wrapper = dict(
119
+ type=AmpOptimWrapper,
120
+ optimizer=dict(
121
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
122
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
123
+ accumulative_counts=accumulative_counts,
124
+ loss_scale='dynamic',
125
+ dtype='float16')
126
+
127
+ # learning policy
128
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
129
+ param_scheduler = [
130
+ dict(
131
+ type=LinearLR,
132
+ start_factor=1e-5,
133
+ by_epoch=True,
134
+ begin=0,
135
+ end=warmup_ratio * max_epochs,
136
+ convert_to_iter_based=True),
137
+ dict(
138
+ type=CosineAnnealingLR,
139
+ eta_min=0.0,
140
+ by_epoch=True,
141
+ begin=warmup_ratio * max_epochs,
142
+ end=max_epochs,
143
+ convert_to_iter_based=True)
144
+ ]
145
+
146
+ # train, val, test setting
147
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
148
+
149
+ #######################################################################
150
+ # PART 5 Runtime #
151
+ #######################################################################
152
+ # Log the dialogue periodically during the training process, optional
153
+ custom_hooks = [
154
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
155
+ dict(
156
+ type=EvaluateChatHook,
157
+ tokenizer=tokenizer,
158
+ image_processor=image_processor,
159
+ every_n_iters=evaluation_freq,
160
+ evaluation_inputs=evaluation_inputs,
161
+ evaluation_images=evaluation_images,
162
+ system=SYSTEM,
163
+ prompt_template=prompt_template)
164
+ ]
165
+
166
+ # configure default hooks
167
+ default_hooks = dict(
168
+ # record the time of every iteration.
169
+ timer=dict(type=IterTimerHook),
170
+ # print log every 10 iterations.
171
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
172
+ # enable the parameter scheduler.
173
+ param_scheduler=dict(type=ParamSchedulerHook),
174
+ # save checkpoint per `save_steps`.
175
+ checkpoint=dict(
176
+ type=CheckpointHook,
177
+ by_epoch=False,
178
+ interval=save_steps,
179
+ max_keep_ckpts=save_total_limit),
180
+ # set sampler seed in distributed evrionment.
181
+ sampler_seed=dict(type=DistSamplerSeedHook),
182
+ )
183
+
184
+ # configure environment
185
+ env_cfg = dict(
186
+ # whether to enable cudnn benchmark
187
+ cudnn_benchmark=False,
188
+ # set multi process parameters
189
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
190
+ # set distributed parameters
191
+ dist_cfg=dict(backend='nccl'),
192
+ )
193
+
194
+ # set visualizer
195
+ from mmengine.visualization import Visualizer, TensorboardVisBackend
196
+ visualizer = dict(
197
+ type=Visualizer,
198
+ vis_backends=[dict(type=TensorboardVisBackend)]
199
+ )
200
+
201
+ # set log level
202
+ log_level = 'INFO'
203
+
204
+ # load from which checkpoint
205
+ load_from = None
206
+
207
+ # whether to resume training from the loaded checkpoint
208
+ resume = False
209
+
210
+ # Defaults to use random seed and disable `deterministic`
211
+ randomness = dict(seed=None, deterministic=False)
212
+
213
+ # set log processor
214
+ log_processor = dict(by_epoch=False)