Scorpius_HF / DiseaseSpecific /edge_to_abstract.py
yjwtheonly
specific
ac7c391
raw
history blame
No virus
21.5 kB
#%%
import torch
import numpy as np
from torch.autograd import Variable
# from sklearn import metrics
import datetime
from typing import Dict, Tuple, List
import logging
import os
import utils
import pickle as pkl
import json
import torch.backends.cudnn as cudnn
from tqdm import tqdm
import sys
sys.path.append("..")
import Parameters
parser = utils.get_argument_parser()
parser = utils.add_attack_parameters(parser)
parser.add_argument('--mode', type=str, default='sentence', help='sentence, biogpt or finetune')
parser.add_argument('--ratio', type = str, default='', help='ratio of the number of changed words')
args = parser.parse_args()
args = utils.set_hyperparams(args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
utils.seed_all(args.seed)
np.set_printoptions(precision=5)
cudnn.benchmark = False
data_path = os.path.join('processed_data', args.data)
target_path = os.path.join(data_path, 'DD_target_{0}_{1}_{2}_{3}_{4}_{5}.txt'.format(args.model, args.data, args.target_split, args.target_size, 'exists:'+str(args.target_existed), args.attack_goal))
attack_path = os.path.join('attack_results', args.data, 'cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}.txt'.format(args.model,
args.target_split,
args.target_size,
'exists:'+str(args.target_existed),
args.neighbor_num,
args.candidate_mode,
args.attack_goal,
str(args.reasonable_rate)))
# target_data = utils.load_data(target_path)
attack_data = utils.load_data(attack_path, drop=False)
# assert target_data.shape == attack_data.shape
#%%
with open(os.path.join(data_path, 'entities_reverse_dict.json')) as fl:
id_to_meshid = json.load(fl)
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
entity_raw_name = pkl.load(fl)
with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl:
retieve_sentence_through_edgetype = pkl.load(fl)
with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
raw_text_sen = pkl.load(fl)
if not os.path.exists('generate_abstract/valid_entity.json'):
valid_entity = set()
for paper_id, paper in raw_text_sen.items():
for sen_id, sen in paper.items():
text = sen['text'].split(' ')
for a in text:
if '_' in a:
valid_entity.add(a.replace('_', ' '))
with open('valid_entity.json', 'w') as fl:
json.dump(list(valid_entity), fl, indent=4)
print('Valid entity saved!!')
if args.mode == 'sentence':
import torch
from torch.nn.modules.loss import CrossEntropyLoss
from transformers import AutoTokenizer
from transformers import BioGptForCausalLM
criterion = CrossEntropyLoss(reduction="none")
print('Generating GPT input ...')
tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
tokenizer.pad_token = tokenizer.eos_token
model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=tokenizer.eos_token_id)
model.to(device)
model.eval()
GPT_batch_size = 32
single_sentence = {}
test_text = []
test_dp = []
test_parse = []
for i, (s, r, o) in enumerate(tqdm(attack_data)):
if int(s) != -1:
dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
candidate_sen = []
Dp_path = []
L = len(dependency_sen_dict.keys())
bound = 500 // L
if bound == 0:
bound = 1
for dp_path, sen_list in dependency_sen_dict.items():
if len(sen_list) > bound:
index = np.random.choice(np.array(range(len(sen_list))), bound, replace=False)
sen_list = [sen_list[aa] for aa in index]
candidate_sen += sen_list
Dp_path += [dp_path] * len(sen_list)
text_s = entity_raw_name[id_to_meshid[s]]
text_o = entity_raw_name[id_to_meshid[o]]
candidate_text_sen = []
candidate_ori_sen = []
candidate_parse_sen = []
for paper_id, sen_id in candidate_sen:
sen = raw_text_sen[paper_id][sen_id]
text = sen['text']
candidate_ori_sen.append(text)
ss = sen['start_formatted']
oo = sen['end_formatted']
text = text.replace('-LRB-', '(')
text = text.replace('-RRB-', ')')
text = text.replace('-LSB-', '[')
text = text.replace('-RSB-', ']')
text = text.replace('-LCB-', '{')
text = text.replace('-RCB-', '}')
parse_text = text
parse_text = parse_text.replace(ss, text_s.replace(' ', '_'))
parse_text = parse_text.replace(oo, text_o.replace(' ', '_'))
text = text.replace(ss, text_s)
text = text.replace(oo, text_o)
text = text.replace('_', ' ')
candidate_text_sen.append(text)
candidate_parse_sen.append(parse_text)
tokens = tokenizer( candidate_text_sen,
truncation = True,
padding = True,
max_length = 300,
return_tensors="pt")
target_ids = tokens['input_ids'].to(device)
attention_mask = tokens['attention_mask'].to(device)
L = len(candidate_text_sen)
assert L > 0
ret_log_L = []
for l in range(0, L, GPT_batch_size):
R = min(L, l + GPT_batch_size)
target = target_ids[l:R, :]
attention = attention_mask[l:R, :]
outputs = model(input_ids = target,
attention_mask = attention,
labels = target)
logits = outputs.logits
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = target[..., 1:].contiguous()
Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
Loss = Loss.view(-1, shift_logits.shape[1])
attention = attention[..., 1:].contiguous()
log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1))
ret_log_L.append(log_Loss.detach())
ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy())
sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen))
sen_score.sort(key = lambda x: x[1])
test_text.append(sen_score[0][2])
test_dp.append(sen_score[0][3])
test_parse.append(sen_score[0][4])
single_sentence.update({f'{s}_{r}_{o}_{i}': sen_score[0][0]})
else:
single_sentence.update({f'{s}_{r}_{o}_{i}': ''})
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_sentence.json', 'w') as fl:
json.dump(single_sentence, fl, indent=4)
# with open('generate_abstract/test.txt', 'w') as fl:
# fl.write('\n'.join(test_text))
# with open('generate_abstract/dp.txt', 'w') as fl:
# fl.write('\n'.join(test_dp))
with open (f'generate_abstract/path/{args.target_split}_{args.reasonable_rate}_path.json', 'w') as fl:
fl.write('\n'.join(test_dp))
with open (f'generate_abstract/path/{args.target_split}_{args.reasonable_rate}_temp.json', 'w') as fl:
fl.write('\n'.join(test_text))
elif args.mode == 'finetune':
import spacy
import pprint
from transformers import AutoModel, AutoTokenizer,BartForConditionalGeneration
print('Finetuning ...')
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_chat.json', 'r') as fl:
draft = json.load(fl)
with open (f'generate_abstract/path/{args.target_split}_{args.reasonable_rate}_path.json', 'r') as fl:
dpath = fl.readlines()
nlp = spacy.load("en_core_web_sm")
if os.path.exists(f'generate_abstract/bioBART/{args.target_split}_{args.reasonable_rate}{args.ratio}_candidates.json'):
with open(f'generate_abstract/bioBART/{args.target_split}_{args.reasonable_rate}{args.ratio}_candidates.json', 'r') as fl:
ret_candidates = json.load(fl)
# if False:
# pass
else:
def find_mini_span(vec, words, check_set):
def cal(text, sset):
add = 0
for tt in sset:
if tt in text:
add += 1
return add
text = ' '.join(words)
max_add = cal(text, check_set)
minn = 10000000
span = ''
rc = None
for i in range(len(vec)):
if vec[i] == True:
p = -1
for j in range(i+1, len(vec)+1):
if vec[j-1] == True:
text = ' '.join(words[i:j])
if cal(text, check_set) == max_add:
p = j
break
if p > 0:
if (p-i) < minn:
minn = p-i
span = ' '.join(words[i:p])
rc = (i, p)
if rc:
for i in range(rc[0], rc[1]):
vec[i] = True
return vec, span
def mask_func(tokenized_sen):
if len(tokenized_sen) == 0:
return []
token_list = []
# for sen in tokenized_sen:
# for token in sen:
# token_list.append(token)
for sen in tokenized_sen:
token_list += sen.text.split(' ')
if args.ratio == '':
P = 0.3
else:
P = float(args.ratio)
ret_list = []
i = 0
mask_num = 0
while i < len(token_list):
t = token_list[i]
if '.' in t or '(' in t or ')' in t or '[' in t or ']' in t:
ret_list.append(t)
i += 1
mask_num = 0
else:
length = np.random.poisson(3)
if np.random.rand() < P and length > 0:
if mask_num < 8:
ret_list.append('<mask>')
mask_num += 1
i += length
else:
ret_list.append(t)
i += 1
mask_num = 0
return [' '.join(ret_list)]
model = BartForConditionalGeneration.from_pretrained('GanjinZero/biobart-large')
model.eval()
model.to(device)
tokenizer = AutoTokenizer.from_pretrained('GanjinZero/biobart-large')
ret_candidates = {}
dpath_i = 0
for i,(k, v) in enumerate(tqdm(draft.items())):
input = v['in'].replace('\n', '')
output = v['out'].replace('\n', '')
s, r, o = attack_data[i]
if int(s) == -1:
ret_candidates[str(i)] = {'span': '', 'prompt' : '', 'out' : [], 'in': [], 'assist': []}
continue
path_text = dpath[dpath_i].replace('\n', '')
dpath_i += 1
text_s = entity_raw_name[id_to_meshid[s]]
text_o = entity_raw_name[id_to_meshid[o]]
doc = nlp(output)
words= input.split(' ')
tokenized_sens = [sen for sen in doc.sents]
sens = np.array([sen.text for sen in doc.sents])
checkset = set([text_s, text_o])
e_entity = set(['start_entity', 'end_entity'])
for path in path_text.split(' '):
a, b, c = path.split('|')
if a not in e_entity:
checkset.add(a)
if c not in e_entity:
checkset.add(c)
vec = []
l = 0
while(l < len(words)):
bo =False
for j in range(len(words), l, -1): # reversing is important !!!
cc = ' '.join(words[l:j])
if (cc in checkset):
vec += [True] * (j-l)
l = j
bo = True
break
if not bo:
vec.append(False)
l += 1
vec, span = find_mini_span(vec, words, checkset)
# vec = np.vectorize(lambda x: x in checkset)(words)
vec[-1] = True
prompt = []
mask_num = 0
for j, bo in enumerate(vec):
if not bo:
mask_num += 1
else:
if mask_num > 0:
# mask_num = mask_num // 3 # span length ~ poisson distribution (lambda = 3)
mask_num = max(mask_num, 1)
mask_num= min(8, mask_num)
prompt += ['<mask>'] * mask_num
prompt.append(words[j])
mask_num = 0
prompt = ' '.join(prompt)
Text = []
Assist = []
for j in range(len(sens)):
Bart_input = list(sens[:j]) + [prompt] +list(sens[j+1:])
assist = list(sens[:j]) + [input] +list(sens[j+1:])
Text.append(' '.join(Bart_input))
Assist.append(' '.join(assist))
for j in range(len(sens)):
Bart_input = mask_func(tokenized_sens[:j]) + [input] + mask_func(tokenized_sens[j+1:])
assist = list(sens[:j]) + [input] +list(sens[j+1:])
Text.append(' '.join(Bart_input))
Assist.append(' '.join(assist))
batch_size = len(Text) // 2
Outs = []
for l in range(2):
A = tokenizer(Text[batch_size * l:batch_size * (l+1)],
truncation = True,
padding = True,
max_length = 1024,
return_tensors="pt")
input_ids = A['input_ids'].to(device)
attention_mask = A['attention_mask'].to(device)
aaid = model.generate(input_ids, attention_mask = attention_mask, num_beams = 5, max_length = 1024)
outs = tokenizer.batch_decode(aaid, skip_special_tokens=True, clean_up_tokenization_spaces=False)
Outs += outs
ret_candidates[str(i)] = {'span': span, 'prompt' : prompt, 'out' : Outs, 'in': Text, 'assist': Assist}
with open(f'generate_abstract/bioBART/{args.target_split}_{args.reasonable_rate}{args.ratio}_candidates.json', 'w') as fl:
json.dump(ret_candidates, fl, indent = 4)
from torch.nn.modules.loss import CrossEntropyLoss
from transformers import BioGptForCausalLM
criterion = CrossEntropyLoss(reduction="none")
tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
tokenizer.pad_token = tokenizer.eos_token
model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=tokenizer.eos_token_id)
model.to(device)
model.eval()
scored = {}
ret = {}
dpath_i = 0
for i,(k, v) in enumerate(tqdm(draft.items())):
span = ret_candidates[str(i)]['span']
prompt = ret_candidates[str(i)]['prompt']
sen_list = ret_candidates[str(i)]['out']
BART_in = ret_candidates[str(i)]['in']
Assist = ret_candidates[str(i)]['assist']
s, r, o = attack_data[i]
if int(s) == -1:
ret[k] = {'prompt': '', 'in':'', 'out': ''}
continue
text_s = entity_raw_name[id_to_meshid[s]]
text_o = entity_raw_name[id_to_meshid[o]]
def process(text):
for i in range(ord('A'), ord('Z')+1):
text = text.replace(f'.{chr(i)}', f'. {chr(i)}')
return text
sen_list = [process(text) for text in sen_list]
path_text = dpath[dpath_i].replace('\n', '')
dpath_i += 1
checkset = set([text_s, text_o])
e_entity = set(['start_entity', 'end_entity'])
for path in path_text.split(' '):
a, b, c = path.split('|')
if a not in e_entity:
checkset.add(a)
if c not in e_entity:
checkset.add(c)
input = v['in'].replace('\n', '')
output = v['out'].replace('\n', '')
doc = nlp(output)
gpt_sens = [sen.text for sen in doc.sents]
assert len(gpt_sens) == len(sen_list) // 2
word_sets = []
for sen in gpt_sens:
word_sets.append(set(sen.split(' ')))
def sen_align(word_sets, modified_word_sets):
l = 0
while(l < len(modified_word_sets)):
if len(word_sets[l].intersection(modified_word_sets[l])) > len(word_sets[l]) * 0.8:
l += 1
else:
break
if l == len(modified_word_sets):
return -1, -1, -1, -1
r = l + 1
r1 = None
r2 = None
for pos1 in range(r, len(word_sets)):
for pos2 in range(r, len(modified_word_sets)):
if len(word_sets[pos1].intersection(modified_word_sets[pos2])) > len(word_sets[pos1]) * 0.8:
r1 = pos1
r2 = pos2
break
if r1 is not None:
break
if r1 is None:
r1 = len(word_sets)
r2 = len(modified_word_sets)
return l, r1, l, r2
replace_sen_list = []
boundary = []
assert len(sen_list) % 2 == 0
for j in range(len(sen_list) // 2):
doc = nlp(sen_list[j])
sens = [sen.text for sen in doc.sents]
modified_word_sets = [set(sen.split(' ')) for sen in sens]
l1, r1, l2, r2 = sen_align(word_sets, modified_word_sets)
boundary.append((l1, r1, l2, r2))
if l1 == -1:
replace_sen_list.append(sen_list[j])
continue
check_text = ' '.join(sens[l2: r2])
replace_sen_list.append(' '.join(gpt_sens[:l1] + [check_text] + gpt_sens[r1:]))
sen_list = replace_sen_list + sen_list[len(sen_list) // 2:]
old_L = len(sen_list)
sen_list.append(output)
sen_list += Assist
tokens = tokenizer( sen_list,
truncation = True,
padding = True,
max_length = 1024,
return_tensors="pt")
target_ids = tokens['input_ids'].to(device)
attention_mask = tokens['attention_mask'].to(device)
L = len(sen_list)
ret_log_L = []
for l in range(0, L, 5):
R = min(L, l + 5)
target = target_ids[l:R, :]
attention = attention_mask[l:R, :]
outputs = model(input_ids = target,
attention_mask = attention,
labels = target)
logits = outputs.logits
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = target[..., 1:].contiguous()
Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
Loss = Loss.view(-1, shift_logits.shape[1])
attention = attention[..., 1:].contiguous()
log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1))
ret_log_L.append(log_Loss.detach())
log_Loss = torch.cat(ret_log_L, -1).cpu().numpy()
real_log_Loss = log_Loss.copy()
log_Loss = log_Loss[:old_L]
p = np.argmin(log_Loss)
content = []
for i in range(len(real_log_Loss)):
content.append([sen_list[i], str(real_log_Loss[i])])
scored[k] = {'path':path_text, 'prompt': prompt, 'in':input, 's':text_s, 'o':text_o, 'out': content, 'bound': boundary}
p_p = p
# print('Old_L:', old_L)
if real_log_Loss[p] > real_log_Loss[p+1+old_L]:
p_p = p+1+old_L
if real_log_Loss[p] > real_log_Loss[old_L]:
if real_log_Loss[p] > real_log_Loss[p+1+old_L]:
p = p+1+old_L
ret[k] = {'prompt': prompt, 'in':input, 'out': sen_list[p]}
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}{args.ratio}_bioBART_finetune.json', 'w') as fl:
json.dump(ret, fl, indent=4)
with open(f'generate_abstract/bioBART/{args.target_split}_{args.reasonable_rate}{args.ratio}_scored.json', 'w') as fl:
json.dump(scored, fl, indent=4)
else:
raise Exception('Wrong mode !!')