tgpt-opt-nano / train.py
lixiangchun's picture
initial upload
56377c9 verified
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
from dataclasses import dataclass, field
import pathlib
from typing import Dict, Optional, Sequence
import torch
import transformers
from torch.utils.data import Dataset
from transformers import Trainer
import json
IGNORE_INDEX = -100
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
@dataclass
class DataArguments:
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=8192,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizerFast):
super(SupervisedDataset, self).__init__()
logging.warning("Loading data...")
self.tokenizer = tokenizer
self.max_length = 64 # max number of genes
with open(data_path) as f:
self.list_data = [line.split()[0: self.max_length] for line in f if len(line.split()) >= self.max_length]
self.cached_input_ids = {}
def __len__(self):
return len(self.list_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
if i in self.cached_input_ids:
input_ids = self.cached_input_ids[i]
else:
input_ids = self.tokenizer(self.list_data[i], is_split_into_words=True)["input_ids"]
input_ids = torch.tensor(input_ids) # exclude EOS token
self.cached_input_ids[i] = input_ids
return dict(input_ids=input_ids, labels=input_ids)
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizerFast
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=(input_ids.ne(self.tokenizer.pad_token_id)).long(),
)
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizerFast, data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
##data_collator = transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
def train():
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
#model = transformers.AutoModelForCausalLM.from_pretrained(
# model_args.model_name_or_path,
# cache_dir=training_args.cache_dir,
#)
config = transformers.AutoConfig.from_pretrained('config.json')
model = transformers.OPTForCausalLM(config)
#model = transformers.BertForMaskedLM(config)
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)/1e+6
rank0_print(model)
rank0_print(f"model_size: {model_size:.3f} Mb")
tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained("tokenizer")
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
#trainer.train()
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
trainer.save_model(output_dir=training_args.output_dir)
if __name__ == "__main__":
train()