#%% import torch import numpy as np from torch.autograd import Variable # from sklearn import metrics import datetime from typing import Dict, Tuple, List import logging import os import utils import pickle as pkl import json import torch.backends.cudnn as cudnn from tqdm import tqdm import sys sys.path.append("..") import Parameters parser = utils.get_argument_parser() parser = utils.add_attack_parameters(parser) parser.add_argument('--mode', type=str, default='sentence', help='sentence, biogpt or finetune') parser.add_argument('--ratio', type = str, default='', help='ratio of the number of changed words') args = parser.parse_args() args = utils.set_hyperparams(args) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") utils.seed_all(args.seed) np.set_printoptions(precision=5) cudnn.benchmark = False data_path = os.path.join('processed_data', args.data) target_path = os.path.join(data_path, 'DD_target_{0}_{1}_{2}_{3}_{4}_{5}.txt'.format(args.model, args.data, args.target_split, args.target_size, 'exists:'+str(args.target_existed), args.attack_goal)) attack_path = os.path.join('attack_results', args.data, 'cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}.txt'.format(args.model, args.target_split, args.target_size, 'exists:'+str(args.target_existed), args.neighbor_num, args.candidate_mode, args.attack_goal, str(args.reasonable_rate))) # target_data = utils.load_data(target_path) attack_data = utils.load_data(attack_path, drop=False) # assert target_data.shape == attack_data.shape #%% with open(os.path.join(data_path, 'entities_reverse_dict.json')) as fl: id_to_meshid = json.load(fl) with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl: entity_raw_name = pkl.load(fl) with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl: retieve_sentence_through_edgetype = pkl.load(fl) with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl: raw_text_sen = pkl.load(fl) if not os.path.exists('generate_abstract/valid_entity.json'): valid_entity = set() for paper_id, paper in raw_text_sen.items(): for sen_id, sen in paper.items(): text = sen['text'].split(' ') for a in text: if '_' in a: valid_entity.add(a.replace('_', ' ')) with open('valid_entity.json', 'w') as fl: json.dump(list(valid_entity), fl, indent=4) print('Valid entity saved!!') if args.mode == 'sentence': import torch from torch.nn.modules.loss import CrossEntropyLoss from transformers import AutoTokenizer from transformers import BioGptForCausalLM criterion = CrossEntropyLoss(reduction="none") print('Generating GPT input ...') tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt') tokenizer.pad_token = tokenizer.eos_token model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=tokenizer.eos_token_id) model.to(device) model.eval() GPT_batch_size = 32 single_sentence = {} test_text = [] test_dp = [] test_parse = [] for i, (s, r, o) in enumerate(tqdm(attack_data)): if int(s) != -1: dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual'] candidate_sen = [] Dp_path = [] L = len(dependency_sen_dict.keys()) bound = 500 // L if bound == 0: bound = 1 for dp_path, sen_list in dependency_sen_dict.items(): if len(sen_list) > bound: index = np.random.choice(np.array(range(len(sen_list))), bound, replace=False) sen_list = [sen_list[aa] for aa in index] candidate_sen += sen_list Dp_path += [dp_path] * len(sen_list) text_s = entity_raw_name[id_to_meshid[s]] text_o = entity_raw_name[id_to_meshid[o]] candidate_text_sen = [] candidate_ori_sen = [] candidate_parse_sen = [] for paper_id, sen_id in candidate_sen: sen = raw_text_sen[paper_id][sen_id] text = sen['text'] candidate_ori_sen.append(text) ss = sen['start_formatted'] oo = sen['end_formatted'] text = text.replace('-LRB-', '(') text = text.replace('-RRB-', ')') text = text.replace('-LSB-', '[') text = text.replace('-RSB-', ']') text = text.replace('-LCB-', '{') text = text.replace('-RCB-', '}') parse_text = text parse_text = parse_text.replace(ss, text_s.replace(' ', '_')) parse_text = parse_text.replace(oo, text_o.replace(' ', '_')) text = text.replace(ss, text_s) text = text.replace(oo, text_o) text = text.replace('_', ' ') candidate_text_sen.append(text) candidate_parse_sen.append(parse_text) tokens = tokenizer( candidate_text_sen, truncation = True, padding = True, max_length = 300, return_tensors="pt") target_ids = tokens['input_ids'].to(device) attention_mask = tokens['attention_mask'].to(device) L = len(candidate_text_sen) assert L > 0 ret_log_L = [] for l in range(0, L, GPT_batch_size): R = min(L, l + GPT_batch_size) target = target_ids[l:R, :] attention = attention_mask[l:R, :] outputs = model(input_ids = target, attention_mask = attention, labels = target) logits = outputs.logits shift_logits = logits[..., :-1, :].contiguous() shift_labels = target[..., 1:].contiguous() Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1)) Loss = Loss.view(-1, shift_logits.shape[1]) attention = attention[..., 1:].contiguous() log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1)) ret_log_L.append(log_Loss.detach()) ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy()) sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen)) sen_score.sort(key = lambda x: x[1]) test_text.append(sen_score[0][2]) test_dp.append(sen_score[0][3]) test_parse.append(sen_score[0][4]) single_sentence.update({f'{s}_{r}_{o}_{i}': sen_score[0][0]}) else: single_sentence.update({f'{s}_{r}_{o}_{i}': ''}) with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_sentence.json', 'w') as fl: json.dump(single_sentence, fl, indent=4) # with open('generate_abstract/test.txt', 'w') as fl: # fl.write('\n'.join(test_text)) # with open('generate_abstract/dp.txt', 'w') as fl: # fl.write('\n'.join(test_dp)) with open (f'generate_abstract/path/{args.target_split}_{args.reasonable_rate}_path.json', 'w') as fl: fl.write('\n'.join(test_dp)) with open (f'generate_abstract/path/{args.target_split}_{args.reasonable_rate}_temp.json', 'w') as fl: fl.write('\n'.join(test_text)) elif args.mode == 'finetune': import spacy import pprint from transformers import AutoModel, AutoTokenizer,BartForConditionalGeneration print('Finetuning ...') with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_chat.json', 'r') as fl: draft = json.load(fl) with open (f'generate_abstract/path/{args.target_split}_{args.reasonable_rate}_path.json', 'r') as fl: dpath = fl.readlines() nlp = spacy.load("en_core_web_sm") if os.path.exists(f'generate_abstract/bioBART/{args.target_split}_{args.reasonable_rate}{args.ratio}_candidates.json'): with open(f'generate_abstract/bioBART/{args.target_split}_{args.reasonable_rate}{args.ratio}_candidates.json', 'r') as fl: ret_candidates = json.load(fl) # if False: # pass else: def find_mini_span(vec, words, check_set): def cal(text, sset): add = 0 for tt in sset: if tt in text: add += 1 return add text = ' '.join(words) max_add = cal(text, check_set) minn = 10000000 span = '' rc = None for i in range(len(vec)): if vec[i] == True: p = -1 for j in range(i+1, len(vec)+1): if vec[j-1] == True: text = ' '.join(words[i:j]) if cal(text, check_set) == max_add: p = j break if p > 0: if (p-i) < minn: minn = p-i span = ' '.join(words[i:p]) rc = (i, p) if rc: for i in range(rc[0], rc[1]): vec[i] = True return vec, span def mask_func(tokenized_sen): if len(tokenized_sen) == 0: return [] token_list = [] # for sen in tokenized_sen: # for token in sen: # token_list.append(token) for sen in tokenized_sen: token_list += sen.text.split(' ') if args.ratio == '': P = 0.3 else: P = float(args.ratio) ret_list = [] i = 0 mask_num = 0 while i < len(token_list): t = token_list[i] if '.' in t or '(' in t or ')' in t or '[' in t or ']' in t: ret_list.append(t) i += 1 mask_num = 0 else: length = np.random.poisson(3) if np.random.rand() < P and length > 0: if mask_num < 8: ret_list.append('') mask_num += 1 i += length else: ret_list.append(t) i += 1 mask_num = 0 return [' '.join(ret_list)] model = BartForConditionalGeneration.from_pretrained('GanjinZero/biobart-large') model.eval() model.to(device) tokenizer = AutoTokenizer.from_pretrained('GanjinZero/biobart-large') ret_candidates = {} dpath_i = 0 for i,(k, v) in enumerate(tqdm(draft.items())): input = v['in'].replace('\n', '') output = v['out'].replace('\n', '') s, r, o = attack_data[i] if int(s) == -1: ret_candidates[str(i)] = {'span': '', 'prompt' : '', 'out' : [], 'in': [], 'assist': []} continue path_text = dpath[dpath_i].replace('\n', '') dpath_i += 1 text_s = entity_raw_name[id_to_meshid[s]] text_o = entity_raw_name[id_to_meshid[o]] doc = nlp(output) words= input.split(' ') tokenized_sens = [sen for sen in doc.sents] sens = np.array([sen.text for sen in doc.sents]) checkset = set([text_s, text_o]) e_entity = set(['start_entity', 'end_entity']) for path in path_text.split(' '): a, b, c = path.split('|') if a not in e_entity: checkset.add(a) if c not in e_entity: checkset.add(c) vec = [] l = 0 while(l < len(words)): bo =False for j in range(len(words), l, -1): # reversing is important !!! cc = ' '.join(words[l:j]) if (cc in checkset): vec += [True] * (j-l) l = j bo = True break if not bo: vec.append(False) l += 1 vec, span = find_mini_span(vec, words, checkset) # vec = np.vectorize(lambda x: x in checkset)(words) vec[-1] = True prompt = [] mask_num = 0 for j, bo in enumerate(vec): if not bo: mask_num += 1 else: if mask_num > 0: # mask_num = mask_num // 3 # span length ~ poisson distribution (lambda = 3) mask_num = max(mask_num, 1) mask_num= min(8, mask_num) prompt += [''] * mask_num prompt.append(words[j]) mask_num = 0 prompt = ' '.join(prompt) Text = [] Assist = [] for j in range(len(sens)): Bart_input = list(sens[:j]) + [prompt] +list(sens[j+1:]) assist = list(sens[:j]) + [input] +list(sens[j+1:]) Text.append(' '.join(Bart_input)) Assist.append(' '.join(assist)) for j in range(len(sens)): Bart_input = mask_func(tokenized_sens[:j]) + [input] + mask_func(tokenized_sens[j+1:]) assist = list(sens[:j]) + [input] +list(sens[j+1:]) Text.append(' '.join(Bart_input)) Assist.append(' '.join(assist)) batch_size = len(Text) // 2 Outs = [] for l in range(2): A = tokenizer(Text[batch_size * l:batch_size * (l+1)], truncation = True, padding = True, max_length = 1024, return_tensors="pt") input_ids = A['input_ids'].to(device) attention_mask = A['attention_mask'].to(device) aaid = model.generate(input_ids, attention_mask = attention_mask, num_beams = 5, max_length = 1024) outs = tokenizer.batch_decode(aaid, skip_special_tokens=True, clean_up_tokenization_spaces=False) Outs += outs ret_candidates[str(i)] = {'span': span, 'prompt' : prompt, 'out' : Outs, 'in': Text, 'assist': Assist} with open(f'generate_abstract/bioBART/{args.target_split}_{args.reasonable_rate}{args.ratio}_candidates.json', 'w') as fl: json.dump(ret_candidates, fl, indent = 4) from torch.nn.modules.loss import CrossEntropyLoss from transformers import BioGptForCausalLM criterion = CrossEntropyLoss(reduction="none") tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt') tokenizer.pad_token = tokenizer.eos_token model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=tokenizer.eos_token_id) model.to(device) model.eval() scored = {} ret = {} dpath_i = 0 for i,(k, v) in enumerate(tqdm(draft.items())): span = ret_candidates[str(i)]['span'] prompt = ret_candidates[str(i)]['prompt'] sen_list = ret_candidates[str(i)]['out'] BART_in = ret_candidates[str(i)]['in'] Assist = ret_candidates[str(i)]['assist'] s, r, o = attack_data[i] if int(s) == -1: ret[k] = {'prompt': '', 'in':'', 'out': ''} continue text_s = entity_raw_name[id_to_meshid[s]] text_o = entity_raw_name[id_to_meshid[o]] def process(text): for i in range(ord('A'), ord('Z')+1): text = text.replace(f'.{chr(i)}', f'. {chr(i)}') return text sen_list = [process(text) for text in sen_list] path_text = dpath[dpath_i].replace('\n', '') dpath_i += 1 checkset = set([text_s, text_o]) e_entity = set(['start_entity', 'end_entity']) for path in path_text.split(' '): a, b, c = path.split('|') if a not in e_entity: checkset.add(a) if c not in e_entity: checkset.add(c) input = v['in'].replace('\n', '') output = v['out'].replace('\n', '') doc = nlp(output) gpt_sens = [sen.text for sen in doc.sents] assert len(gpt_sens) == len(sen_list) // 2 word_sets = [] for sen in gpt_sens: word_sets.append(set(sen.split(' '))) def sen_align(word_sets, modified_word_sets): l = 0 while(l < len(modified_word_sets)): if len(word_sets[l].intersection(modified_word_sets[l])) > len(word_sets[l]) * 0.8: l += 1 else: break if l == len(modified_word_sets): return -1, -1, -1, -1 r = l + 1 r1 = None r2 = None for pos1 in range(r, len(word_sets)): for pos2 in range(r, len(modified_word_sets)): if len(word_sets[pos1].intersection(modified_word_sets[pos2])) > len(word_sets[pos1]) * 0.8: r1 = pos1 r2 = pos2 break if r1 is not None: break if r1 is None: r1 = len(word_sets) r2 = len(modified_word_sets) return l, r1, l, r2 replace_sen_list = [] boundary = [] assert len(sen_list) % 2 == 0 for j in range(len(sen_list) // 2): doc = nlp(sen_list[j]) sens = [sen.text for sen in doc.sents] modified_word_sets = [set(sen.split(' ')) for sen in sens] l1, r1, l2, r2 = sen_align(word_sets, modified_word_sets) boundary.append((l1, r1, l2, r2)) if l1 == -1: replace_sen_list.append(sen_list[j]) continue check_text = ' '.join(sens[l2: r2]) replace_sen_list.append(' '.join(gpt_sens[:l1] + [check_text] + gpt_sens[r1:])) sen_list = replace_sen_list + sen_list[len(sen_list) // 2:] old_L = len(sen_list) sen_list.append(output) sen_list += Assist tokens = tokenizer( sen_list, truncation = True, padding = True, max_length = 1024, return_tensors="pt") target_ids = tokens['input_ids'].to(device) attention_mask = tokens['attention_mask'].to(device) L = len(sen_list) ret_log_L = [] for l in range(0, L, 5): R = min(L, l + 5) target = target_ids[l:R, :] attention = attention_mask[l:R, :] outputs = model(input_ids = target, attention_mask = attention, labels = target) logits = outputs.logits shift_logits = logits[..., :-1, :].contiguous() shift_labels = target[..., 1:].contiguous() Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1)) Loss = Loss.view(-1, shift_logits.shape[1]) attention = attention[..., 1:].contiguous() log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1)) ret_log_L.append(log_Loss.detach()) log_Loss = torch.cat(ret_log_L, -1).cpu().numpy() real_log_Loss = log_Loss.copy() log_Loss = log_Loss[:old_L] p = np.argmin(log_Loss) content = [] for i in range(len(real_log_Loss)): content.append([sen_list[i], str(real_log_Loss[i])]) scored[k] = {'path':path_text, 'prompt': prompt, 'in':input, 's':text_s, 'o':text_o, 'out': content, 'bound': boundary} p_p = p # print('Old_L:', old_L) if real_log_Loss[p] > real_log_Loss[p+1+old_L]: p_p = p+1+old_L if real_log_Loss[p] > real_log_Loss[old_L]: if real_log_Loss[p] > real_log_Loss[p+1+old_L]: p = p+1+old_L ret[k] = {'prompt': prompt, 'in':input, 'out': sen_list[p]} with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}{args.ratio}_bioBART_finetune.json', 'w') as fl: json.dump(ret, fl, indent=4) with open(f'generate_abstract/bioBART/{args.target_split}_{args.reasonable_rate}{args.ratio}_scored.json', 'w') as fl: json.dump(scored, fl, indent=4) else: raise Exception('Wrong mode !!')