|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import logging |
|
from dataclasses import dataclass, field |
|
import pathlib |
|
from typing import Dict, Optional, Sequence |
|
|
|
import torch |
|
import transformers |
|
from torch.utils.data import Dataset |
|
from transformers import Trainer |
|
import json |
|
|
|
IGNORE_INDEX = -100 |
|
|
|
@dataclass |
|
class ModelArguments: |
|
model_name_or_path: Optional[str] = field(default="facebook/opt-125m") |
|
|
|
|
|
@dataclass |
|
class DataArguments: |
|
data_path: str = field(default=None, metadata={"help": "Path to the training data."}) |
|
|
|
|
|
@dataclass |
|
class TrainingArguments(transformers.TrainingArguments): |
|
cache_dir: Optional[str] = field(default=None) |
|
optim: str = field(default="adamw_torch") |
|
model_max_length: int = field( |
|
default=8192, |
|
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, |
|
) |
|
|
|
local_rank = None |
|
|
|
def rank0_print(*args): |
|
if local_rank == 0: |
|
print(*args) |
|
|
|
class SupervisedDataset(Dataset): |
|
"""Dataset for supervised fine-tuning.""" |
|
|
|
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizerFast): |
|
super(SupervisedDataset, self).__init__() |
|
logging.warning("Loading data...") |
|
self.tokenizer = tokenizer |
|
self.max_length = 64 |
|
with open(data_path) as f: |
|
self.list_data = [line.split()[0: self.max_length] for line in f if len(line.split()) >= self.max_length] |
|
|
|
self.cached_input_ids = {} |
|
|
|
def __len__(self): |
|
return len(self.list_data) |
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
|
if i in self.cached_input_ids: |
|
input_ids = self.cached_input_ids[i] |
|
else: |
|
input_ids = self.tokenizer(self.list_data[i], is_split_into_words=True)["input_ids"] |
|
input_ids = torch.tensor(input_ids) |
|
self.cached_input_ids[i] = input_ids |
|
|
|
return dict(input_ids=input_ids, labels=input_ids) |
|
|
|
@dataclass |
|
class DataCollatorForSupervisedDataset(object): |
|
"""Collate examples for supervised fine-tuning.""" |
|
|
|
tokenizer: transformers.PreTrainedTokenizerFast |
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) |
|
input_ids = torch.nn.utils.rnn.pad_sequence( |
|
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id |
|
) |
|
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) |
|
return dict( |
|
input_ids=input_ids, |
|
labels=labels, |
|
attention_mask=(input_ids.ne(self.tokenizer.pad_token_id)).long(), |
|
) |
|
|
|
|
|
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizerFast, data_args) -> Dict: |
|
"""Make dataset and collator for supervised fine-tuning.""" |
|
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path) |
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) |
|
|
|
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) |
|
|
|
|
|
def train(): |
|
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
|
|
|
|
|
|
|
|
|
config = transformers.AutoConfig.from_pretrained('config.json') |
|
model = transformers.OPTForCausalLM(config) |
|
|
|
|
|
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)/1e+6 |
|
rank0_print(model) |
|
rank0_print(f"model_size: {model_size:.3f} Mb") |
|
|
|
tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained("tokenizer") |
|
|
|
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) |
|
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) |
|
|
|
|
|
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): |
|
trainer.train(resume_from_checkpoint=True) |
|
else: |
|
trainer.train() |
|
|
|
trainer.save_state() |
|
trainer.save_model(output_dir=training_args.output_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
train() |
|
|