IC4T commited on
Commit
6997035
1 Parent(s): ea27bb1
training/__init__.py ADDED
File without changes
training/consts.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DEFAULT_INPUT_MODEL = "EleutherAI/pythia-6.9b"
2
+ SUGGESTED_INPUT_MODELS = [
3
+ "EleutherAI/pythia-2.8b",
4
+ "EleutherAI/pythia-6.9b",
5
+ "EleutherAI/pythia-12b",
6
+ "EleutherAI/gpt-j-6B",
7
+ ]
8
+ INTRO_BLURB = (
9
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request."
10
+ )
11
+ INSTRUCTION_KEY = "### Instruction:"
12
+ INPUT_KEY = "Input:"
13
+ RESPONSE_KEY = "### Response:"
14
+ END_KEY = "### End"
15
+ RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n"
16
+ DEFAULT_SEED = 42
17
+
18
+ # This is a training prompt that does not contain an input string. The instruction by itself has enough information
19
+ # to respond. For example, the instruction might ask for the year a historic figure was born.
20
+ PROMPT_NO_INPUT_FORMAT = """{intro}
21
+
22
+ {instruction_key}
23
+ {instruction}
24
+
25
+ {response_key}
26
+ {response}
27
+
28
+ {end_key}""".format(
29
+ intro=INTRO_BLURB,
30
+ instruction_key=INSTRUCTION_KEY,
31
+ instruction="{instruction}",
32
+ response_key=RESPONSE_KEY,
33
+ response="{response}",
34
+ end_key=END_KEY,
35
+ )
36
+
37
+ # This is a training prompt that contains an input string that serves as context for the instruction. For example,
38
+ # the input might be a passage from Wikipedia and the intruction is to extract some information from it.
39
+ PROMPT_WITH_INPUT_FORMAT = """{intro}
40
+
41
+ {instruction_key}
42
+ {instruction}
43
+
44
+ {input_key}
45
+ {input}
46
+
47
+ {response_key}
48
+ {response}
49
+
50
+ {end_key}""".format(
51
+ intro=INTRO_BLURB,
52
+ instruction_key=INSTRUCTION_KEY,
53
+ instruction="{instruction}",
54
+ input_key=INPUT_KEY,
55
+ input="{input}",
56
+ response_key=RESPONSE_KEY,
57
+ response="{response}",
58
+ end_key=END_KEY,
59
+ )
60
+
61
+ # This is the prompt that is used for generating responses using an already trained model. It ends with the response
62
+ # key, where the job of the model is to provide the completion that follows it (i.e. the response itself).
63
+ PROMPT_FOR_GENERATION_FORMAT = """{intro}
64
+
65
+ {instruction_key}
66
+ {instruction}
67
+
68
+ {response_key}
69
+ """.format(
70
+ intro=INTRO_BLURB,
71
+ instruction_key=INSTRUCTION_KEY,
72
+ instruction="{instruction}",
73
+ response_key=RESPONSE_KEY,
74
+ )
training/generate.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ from typing import List, Tuple
4
+
5
+ import numpy as np
6
+ from transformers import (
7
+ AutoModelForCausalLM,
8
+ AutoTokenizer,
9
+ Pipeline,
10
+ PreTrainedModel,
11
+ PreTrainedTokenizer,
12
+ )
13
+
14
+ from transformers.utils import is_tf_available
15
+
16
+ if is_tf_available():
17
+ import tensorflow as tf
18
+
19
+ from .consts import END_KEY, PROMPT_FOR_GENERATION_FORMAT, RESPONSE_KEY
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def load_model_tokenizer_for_generate(
25
+ pretrained_model_name_or_path: str,
26
+ ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
27
+ """Loads the model and tokenizer so that it can be used for generating responses.
28
+
29
+ Args:
30
+ pretrained_model_name_or_path (str): name or path for model
31
+
32
+ Returns:
33
+ Tuple[PreTrainedModel, PreTrainedTokenizer]: model and tokenizer
34
+ """
35
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, padding_side="left", cache_dir="/media/siiva/DataStore/LLMs/cache/dollyV2")
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ pretrained_model_name_or_path, device_map="auto", trust_remote_code=True, cache_dir="/media/siiva/DataStore/LLMs/cache/dollyV2"
38
+ )
39
+ return model, tokenizer
40
+
41
+
42
+ def get_special_token_id(tokenizer: PreTrainedTokenizer, key: str) -> int:
43
+ """Gets the token ID for a given string that has been added to the tokenizer as a special token.
44
+
45
+ When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
46
+ treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to.
47
+
48
+ Args:
49
+ tokenizer (PreTrainedTokenizer): the tokenizer
50
+ key (str): the key to convert to a single token
51
+
52
+ Raises:
53
+ RuntimeError: if more than one ID was generated
54
+
55
+ Returns:
56
+ int: the token ID for the given key
57
+ """
58
+ token_ids = tokenizer.encode(key)
59
+ if len(token_ids) > 1:
60
+ raise RuntimeError(f"Expected only a single token for '{key}' but found {token_ids}")
61
+ return token_ids[0]
62
+
63
+
64
+ class InstructionTextGenerationPipeline(Pipeline):
65
+ def __init__(
66
+ self, *args, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs
67
+ ):
68
+ """Initialize the pipeline
69
+
70
+ Args:
71
+ do_sample (bool, optional): Whether or not to use sampling. Defaults to True.
72
+ max_new_tokens (int, optional): Max new tokens after the prompt to generate. Defaults to 128.
73
+ top_p (float, optional): If set to float < 1, only the smallest set of most probable tokens with
74
+ probabilities that add up to top_p or higher are kept for generation. Defaults to 0.92.
75
+ top_k (int, optional): The number of highest probability vocabulary tokens to keep for top-k-filtering.
76
+ Defaults to 0.
77
+ """
78
+ super().__init__(*args, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k,
79
+ **kwargs)
80
+
81
+ def _sanitize_parameters(self,
82
+ return_full_text: bool = None,
83
+ **generate_kwargs):
84
+ preprocess_params = {}
85
+
86
+ # newer versions of the tokenizer configure the response key as a special token. newer versions still may
87
+ # append a newline to yield a single token. find whatever token is configured for the response key.
88
+ tokenizer_response_key = next(
89
+ (token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None
90
+ )
91
+
92
+ response_key_token_id = None
93
+ end_key_token_id = None
94
+ if tokenizer_response_key:
95
+ try:
96
+ response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key)
97
+ end_key_token_id = get_special_token_id(self.tokenizer, END_KEY)
98
+
99
+ # Ensure generation stops once it generates "### End"
100
+ generate_kwargs["eos_token_id"] = end_key_token_id
101
+ except ValueError:
102
+ pass
103
+
104
+ forward_params = generate_kwargs
105
+ postprocess_params = {
106
+ "response_key_token_id": response_key_token_id,
107
+ "end_key_token_id": end_key_token_id
108
+ }
109
+
110
+ if return_full_text is not None:
111
+ postprocess_params["return_full_text"] = return_full_text
112
+
113
+ return preprocess_params, forward_params, postprocess_params
114
+
115
+ def preprocess(self, instruction_text, **generate_kwargs):
116
+ prompt_text = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction_text)
117
+ inputs = self.tokenizer(
118
+ prompt_text,
119
+ return_tensors="pt",
120
+ )
121
+ inputs["prompt_text"] = prompt_text
122
+ inputs["instruction_text"] = instruction_text
123
+ return inputs
124
+
125
+ def _forward(self, model_inputs, **generate_kwargs):
126
+ input_ids = model_inputs["input_ids"]
127
+ attention_mask = model_inputs.get("attention_mask", None)
128
+
129
+ if input_ids.shape[1] == 0:
130
+ input_ids = None
131
+ attention_mask = None
132
+ in_b = 1
133
+ else:
134
+ in_b = input_ids.shape[0]
135
+
136
+ generated_sequence = self.model.generate(
137
+ input_ids=input_ids.to(self.model.device),
138
+ attention_mask=attention_mask,
139
+ pad_token_id=self.tokenizer.pad_token_id,
140
+ **generate_kwargs,
141
+ )
142
+
143
+ out_b = generated_sequence.shape[0]
144
+ if self.framework == "pt":
145
+ generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
146
+ elif self.framework == "tf":
147
+ generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
148
+
149
+ instruction_text = model_inputs.pop("instruction_text")
150
+ return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
151
+
152
+ def postprocess(self, model_outputs, response_key_token_id, end_key_token_id, return_full_text: bool = False):
153
+
154
+ generated_sequence = model_outputs["generated_sequence"][0]
155
+ instruction_text = model_outputs["instruction_text"]
156
+
157
+ generated_sequence: List[List[int]] = generated_sequence.numpy().tolist()
158
+ records = []
159
+ for sequence in generated_sequence:
160
+
161
+ # The response will be set to this variable if we can identify it.
162
+ decoded = None
163
+
164
+ # If we have token IDs for the response and end, then we can find the tokens and only decode between them.
165
+ if response_key_token_id and end_key_token_id:
166
+ # Find where "### Response:" is first found in the generated tokens. Considering this is part of the
167
+ # prompt, we should definitely find it. We will return the tokens found after this token.
168
+ try:
169
+ response_pos = sequence.index(response_key_token_id)
170
+ except ValueError:
171
+ logger.warn(f"Could not find response key {response_key_token_id} in: {sequence}")
172
+ response_pos = None
173
+
174
+ if response_pos:
175
+ # Next find where "### End" is located. The model has been trained to end its responses with this
176
+ # sequence (or actually, the token ID it maps to, since it is a special token). We may not find
177
+ # this token, as the response could be truncated. If we don't find it then just return everything
178
+ # to the end. Note that even though we set eos_token_id, we still see the this token at the end.
179
+ try:
180
+ end_pos = sequence.index(end_key_token_id)
181
+ except ValueError:
182
+ end_pos = None
183
+
184
+ decoded = self.tokenizer.decode(sequence[response_pos + 1 : end_pos]).strip()
185
+
186
+ if not decoded:
187
+ # Otherwise we'll decode everything and use a regex to find the response and end.
188
+
189
+ fully_decoded = self.tokenizer.decode(sequence)
190
+
191
+ # The response appears after "### Response:". The model has been trained to append "### End" at the
192
+ # end.
193
+ m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
194
+
195
+ if m:
196
+ decoded = m.group(1).strip()
197
+ else:
198
+ # The model might not generate the "### End" sequence before reaching the max tokens. In this case,
199
+ # return everything after "### Response:".
200
+ m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
201
+ if m:
202
+ decoded = m.group(1).strip()
203
+ else:
204
+ logger.warn(f"Failed to find response in:\n{fully_decoded}")
205
+
206
+ # If the full text is requested, then append the decoded text to the original instruction.
207
+ # This technically isn't the full text, as we format the instruction in the prompt the model has been
208
+ # trained on, but to the client it will appear to be the full text.
209
+ if return_full_text:
210
+ decoded = f"{instruction_text}\n{decoded}"
211
+
212
+ rec = {"generated_text": decoded}
213
+
214
+ records.append(rec)
215
+
216
+ return records
217
+
218
+
219
+ def generate_response(
220
+ instruction: str,
221
+ *,
222
+ model: PreTrainedModel,
223
+ tokenizer: PreTrainedTokenizer,
224
+ **kwargs,
225
+ ) -> str:
226
+ """Given an instruction, uses the model and tokenizer to generate a response. This formats the instruction in
227
+ the instruction format that the model was fine-tuned on.
228
+
229
+ Args:
230
+ instruction (str): _description_
231
+ model (PreTrainedModel): the model to use
232
+ tokenizer (PreTrainedTokenizer): the tokenizer to use
233
+
234
+ Returns:
235
+ str: response
236
+ """
237
+
238
+ generation_pipeline = InstructionTextGenerationPipeline(model=model, tokenizer=tokenizer, **kwargs)
239
+ return generation_pipeline(instruction)[0]["generated_text"]
training/trainer.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Databricks, Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from functools import partial
17
+ from pathlib import Path
18
+ from typing import Any, Dict, List, Tuple, Union
19
+
20
+ import click
21
+ import numpy as np
22
+ from datasets import Dataset, load_dataset
23
+ from transformers import (
24
+ AutoModelForCausalLM,
25
+ AutoTokenizer,
26
+ DataCollatorForLanguageModeling,
27
+ PreTrainedTokenizer,
28
+ Trainer,
29
+ TrainingArguments,
30
+ set_seed,
31
+ )
32
+
33
+ from .consts import (
34
+ DEFAULT_INPUT_MODEL,
35
+ DEFAULT_SEED,
36
+ PROMPT_WITH_INPUT_FORMAT,
37
+ PROMPT_NO_INPUT_FORMAT,
38
+ END_KEY,
39
+ INSTRUCTION_KEY,
40
+ RESPONSE_KEY_NL,
41
+ )
42
+
43
+ logger = logging.getLogger(__name__)
44
+ ROOT_PATH = Path(__file__).parent.parent
45
+ DATABRICKS_DOLLY_15K_PATH = ROOT_PATH / "data" / "databricks-dolly-15k.jsonl"
46
+
47
+
48
+ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
49
+ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
50
+ batch = super().torch_call(examples)
51
+
52
+ # The prompt ends with the response key plus a newline. We encode this and then try to find it in the
53
+ # sequence of tokens. This should just be a single token.
54
+ response_token_ids = self.tokenizer.encode(RESPONSE_KEY_NL)
55
+
56
+ labels = batch["labels"].clone()
57
+
58
+ for i in range(len(examples)):
59
+
60
+ response_token_ids_start_idx = None
61
+ for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]:
62
+ response_token_ids_start_idx = idx
63
+ break
64
+
65
+ if response_token_ids_start_idx is None:
66
+ raise RuntimeError(
67
+ f'Could not find response key {response_token_ids} in token IDs {batch["labels"][i]}'
68
+ )
69
+
70
+ response_token_ids_end_idx = response_token_ids_start_idx + 1
71
+
72
+ # Make pytorch loss function ignore all tokens up through the end of the response key
73
+ labels[i, :response_token_ids_end_idx] = -100
74
+
75
+ batch["labels"] = labels
76
+
77
+ return batch
78
+
79
+
80
+ def preprocess_batch(batch: Dict[str, List], tokenizer: AutoTokenizer, max_length: int) -> dict:
81
+ return tokenizer(
82
+ batch["text"],
83
+ max_length=max_length,
84
+ truncation=True,
85
+ )
86
+
87
+
88
+ def load_training_dataset() -> Dataset:
89
+ logger.info(f"Loading dataset from {DATABRICKS_DOLLY_15K_PATH}")
90
+ dataset = load_dataset("json", data_files=str(DATABRICKS_DOLLY_15K_PATH))["train"]
91
+ logger.info("Found %d rows", dataset.num_rows)
92
+
93
+ def _add_text(rec):
94
+ instruction = rec["instruction"]
95
+ response = rec["response"]
96
+ context = rec.get("context")
97
+
98
+ if not instruction:
99
+ raise ValueError(f"Expected an instruction in: {rec}")
100
+
101
+ if not response:
102
+ raise ValueError(f"Expected a response in: {rec}")
103
+
104
+ # For some instructions there is an input that goes along with the instruction, providing context for the
105
+ # instruction. For example, the input might be a passage from Wikipedia and the instruction says to extract
106
+ # some piece of information from it. The response is that information to extract. In other cases there is
107
+ # no input. For example, the instruction might be open QA such as asking what year some historic figure was
108
+ # born.
109
+ if context:
110
+ rec["text"] = PROMPT_WITH_INPUT_FORMAT.format(instruction=instruction, response=response, input=context)
111
+ else:
112
+ rec["text"] = PROMPT_NO_INPUT_FORMAT.format(instruction=instruction, response=response)
113
+ return rec
114
+
115
+ dataset = dataset.map(_add_text)
116
+
117
+ return dataset
118
+
119
+
120
+ def load_tokenizer(pretrained_model_name_or_path: str = DEFAULT_INPUT_MODEL) -> PreTrainedTokenizer:
121
+ logger.info(f"Loading tokenizer for {pretrained_model_name_or_path}")
122
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
123
+ tokenizer.pad_token = tokenizer.eos_token
124
+ tokenizer.add_special_tokens({"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]})
125
+ return tokenizer
126
+
127
+
128
+ def load_model(
129
+ pretrained_model_name_or_path: str = DEFAULT_INPUT_MODEL, *, gradient_checkpointing: bool = False
130
+ ) -> AutoModelForCausalLM:
131
+ logger.info(f"Loading model for {pretrained_model_name_or_path}")
132
+ model = AutoModelForCausalLM.from_pretrained(
133
+ pretrained_model_name_or_path, trust_remote_code=True, use_cache=False if gradient_checkpointing else True
134
+ )
135
+ return model
136
+
137
+
138
+ def get_model_tokenizer(
139
+ pretrained_model_name_or_path: str = DEFAULT_INPUT_MODEL, *, gradient_checkpointing: bool = False
140
+ ) -> Tuple[AutoModelForCausalLM, PreTrainedTokenizer]:
141
+ tokenizer = load_tokenizer(pretrained_model_name_or_path)
142
+ model = load_model(pretrained_model_name_or_path, gradient_checkpointing=gradient_checkpointing)
143
+ model.resize_token_embeddings(len(tokenizer))
144
+
145
+ return model, tokenizer
146
+
147
+
148
+ def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, seed=DEFAULT_SEED) -> Dataset:
149
+ """Loads the training dataset and tokenizes it so it is ready for training.
150
+
151
+ Args:
152
+ tokenizer (AutoTokenizer): Tokenizer tied to the model.
153
+ max_length (int): Maximum number of tokens to emit from tokenizer.
154
+
155
+ Returns:
156
+ Dataset: HuggingFace dataset
157
+ """
158
+
159
+ dataset = load_training_dataset()
160
+
161
+ logger.info("Preprocessing dataset")
162
+ _preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer)
163
+ dataset = dataset.map(
164
+ _preprocessing_function,
165
+ batched=True,
166
+ remove_columns=["instruction", "context", "response", "text", "category"],
167
+ )
168
+
169
+ # Make sure we don't have any truncated records, as this would mean the end keyword is missing.
170
+ logger.info("Processed dataset has %d rows", dataset.num_rows)
171
+ dataset = dataset.filter(lambda rec: len(rec["input_ids"]) < max_length)
172
+ logger.info("Processed dataset has %d rows after filtering for truncated records", dataset.num_rows)
173
+
174
+ logger.info("Shuffling dataset")
175
+ dataset = dataset.shuffle(seed=seed)
176
+
177
+ logger.info("Done preprocessing")
178
+
179
+ return dataset
180
+
181
+
182
+ def train(
183
+ *,
184
+ input_model: str,
185
+ local_output_dir: str,
186
+ dbfs_output_dir: str,
187
+ epochs: int,
188
+ per_device_train_batch_size: int,
189
+ per_device_eval_batch_size: int,
190
+ lr: float,
191
+ seed: int,
192
+ deepspeed: str,
193
+ gradient_checkpointing: bool,
194
+ local_rank: str,
195
+ bf16: bool,
196
+ logging_steps: int,
197
+ save_steps: int,
198
+ eval_steps: int,
199
+ test_size: Union[float, int],
200
+ save_total_limit: int,
201
+ warmup_steps: int,
202
+ ):
203
+ set_seed(seed)
204
+
205
+ model, tokenizer = get_model_tokenizer(
206
+ pretrained_model_name_or_path=input_model, gradient_checkpointing=gradient_checkpointing
207
+ )
208
+
209
+ # Use the same max length that the model supports. Fall back to 1024 if the setting can't be found.
210
+ # The configuraton for the length can be stored under different names depending on the model. Here we attempt
211
+ # a few possible names we've encountered.
212
+ conf = model.config
213
+ max_length = None
214
+ for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
215
+ max_length = getattr(model.config, length_setting, None)
216
+ if max_length:
217
+ logger.info(f"Found max lenth: {max_length}")
218
+ break
219
+ if not max_length:
220
+ max_length = 1024
221
+ logger.info(f"Using default max length: {max_length}")
222
+
223
+ processed_dataset = preprocess_dataset(tokenizer=tokenizer, max_length=max_length, seed=seed)
224
+
225
+ split_dataset = processed_dataset.train_test_split(test_size=test_size, seed=seed)
226
+
227
+ logger.info("Train data size: %d", split_dataset["train"].num_rows)
228
+ logger.info("Test data size: %d", split_dataset["test"].num_rows)
229
+
230
+ data_collator = DataCollatorForCompletionOnlyLM(
231
+ tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8
232
+ )
233
+
234
+ if not dbfs_output_dir:
235
+ logger.warn("Will NOT save to DBFS")
236
+
237
+ training_args = TrainingArguments(
238
+ output_dir=local_output_dir,
239
+ per_device_train_batch_size=per_device_train_batch_size,
240
+ per_device_eval_batch_size=per_device_eval_batch_size,
241
+ fp16=False,
242
+ bf16=bf16,
243
+ learning_rate=lr,
244
+ num_train_epochs=epochs,
245
+ deepspeed=deepspeed,
246
+ gradient_checkpointing=gradient_checkpointing,
247
+ logging_dir=f"{local_output_dir}/runs",
248
+ logging_strategy="steps",
249
+ logging_steps=logging_steps,
250
+ evaluation_strategy="steps",
251
+ eval_steps=eval_steps,
252
+ save_strategy="steps",
253
+ save_steps=save_steps,
254
+ save_total_limit=save_total_limit,
255
+ load_best_model_at_end=False,
256
+ report_to="tensorboard",
257
+ disable_tqdm=True,
258
+ remove_unused_columns=False,
259
+ local_rank=local_rank,
260
+ warmup_steps=warmup_steps,
261
+ )
262
+
263
+ logger.info("Instantiating Trainer")
264
+
265
+ trainer = Trainer(
266
+ model=model,
267
+ tokenizer=tokenizer,
268
+ args=training_args,
269
+ train_dataset=split_dataset["train"],
270
+ eval_dataset=split_dataset["test"],
271
+ data_collator=data_collator,
272
+ )
273
+
274
+ logger.info("Training")
275
+ trainer.train()
276
+
277
+ logger.info(f"Saving Model to {local_output_dir}")
278
+ trainer.save_model(output_dir=local_output_dir)
279
+
280
+ if dbfs_output_dir:
281
+ logger.info(f"Saving Model to {dbfs_output_dir}")
282
+ trainer.save_model(output_dir=dbfs_output_dir)
283
+
284
+ logger.info("Done.")
285
+
286
+
287
+ @click.command()
288
+ @click.option("--input-model", type=str, help="Input model to fine tune", default=DEFAULT_INPUT_MODEL)
289
+ @click.option("--local-output-dir", type=str, help="Write directly to this local path", required=True)
290
+ @click.option("--dbfs-output-dir", type=str, help="Sync data to this path on DBFS")
291
+ @click.option("--epochs", type=int, default=3, help="Number of epochs to train for.")
292
+ @click.option("--per-device-train-batch-size", type=int, default=8, help="Batch size to use for training.")
293
+ @click.option("--per-device-eval-batch-size", type=int, default=8, help="Batch size to use for evaluation.")
294
+ @click.option(
295
+ "--test-size", type=int, default=1000, help="Number of test records for evaluation, or ratio of test records."
296
+ )
297
+ @click.option("--warmup-steps", type=int, default=None, help="Number of steps to warm up to learning rate")
298
+ @click.option("--logging-steps", type=int, default=10, help="How often to log")
299
+ @click.option("--eval-steps", type=int, default=50, help="How often to run evaluation on test records")
300
+ @click.option("--save-steps", type=int, default=400, help="How often to checkpoint the model")
301
+ @click.option("--save-total-limit", type=int, default=10, help="Maximum number of checkpoints to keep on disk")
302
+ @click.option("--lr", type=float, default=1e-5, help="Learning rate to use for training.")
303
+ @click.option("--seed", type=int, default=DEFAULT_SEED, help="Seed to use for training.")
304
+ @click.option("--deepspeed", type=str, default=None, help="Path to deepspeed config file.")
305
+ @click.option(
306
+ "--gradient-checkpointing/--no-gradient-checkpointing",
307
+ is_flag=True,
308
+ default=True,
309
+ help="Use gradient checkpointing?",
310
+ )
311
+ @click.option(
312
+ "--local_rank",
313
+ type=str,
314
+ default=True,
315
+ help="Provided by deepspeed to identify which instance this process is when performing multi-GPU training.",
316
+ )
317
+ @click.option("--bf16", type=bool, default=True, help="Whether to use bf16 (preferred on A100's).")
318
+ def main(**kwargs):
319
+ train(**kwargs)
320
+
321
+
322
+ if __name__ == "__main__":
323
+ logging.basicConfig(
324
+ format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
325
+ )
326
+ try:
327
+ main()
328
+ except Exception:
329
+ logger.exception("main failed")
330
+ raise