# coding=utf8 import os import pytorch_lightning as pl from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from transformers import AutoTokenizer class GPT2QADataset(Dataset): ''' Dataset Used for yuyuan medical qa task. Just surpport small datasets, when deal with large datasets it may be slowly. for large datasets please use mmapdatasets(doing) ''' def __init__(self, data_path, name, args): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained( args.pretrained_model_path) if self.tokenizer.pad_token is None: self.tokenizer.add_special_tokens({'pad_token': '<|endoftext|>'}) self.data_size = os.path.getsize(data_path)/1024/1024/1024 self.data_type_name = name self.data = self.load_data(data_path) self.max_seq_length = args.max_seq_length def __len__(self): return len(self.data) def __getitem__(self, index): return self.encode(self.data[index]) def load_data(self, data_path): # 有进度条展示 if self.data_size <= 5: with open(data_path, "rt", encoding='utf8') as f: lines = f.readlines() total_num = len(lines) data_gen = lines else: data_gen = open(data_path, "rt", encoding='utf8') total_num = None data = [] with tqdm(total=total_num, desc=f'{self.data_type_name}处理进度', mininterval=0.3) as bar: for idx, line in enumerate(data_gen): data.append(self.data_parse(line)) bar.update() if self.data_size > 5: data_gen.close() return data def data_parse(self, line): """ 解析不同格式的数据 """ dic = eval(line.strip()) return dic def encode(self, item): """ 将数据转换成模型训练的输入 """ inputs_dict = self.tokenizer.encode_plus(item['Question']+item['answer'], max_length=self.max_seq_length, padding='max_length', truncation=True, return_tensors='pt') target = inputs_dict['input_ids'] labels = target.clone().detach() labels[target == self.tokenizer.pad_token_id] = -100 return { "input_ids": inputs_dict['input_ids'].squeeze(), "attention_mask": inputs_dict['attention_mask'].squeeze(), "labels": labels.squeeze(), "question": item['Question'], "answer": item['answer'] } class GPT2QADataModel(pl.LightningDataModule): @staticmethod def add_data_specific_args(parent_args): parser = parent_args.add_argument_group('GPT2QADataModel') parser.add_argument('--data_dir', type=str, required=True) parser.add_argument('--num_workers', default=2, type=int) parser.add_argument('--train_data', default='train.txt', type=str) parser.add_argument('--valid_data', default='valid.txt', type=str) parser.add_argument('--test_data', default='test.txt', type=str) parser.add_argument('--train_batchsize', type=int, required=True) parser.add_argument('--valid_batchsize', type=int, required=True) parser.add_argument('--max_seq_length', default=1024, type=int) return parent_args def __init__(self, args): super().__init__() self.args = args self.train_batchsize = args.train_batchsize self.valid_batchsize = args.valid_batchsize if not args.do_eval_only: self.train_data = GPT2QADataset(os.path.join( args.data_dir, args.train_data), '训练集', args) self.valid_data = GPT2QADataset(os.path.join( args.data_dir, args.valid_data), '验证集', args) self.test_data = GPT2QADataset(os.path.join( args.data_dir, args.test_data), '测试集', args) def train_dataloader(self): return DataLoader( self.train_data, shuffle=True, batch_size=self.train_batchsize, pin_memory=False, num_workers=self.args.num_workers) def val_dataloader(self): return DataLoader(self.valid_data, shuffle=False, batch_size=self.valid_batchsize, pin_memory=False, num_workers=self.args.num_workers) def predict_dataloader(self): return DataLoader(self.test_data, shuffle=False, batch_size=self.valid_batchsize, pin_memory=False, num_workers=self.args.num_workers) if __name__ == '__main__': import argparse modelfile = '/cognitive_comp/wuziwei/pretrained_model_hf/medical_v2' datafile = '/cognitive_comp/wuziwei/task-data/medical_qa/medical_qa_train.txt' parser = argparse.ArgumentParser(description='hf test', allow_abbrev=False) group = parser.add_argument_group(title='test args') group.add_argument('--pretrained-model-path', type=str, default=modelfile, help='Number of transformer layers.') group.add_argument('--max-seq-length', type=int, default=1024) args = parser.parse_args() testml = GPT2QADataset(datafile, 'medical_qa', args=args) print(testml[10])