from fastchat.train.llama_flash_attn_monkey_patch import ( replace_llama_attn_with_flash_attn, ) replace_llama_attn_with_flash_attn() import json from torch.utils.data import Dataset from accelerate import Accelerator from transformers import AutoModelForCausalLM, AutoTokenizer, AdamW import torch from torch.nn.utils.rnn import pad_sequence from tqdm import tqdm import numpy as np IGNORE_TOKEN_ID = -100 class MixData(Dataset): def __init__(self, dataset, ratio, tokenizer): super(Dataset, self).__init__() self.dataset = dataset self.data_size = [len(c) for c in self.dataset] ratio = [r if isinstance(r, int) else s for r, s in zip(ratio, self.data_size)] self.ratio = ratio self.tokenizer = tokenizer self.sample_size = [int(self.data_size[0] / self.ratio[0] * r) for r in self.ratio] print(self.data_size, self.sample_size, [c1 / c2 for c1, c2 in zip(self.sample_size, self.data_size)]) @staticmethod def rounder(number): rand = np.random.rand() if rand < number - int(number): return int(number) + 1 else: return int(number) @staticmethod def choice_index(number, sample_size): for i in range(len(sample_size)): if number < sum(sample_size[:i + 1]): return i, number - sum(sample_size[:i]) def __getitem__(self, index): corpus_id, index = self.choice_index(index, self.sample_size) rand = np.random.rand() index = self.rounder((index + rand) / self.sample_size[corpus_id] * self.data_size[corpus_id]) index = min(index, len(self.dataset[corpus_id]) - 1) return self.dataset[corpus_id][index] def __len__(self): return sum(self.sample_size) def set_ratio(self, ratio): self.ratio = ratio self.data_size = [len(c) for c in self.dataset] self.sample_size = [int(self.data_size[0] / self.ratio[0] * r) for r in self.ratio] print(self.data_size, self.sample_size, [c1 / c2 for c1, c2 in zip(self.sample_size, self.data_size)]) def collate_fn(self, data): input_ids, labels = zip(*data) input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = pad_sequence(labels, batch_first=True, padding_value=-100) attention_mask = input_ids.ne(self.tokenizer.pad_token_id) features = { 'input_ids': input_ids.long(), 'labels': labels.long(), 'attention_mask': attention_mask.long(), } return features def last_index(lst, value): return next((len(lst) - i - 1 for i, x in enumerate(lst[::-1]) if x != value), -1) def safe_ids(ids, max_value, pad_id): return [i if i < max_value else pad_id for i in ids] def tokenize(messages, tokenizer): roles = {"user": "USER", "assistant": "ASSISTANT"} input_ids = [] labels = [] system = "A chat between a curious user and an artificial intelligence assistant. " \ "The assistant gives helpful, detailed, and polite answers to the user's questions." system_ids = tokenizer.encode(system, add_special_tokens=False) input_ids += system_ids labels += [IGNORE_TOKEN_ID] * len(system_ids) for i, turn in enumerate(messages): role = roles.get(turn['role'], 'USER') content = turn['content'] content = content.strip() if role == 'ASSISTANT': content += '' role_ids = tokenizer.encode(role + ":", add_special_tokens=False) content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True, max_length=2048) input_ids += role_ids + content_ids if role == 'ASSISTANT': labels += [IGNORE_TOKEN_ID] * len(role_ids) + content_ids else: labels += [IGNORE_TOKEN_ID] * (len(role_ids) + len(content_ids)) input_ids = input_ids[:4096] labels = labels[:4096] trunc_id = last_index(labels, -100) + 1 input_ids = input_ids[:trunc_id] labels = labels[:trunc_id] if len(labels) == 0: input_ids, labels = [0, 0], [-100, -100] input_ids = safe_ids(input_ids, 64000, 0) labels = safe_ids(labels, 64000, -100) return input_ids, labels class VicunaData(Dataset): def __init__(self, data, tokenizer): self.data = data self.tokenizer = tokenizer def __len__(self): return len(self.data) def __getitem__(self, item): item = self.data[item] input_ids, labels = tokenize(item, self.tokenizer) return torch.tensor(input_ids), torch.tensor(labels) def collate_fn(self, data): input_ids, labels = zip(*data) input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = pad_sequence(labels, batch_first=True, padding_value=-100) attention_mask = input_ids.ne(self.tokenizer.pad_token_id) features = { 'input_ids': input_ids.long(), 'labels': labels.long(), 'attention_mask': attention_mask.long(), } return features def main(): accelerator = Accelerator(gradient_accumulation_steps=8) batch_size = 4 save_path = 'out/baichuan-vicuna-7b' model_name = './models/baichuan-llama-7b' tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, padding_side="right", model_max_length=4096) tokenizer.pad_token = tokenizer.unk_token model = AutoModelForCausalLM.from_pretrained(model_name) model.config.use_cache = False model.gradient_checkpointing_enable() share_gpt = VicunaData(json.load(open('data/new/share_gpt-90k.json')), tokenizer) instruction = VicunaData(json.load(open('data/new/cot-75k.json')), tokenizer) code = VicunaData(json.load(open('data/new/leet-9k.json')), tokenizer) dataset = MixData([share_gpt, instruction, code], [len(share_gpt), len(instruction), len(code)], tokenizer) print(len(dataset)) data_loader = torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, batch_size=batch_size, num_workers=0, shuffle=True) optimizer = AdamW(model.parameters(), 2e-5) model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader) for epoch in range(10): accelerator.print(f'Training {save_path} {epoch}') accelerator.wait_for_everyone() model.train() tk0 = tqdm(data_loader, total=len(data_loader)) loss_report = [] for batch in tk0: with accelerator.accumulate(model): try: out = model(**batch) loss = out.loss except: loss = torch.tensor(0., device=model.device, requires_grad=True) if loss.isnan(): print(loss) print(batch) loss = torch.tensor(0., device=model.device, requires_grad=True) accelerator.backward(loss) accelerator.clip_grad_norm_(model.parameters(), 1.) optimizer.step() optimizer.zero_grad() loss_report.append(accelerator.gather(loss).mean().item()) tk0.set_postfix(loss=sum(loss_report[-100:]) / len(loss_report[-100:])) accelerator.wait_for_everyone() model.save_checkpoint(f'{save_path}/{epoch}') if __name__ == '__main__': main()