Spaces:
Running
Running
#%% | |
import gradio as gr | |
import time | |
import sys | |
import os | |
import torch | |
import torch.backends.cudnn as cudnn | |
import numpy as np | |
import json | |
import networkx as nx | |
import spacy | |
# os.system("python -m spacy download en-core-web-sm") | |
import pickle as pkl | |
from tqdm import tqdm | |
#%% | |
# please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU. | |
# torch.loa | |
from torch.nn.modules.loss import CrossEntropyLoss | |
from transformers import AutoTokenizer | |
from transformers import BioGptForCausalLM, BartForConditionalGeneration | |
from server import server_utils | |
import Parameters | |
from Openai.chat import generate_abstract | |
from DiseaseSpecific import utils, attack | |
from DiseaseSpecific.attack import calculate_edge_bound, get_model_loss_without_softmax | |
specific_model = None | |
def capitalize_the_first_letter(s): | |
return s[0].upper() + s[1:] | |
parser = utils.get_argument_parser() | |
parser = utils.add_attack_parameters(parser) | |
parser.add_argument('--init-mode', type = str, default='single', help = 'How to select target nodes') # 'single' for case study | |
args = parser.parse_args() | |
args = utils.set_hyperparams(args) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# device = torch.device("cpu") | |
args.device = device | |
args.device1 = device | |
if torch.cuda.device_count() >= 2: | |
args.device = "cuda:0" | |
args.device1 = "cuda:1" | |
utils.seed_all(args.seed) | |
np.set_printoptions(precision=5) | |
cudnn.benchmark = False | |
model_name = '{0}_{1}_{2}_{3}_{4}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop) | |
model_path = 'DiseaseSpecific/saved_models/{0}_{1}.model'.format(args.data, model_name) | |
data_path = os.path.join('DiseaseSpecific/processed_data', args.data) | |
data = utils.load_data(os.path.join(data_path, 'all.txt')) | |
n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path) | |
with open(os.path.join(data_path, 'filter.pickle'), 'rb') as fl: | |
filters = pkl.load(fl) | |
with open(os.path.join(data_path, 'entityid_to_nodetype.json'), 'r') as fl: | |
entityid_to_nodetype = json.load(fl) | |
with open(os.path.join(data_path, 'edge_nghbrs.pickle'), 'rb') as fl: | |
edge_nghbrs = pkl.load(fl) | |
with open(os.path.join(data_path, 'disease_meshid.pickle'), 'rb') as fl: | |
disease_meshid = pkl.load(fl) | |
with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl: | |
entity_to_id = json.load(fl) | |
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl: | |
entity_raw_name = pkl.load(fl) | |
with open(os.path.join(data_path, 'entities_reverse_dict.json'), 'r') as fl: | |
id_to_entity = json.load(fl) | |
id_to_meshid = id_to_entity.copy() | |
with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl: | |
retieve_sentence_through_edgetype = pkl.load(fl) | |
with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl: | |
raw_text_sen = pkl.load(fl) | |
with open(Parameters.UMLSfile+'drug_term', 'rb') as fl: | |
drug_term = pkl.load(fl) | |
gallery_specific_target_path = os.path.join(data_path, 'DD_target_distmult_GNBR_random_50_exists:False_single.txt') | |
gallery_specific_link_path = 'DiseaseSpecific/attack_results/GNBR/cos_distmult_random_50_exists:False_20_quadratic_single_0.5.txt' | |
gallery_specific_text_path = 'DiseaseSpecific/generate_abstract/random_0.5_bioBART_finetune.json' | |
gallery_agnostic_target_path = 'DiseaseAgnostic/processed_data/target_0.7random.pkl' | |
gallery_agnostic_link_path = 'DiseaseAgnostic/processed_data/attack_edge_distmult_0.7random.pkl' | |
gallery_agnostic_text_path = 'DiseaseAgnostic/generate_abstract/random0.7_bioBART_finetune.json' | |
gallery_specific_chat_path = 'DiseaseSpecific/generate_abstract/random_0.5_chat.json' | |
gallery_agnostic_chat_path = 'DiseaseAgnostic/generate_abstract/random0.7_chat.json' | |
gallery_specific_target = utils.load_data(gallery_specific_target_path, drop=False) | |
gallery_specific_link = utils.load_data(gallery_specific_link_path, drop=False) | |
with open(gallery_specific_text_path, 'r') as fl: | |
gallery_specific_text = json.load(fl) | |
with open(gallery_agnostic_target_path, 'rb') as fl: | |
gallery_agnostic_target = pkl.load(fl) | |
with open(gallery_agnostic_link_path, 'rb') as fl: | |
gallery_agnostic_link = pkl.load(fl) | |
with open(gallery_agnostic_text_path, 'r') as fl: | |
gallery_agnostic_text = json.load(fl) | |
with open(gallery_specific_chat_path, 'r') as fl: | |
gallery_specific_chat = json.load(fl) | |
with open(gallery_agnostic_chat_path, 'r') as fl: | |
gallery_agnostic_chat = json.load(fl) | |
gallery_specific_list = [] | |
gallery_specific_target_dict = {} | |
for i, (s, r, o) in enumerate(gallery_specific_target): | |
s = id_to_meshid[str(s)] | |
o = id_to_meshid[str(o)] | |
k = f'{gallery_specific_link[i][0]}_{gallery_specific_link[i][1]}_{gallery_specific_link[i][2]}_{i}' | |
if 'sorry' in gallery_specific_text[k]['out'] or 'Sorry' in gallery_specific_text[k]['out']: | |
continue | |
target_name = f'{capitalize_the_first_letter(entity_raw_name[s])} - {capitalize_the_first_letter(entity_raw_name[o])}' | |
if target_name not in gallery_specific_target_dict: | |
gallery_specific_target_dict[target_name] = i | |
gallery_specific_list.append(target_name) | |
gallery_specific_list.sort() | |
gallery_agnostic_list = [] | |
gallery_agnostic_target_dict = {} | |
for i, iid in enumerate(gallery_agnostic_target): | |
target_name = capitalize_the_first_letter(entity_raw_name[id_to_meshid[str(iid)]]) | |
k = f'{gallery_agnostic_link[i][0]}_{gallery_agnostic_link[i][1]}_{gallery_agnostic_link[i][2]}_{i}' | |
if 'sorry' in gallery_agnostic_text[k]['out'] or 'Sorry' in gallery_agnostic_text[k]['out']: | |
continue | |
if target_name not in gallery_agnostic_target_dict: | |
gallery_agnostic_target_dict[target_name] = i | |
gallery_agnostic_list.append(target_name) | |
gallery_agnostic_list.sort() | |
drug_dict = {} | |
disease_dict = {} | |
for k, v in entity_raw_name.items(): | |
#chemical_mesh:c050048 | |
tp = k.split('_')[0] | |
v = capitalize_the_first_letter(v) | |
if len(v) <= 2: | |
continue | |
if tp == 'chemical': | |
drug_dict[v] = k | |
elif tp == 'disease': | |
disease_dict[v] = k | |
drug_list = list(drug_dict.keys()) | |
disease_list = list(disease_dict.keys()) | |
drug_list.sort() | |
disease_list.sort() | |
init_mask = np.asarray([0] * n_ent).astype('int64') | |
init_mask = (init_mask == 1) | |
for k, v in filters.items(): | |
for kk, vv in v.items(): | |
tmp = init_mask.copy() | |
tmp[np.asarray(vv)] = True | |
t = torch.ByteTensor(tmp).to(args.device) | |
filters[k][kk] = t | |
gpt_tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt') | |
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token | |
gpt_model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=gpt_tokenizer.eos_token_id) | |
gpt_model.eval() | |
specific_model = utils.load_model(model_path, args, n_ent, n_rel, args.device) | |
specific_model.eval() | |
divide_bound, data_mean, data_std = attack.calculate_edge_bound(data, specific_model, args.device, n_ent) | |
nlp = spacy.load("en_core_web_sm") | |
bart_model = BartForConditionalGeneration.from_pretrained('GanjinZero/biobart-large') | |
bart_model.eval() | |
bart_tokenizer = AutoTokenizer.from_pretrained('GanjinZero/biobart-large') | |
def tune_chatgpt(draft, attack_data, dpath): | |
dpath_i = 0 | |
bart_model.to(args.device1) | |
for i, v in enumerate(draft): | |
input = v['in'].replace('\n', '') | |
output = v['out'].replace('\n', '') | |
s, r, o = attack_data[i] | |
path_text = dpath[dpath_i].replace('\n', '') | |
dpath_i += 1 | |
text_s = entity_raw_name[id_to_meshid[s]] | |
text_o = entity_raw_name[id_to_meshid[o]] | |
doc = nlp(output) | |
words= input.split(' ') | |
tokenized_sens = [sen for sen in doc.sents] | |
sens = np.array([sen.text for sen in doc.sents]) | |
checkset = set([text_s, text_o]) | |
e_entity = set(['start_entity', 'end_entity']) | |
for path in path_text.split(' '): | |
a, b, c = path.split('|') | |
if a not in e_entity: | |
checkset.add(a) | |
if c not in e_entity: | |
checkset.add(c) | |
vec = [] | |
l = 0 | |
while(l < len(words)): | |
bo =False | |
for j in range(len(words), l, -1): # reversing is important !!! | |
cc = ' '.join(words[l:j]) | |
if (cc in checkset): | |
vec += [True] * (j-l) | |
l = j | |
bo = True | |
break | |
if not bo: | |
vec.append(False) | |
l += 1 | |
vec, span = server_utils.find_mini_span(vec, words, checkset) | |
# vec = np.vectorize(lambda x: x in checkset)(words) | |
vec[-1] = True | |
prompt = [] | |
mask_num = 0 | |
for j, bo in enumerate(vec): | |
if not bo: | |
mask_num += 1 | |
else: | |
if mask_num > 0: | |
# mask_num = mask_num // 3 # span length ~ poisson distribution (lambda = 3) | |
mask_num = max(mask_num, 1) | |
mask_num= min(8, mask_num) | |
prompt += ['<mask>'] * mask_num | |
prompt.append(words[j]) | |
mask_num = 0 | |
prompt = ' '.join(prompt) | |
Text = [] | |
Assist = [] | |
for j in range(len(sens)): | |
Bart_input = list(sens[:j]) + [prompt] +list(sens[j+1:]) | |
assist = list(sens[:j]) + [input] +list(sens[j+1:]) | |
Text.append(' '.join(Bart_input)) | |
Assist.append(' '.join(assist)) | |
for j in range(len(sens)): | |
Bart_input = server_utils.mask_func(tokenized_sens[:j]) + [input] + server_utils.mask_func(tokenized_sens[j+1:]) | |
assist = list(sens[:j]) + [input] +list(sens[j+1:]) | |
Text.append(' '.join(Bart_input)) | |
Assist.append(' '.join(assist)) | |
batch_size = 8 | |
Outs = [] | |
for l in tqdm(range(0, len(Text), batch_size)): | |
R = min(len(Text), l + batch_size) | |
A = bart_tokenizer(Text[l:R], | |
truncation = True, | |
padding = True, | |
max_length = 1024, | |
return_tensors="pt") | |
input_ids = A['input_ids'].to(args.device1) | |
attention_mask = A['attention_mask'].to(args.device1) | |
aaid = bart_model.generate(input_ids, attention_mask = attention_mask, num_beams = 5, max_length = 1024) | |
outs = bart_tokenizer.batch_decode(aaid, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
Outs += outs | |
bart_model.to('cpu') | |
return span, prompt, Outs, Text, Assist | |
def score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, v): | |
criterion = CrossEntropyLoss(reduction="none") | |
text_s = entity_raw_name[id_to_meshid[str(s)]] | |
text_o = entity_raw_name[id_to_meshid[str(o)]] | |
sen_list = [server_utils.process(text) for text in sen_list] | |
path_text = dpath[0].replace('\n', '') | |
checkset = set([text_s, text_o]) | |
e_entity = set(['start_entity', 'end_entity']) | |
for path in path_text.split(' '): | |
a, b, c = path.split('|') | |
if a not in e_entity: | |
checkset.add(a) | |
if c not in e_entity: | |
checkset.add(c) | |
input = v['in'].replace('\n', '') | |
output = v['out'].replace('\n', '') | |
doc = nlp(output) | |
gpt_sens = [sen.text for sen in doc.sents] | |
assert len(gpt_sens) == len(sen_list) // 2 | |
word_sets = [] | |
for sen in gpt_sens: | |
word_sets.append(set(sen.split(' '))) | |
def sen_align(word_sets, modified_word_sets): | |
l = 0 | |
while(l < len(modified_word_sets)): | |
if len(word_sets[l].intersection(modified_word_sets[l])) > len(word_sets[l]) * 0.8: | |
l += 1 | |
else: | |
break | |
if l == len(modified_word_sets): | |
return -1, -1, -1, -1 | |
r = l + 1 | |
r1 = None | |
r2 = None | |
for pos1 in range(r, len(word_sets)): | |
for pos2 in range(r, len(modified_word_sets)): | |
if len(word_sets[pos1].intersection(modified_word_sets[pos2])) > len(word_sets[pos1]) * 0.8: | |
r1 = pos1 | |
r2 = pos2 | |
break | |
if r1 is not None: | |
break | |
if r1 is None: | |
r1 = len(word_sets) | |
r2 = len(modified_word_sets) | |
return l, r1, l, r2 | |
replace_sen_list = [] | |
boundary = [] | |
assert len(sen_list) % 2 == 0 | |
for j in range(len(sen_list) // 2): | |
doc = nlp(sen_list[j]) | |
sens = [sen.text for sen in doc.sents] | |
modified_word_sets = [set(sen.split(' ')) for sen in sens] | |
l1, r1, l2, r2 = sen_align(word_sets, modified_word_sets) | |
boundary.append((l1, r1, l2, r2)) | |
if l1 == -1: | |
replace_sen_list.append(sen_list[j]) | |
continue | |
check_text = ' '.join(sens[l2: r2]) | |
replace_sen_list.append(' '.join(gpt_sens[:l1] + [check_text] + gpt_sens[r1:])) | |
sen_list = replace_sen_list + sen_list[len(sen_list) // 2:] | |
gpt_model.to(args.device1) | |
sen_list.append(output) | |
tokens = gpt_tokenizer( sen_list, | |
truncation = True, | |
padding = True, | |
max_length = 1024, | |
return_tensors="pt") | |
target_ids = tokens['input_ids'].to(args.device1) | |
attention_mask = tokens['attention_mask'].to(args.device1) | |
L = len(sen_list) | |
ret_log_L = [] | |
for l in tqdm(range(0, L, 5)): | |
R = min(L, l + 5) | |
target = target_ids[l:R, :] | |
attention = attention_mask[l:R, :] | |
outputs = gpt_model(input_ids = target, | |
attention_mask = attention, | |
labels = target) | |
logits = outputs.logits | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = target[..., 1:].contiguous() | |
Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1)) | |
Loss = Loss.view(-1, shift_logits.shape[1]) | |
attention = attention[..., 1:].contiguous() | |
log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1)) | |
ret_log_L.append(log_Loss.detach()) | |
log_Loss = torch.cat(ret_log_L, -1).cpu().numpy() | |
gpt_model.to('cpu') | |
p = np.argmin(log_Loss) | |
return sen_list[p] | |
def generate_template_for_triplet(attack_data): | |
criterion = CrossEntropyLoss(reduction="none") | |
gpt_model.to(args.device1) | |
print('Generating template ...') | |
GPT_batch_size = 8 | |
single_sentence = [] | |
test_text = [] | |
test_dp = [] | |
test_parse = [] | |
s, r, o = attack_data[0] | |
dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual'] | |
candidate_sen = [] | |
Dp_path = [] | |
L = len(dependency_sen_dict.keys()) | |
bound = 500 // L | |
if bound == 0: | |
bound = 1 | |
for dp_path, sen_list in dependency_sen_dict.items(): | |
if len(sen_list) > bound: | |
index = np.random.choice(np.array(range(len(sen_list))), bound, replace=False) | |
sen_list = [sen_list[aa] for aa in index] | |
ssen_list = [] | |
for aa in range(len(sen_list)): | |
paper_id, sen_id = sen_list[aa] | |
if raw_text_sen[paper_id][sen_id]['start_formatted'] == raw_text_sen[paper_id][sen_id]['end_formatted']: | |
continue | |
ssen_list.append(sen_list[aa]) | |
sen_list = ssen_list | |
candidate_sen += sen_list | |
Dp_path += [dp_path] * len(sen_list) | |
text_s = entity_raw_name[id_to_meshid[s]] | |
text_o = entity_raw_name[id_to_meshid[o]] | |
candidate_text_sen = [] | |
candidate_ori_sen = [] | |
candidate_parse_sen = [] | |
for paper_id, sen_id in candidate_sen: | |
sen = raw_text_sen[paper_id][sen_id] | |
text = sen['text'] | |
candidate_ori_sen.append(text) | |
ss = sen['start_formatted'] | |
oo = sen['end_formatted'] | |
text = text.replace('-LRB-', '(') | |
text = text.replace('-RRB-', ')') | |
text = text.replace('-LSB-', '[') | |
text = text.replace('-RSB-', ']') | |
text = text.replace('-LCB-', '{') | |
text = text.replace('-RCB-', '}') | |
parse_text = text | |
parse_text = parse_text.replace(ss, text_s.replace(' ', '_')) | |
parse_text = parse_text.replace(oo, text_o.replace(' ', '_')) | |
text = text.replace(ss, text_s) | |
text = text.replace(oo, text_o) | |
text = text.replace('_', ' ') | |
candidate_text_sen.append(text) | |
candidate_parse_sen.append(parse_text) | |
tokens = gpt_tokenizer( candidate_text_sen, | |
truncation = True, | |
padding = True, | |
max_length = 300, | |
return_tensors="pt") | |
target_ids = tokens['input_ids'].to(args.device1) | |
attention_mask = tokens['attention_mask'].to(args.device1) | |
L = len(candidate_text_sen) | |
assert L > 0 | |
ret_log_L = [] | |
for l in tqdm(range(0, L, GPT_batch_size)): | |
R = min(L, l + GPT_batch_size) | |
target = target_ids[l:R, :] | |
attention = attention_mask[l:R, :] | |
outputs = gpt_model(input_ids = target, | |
attention_mask = attention, | |
labels = target) | |
logits = outputs.logits | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = target[..., 1:].contiguous() | |
Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1)) | |
Loss = Loss.view(-1, shift_logits.shape[1]) | |
attention = attention[..., 1:].contiguous() | |
log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1)) | |
ret_log_L.append(log_Loss.detach()) | |
ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy()) | |
sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen)) | |
sen_score.sort(key = lambda x: x[1]) | |
Len = len(sen_score) | |
p = 0 | |
if Len > 10: | |
p = np.random.choice(np.array(range(Len // 10)), 1)[0] | |
test_text.append(sen_score[p][2]) | |
test_dp.append(sen_score[p][3]) | |
test_parse.append(sen_score[p][4]) | |
single_sentence.append(sen_score[p][0]) | |
gpt_model.to('cpu') | |
return single_sentence, test_text, test_dp, test_parse | |
meshids = list(id_to_meshid.values()) | |
cal = { | |
'chemical' : 0, | |
'disease' : 0, | |
'gene' : 0 | |
} | |
for meshid in meshids: | |
cal[meshid.split('_')[0]] += 1 | |
def check_reasonable(s, r, o): | |
train_trip = np.asarray([[s, r, o]]) | |
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) | |
edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze() | |
# edge_losse_log_prob = torch.log(F.softmax(-edge_loss, dim = -1)) | |
edge_loss = edge_loss.item() | |
edge_loss = (edge_loss - data_mean) / data_std | |
edge_losses_prob = 1 / ( 1 + np.exp(edge_loss - divide_bound) ) | |
bound = 1 - args.reasonable_rate | |
return (edge_losses_prob > bound), edge_losses_prob | |
edgeid_to_edgetype = {} | |
edgeid_to_reversemask = {} | |
for k, id_list in Parameters.edge_type_to_id.items(): | |
for iid, mask in zip(id_list, Parameters.reverse_mask[k]): | |
edgeid_to_edgetype[str(iid)] = k | |
edgeid_to_reversemask[str(iid)] = mask | |
reverse_tot = 0 | |
G = nx.DiGraph() | |
for s, r, o in data: | |
assert id_to_meshid[s].split('_')[0] == edgeid_to_edgetype[r].split('-')[0] | |
if edgeid_to_reversemask[r] == 1: | |
reverse_tot += 1 | |
G.add_edge(int(o), int(s)) | |
else: | |
G.add_edge(int(s), int(o)) | |
print('Page ranking ...') | |
pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7) | |
drug_meshid = [] | |
drug_list = [] | |
for meshid, nm in entity_raw_name.items(): | |
if nm.lower() in drug_term and meshid.split('_')[0] == 'chemical': | |
drug_meshid.append(meshid) | |
drug_list.append(capitalize_the_first_letter(nm)) | |
drug_list = list(set(drug_list)) | |
drug_list.sort() | |
drug_meshid = set(drug_meshid) | |
pr = list(pagerank_value_1.items()) | |
pr.sort(key = lambda x: x[1]) | |
sorted_rank = { 'chemical' : [], | |
'gene' : [], | |
'disease': [], | |
'merged' : []} | |
for iid, score in pr: | |
tp = id_to_meshid[str(iid)].split('_')[0] | |
if tp == 'chemical': | |
if id_to_meshid[str(iid)] in drug_meshid: | |
sorted_rank[tp].append((iid, score)) | |
else: | |
sorted_rank[tp].append((iid, score)) | |
sorted_rank['merged'].append((iid, score)) | |
llen = len(sorted_rank['merged']) | |
sorted_rank['merged'] = sorted_rank['merged'][llen * 3 // 4 : ] | |
def generate_specific_attack_edge(start_entity, end_entity): | |
if device == torch.device('cpu'): | |
print('We can just set the malicious link equals to the target link, since the generation of malicious link is too slow on cpu') | |
return entity_to_id[drug_dict[start_entity]], '10', entity_to_id[disease_dict[end_entity]] | |
global specific_model | |
specific_model.to(device) | |
strat_meshid = drug_dict[start_entity] | |
end_meshid = disease_dict[end_entity] | |
start_entity = entity_to_id[strat_meshid] | |
end_entity = entity_to_id[end_meshid] | |
target_data = np.array([[start_entity, '10', end_entity]]) | |
neighbors = attack.generate_nghbrs(target_data, edge_nghbrs, args) | |
ret = f'Generating malicious link for {strat_meshid}_treatment_{end_meshid}', 'Generation malicious text ...' | |
param_optimizer = list(specific_model.named_parameters()) | |
param_influence = [] | |
for n,p in param_optimizer: | |
param_influence.append(p) | |
len_list = [] | |
for v in neighbors.values(): | |
len_list.append(len(v)) | |
mean_len = np.mean(len_list) | |
attack_trip, score_record = attack.addition_attack(param_influence, args.device, n_rel, data, target_data, neighbors, specific_model, filters, entityid_to_nodetype, args.attack_batch_size, args, load_Record = args.load_existed, divide_bound = divide_bound, data_mean = data_mean, data_std = data_std, cache_intermidiate = False) | |
s, r, o = attack_trip[0] | |
specific_model.to('cpu') | |
return s, r, o | |
def generate_agnostic_attack_edge(targets): | |
specific_model.to(device) | |
attack_edge_list = [] | |
for target in targets: | |
candidate_list = [] | |
score_list = [] | |
loss_list = [] | |
main_dict = {} | |
for iid, score in sorted_rank['merged']: | |
a = G.number_of_edges(iid, target) + 1 | |
if a != 1: | |
continue | |
b = G.out_degree(iid) + 1 | |
tp = id_to_meshid[str(iid)].split('_')[0] | |
edge_losses = [] | |
r_list = [] | |
for r in range(len(edgeid_to_edgetype)): | |
r_tp = edgeid_to_edgetype[str(r)] | |
if (edgeid_to_reversemask[str(r)] == 0 and r_tp.split('-')[0] == tp and r_tp.split('-')[1] == 'chemical'): | |
train_trip = np.array([[iid, r, target]]) | |
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) | |
edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze() | |
edge_losses.append(edge_loss.unsqueeze(0).detach()) | |
r_list.append(r) | |
elif(edgeid_to_reversemask[str(r)] == 1 and r_tp.split('-')[0] == 'chemical' and r_tp.split('-')[1] == tp): | |
train_trip = np.array([[iid, r, target]]) # add batch dim | |
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) | |
edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze() | |
edge_losses.append(edge_loss.unsqueeze(0).detach()) | |
r_list.append(r) | |
if len(edge_losses)==0: | |
continue | |
min_index = torch.argmin(torch.cat(edge_losses, dim = 0)) | |
r = r_list[min_index] | |
r_tp = edgeid_to_edgetype[str(r)] | |
old_len = len(candidate_list) | |
if (edgeid_to_reversemask[str(r)] == 0): | |
bo, prob = check_reasonable(iid, r, target) | |
if bo: | |
candidate_list.append((iid, r, target)) | |
score_list.append(score * a / b) | |
loss_list.append(edge_losses[min_index].item()) | |
if (edgeid_to_reversemask[str(r)] == 1): | |
bo, prob = check_reasonable(target, r, iid) | |
if bo: | |
candidate_list.append((target, r, iid)) | |
score_list.append(score * a / b) | |
loss_list.append(edge_losses[min_index].item()) | |
if len(candidate_list) == 0: | |
if args.added_edge_num == '' or int(args.added_edge_num) == 1: | |
attack_edge_list.append((-1,-1,-1)) | |
else: | |
attack_edge_list.append([]) | |
continue | |
norm_score = np.array(score_list) / np.sum(score_list) | |
norm_loss = np.exp(-np.array(loss_list)) / np.sum(np.exp(-np.array(loss_list))) | |
total_score = norm_score * norm_loss | |
total_score_index = list(zip(range(len(total_score)), total_score)) | |
total_score_index.sort(key = lambda x: x[1], reverse = True) | |
total_index = np.argsort(total_score)[::-1] | |
assert total_index[0] == total_score_index[0][0] | |
# find rank of main index | |
max_index = np.argmax(total_score) | |
assert max_index == total_score_index[0][0] | |
tmp_add = [] | |
add_num = 1 | |
if args.added_edge_num == '' or int(args.added_edge_num) == 1: | |
attack_edge_list.append(candidate_list[max_index]) | |
else: | |
add_num = int(args.added_edge_num) | |
for i in range(add_num): | |
tmp_add.append(candidate_list[total_score_index[i][0]]) | |
attack_edge_list.append(tmp_add) | |
specific_model.to('cpu') | |
return attack_edge_list[0] | |
def specific_func(start_entity, end_entity): | |
args.reasonable_rate = 0.5 | |
s, r, o = generate_specific_attack_edge(start_entity, end_entity) | |
if int(s) == -1: | |
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated' | |
s_name = entity_raw_name[id_to_entity[str(s)]] | |
r_name = Parameters.edge_id_to_type[int(r)].split(':')[1] | |
o_name = entity_raw_name[id_to_entity[str(o)]] | |
attack_data = np.array([[s, r, o]]) | |
path_list = [] | |
with open(f'DiseaseSpecific/generate_abstract/path/random_{args.reasonable_rate}_path.json', 'r') as fl: | |
for line in fl.readlines(): | |
line.replace('\n', '') | |
path_list.append(line) | |
with open(f'DiseaseSpecific/generate_abstract/random_{args.reasonable_rate}_sentence.json', 'r') as fl: | |
sentence_dict = json.load(fl) | |
dpath = [] | |
for k, v in sentence_dict.items(): | |
if f'{s}_{r}_{o}' in k: | |
single_sentence = [v] | |
dpath = [path_list[int(k.split('_')[-1])]] | |
break | |
if len(dpath) == 0: | |
single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data) | |
elif not(s_name in single_sentence[0] and o_name in single_sentence[0]): | |
single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data) | |
print('Using ChatGPT for generation...') | |
draft = generate_abstract(single_sentence[0]) | |
if 'sorry' in draft or 'Sorry' in draft: | |
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated' | |
print('Using BioBART for tuning...') | |
span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath) | |
text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft}) | |
return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text) | |
# f'The sentence is: {single_sentence[0]}\n The path is: {dpath[0]}' | |
def agnostic_func(agnostic_entity): | |
args.reasonable_rate = 0.7 | |
target_id = entity_to_id[drug_dict[agnostic_entity]] | |
s = generate_agnostic_attack_edge([int(target_id)]) | |
if len(s) == 0: | |
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated' | |
if int(s[0]) == -1: | |
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated' | |
s, r, o = str(s[0]), str(s[1]), str(s[2]) | |
s_name = entity_raw_name[id_to_entity[str(s)]] | |
r_name = Parameters.edge_id_to_type[int(r)].split(':')[1] | |
o_name = entity_raw_name[id_to_entity[str(o)]] | |
attack_data = np.array([[s, r, o]]) | |
single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data) | |
print('Using ChatGPT for generation...') | |
draft = generate_abstract(single_sentence[0]) | |
if 'sorry' in draft or 'Sorry' in draft: | |
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated' | |
print('Using BioBART for tuning...') | |
span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath) | |
text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft}) | |
return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text) | |
def gallery_specific_func(specific_target): | |
index = gallery_specific_target_dict[specific_target] | |
s, r, o = gallery_specific_link[index] | |
s_name = entity_raw_name[id_to_entity[str(s)]] | |
r_name = Parameters.edge_id_to_type[int(r)].split(':')[1] | |
o_name = entity_raw_name[id_to_entity[str(o)]] | |
k = f'{s}_{r}_{o}_{index}' | |
inn = gallery_specific_text[k]['in'] | |
text = gallery_specific_text[k]['out'] | |
if inn in text: | |
text = gallery_specific_chat[k]['out'] | |
return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text) | |
def gallery_agnostic_func(agnostic_target): | |
index = gallery_agnostic_target_dict[agnostic_target] | |
s, r, o = gallery_agnostic_link[index] | |
s_name = entity_raw_name[id_to_entity[str(s)]] | |
r_name = Parameters.edge_id_to_type[int(r)].split(':')[1] | |
o_name = entity_raw_name[id_to_entity[str(o)]] | |
k = f'{s}_{r}_{o}_{index}' | |
inn = gallery_agnostic_text[k]['in'] | |
text = gallery_agnostic_text[k]['out'] | |
if inn in text: | |
text = gallery_agnostic_chat[k]['out'] | |
return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text) | |
#%% | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown("Poison scitific knowledge with Scorpius") | |
# with gr.Column(): | |
with gr.Row(): | |
# Center | |
with gr.Column(): | |
gr.Markdown("Select your poisoning target") | |
with gr.Tab('Gallery'): | |
with gr.Tab('Target specific'): | |
specific_target = gr.Dropdown(gallery_specific_list, label="Promoting drug and target disease") | |
gallery_specific_generation_button = gr.Button('Poison!') | |
with gr.Tab('Target agnostic'): | |
agnostic_target = gr.Dropdown(gallery_agnostic_list, label="Promoting drug") | |
gallery_agnostic_generation_button = gr.Button('Poison!') | |
with gr.Tab('Poison'): | |
with gr.Tab('Target specific'): | |
with gr.Column(): | |
with gr.Row(): | |
start_entity = gr.Dropdown(drug_list, label="Promoting drug") | |
end_entity = gr.Dropdown(disease_list, label="Target disease") | |
if device == torch.device('cpu'): | |
gr.Markdown("Since the project is currently running on the CPU, we directly treat the malicious link as equivalent to the poisoning target, to accelerate the generation process.") | |
specific_generation_button = gr.Button('Poison!') | |
with gr.Tab('Target agnostic'): | |
agnostic_entity = gr.Dropdown(drug_list, label="Promoting drug") | |
agnostic_generation_button = gr.Button('Poison!') | |
with gr.Column(): | |
gr.Markdown("Generation") | |
malicisous_link = gr.Textbox(lines=1, label="Malicious link") | |
# gr.Markdown("Malicious text") | |
malicious_text = gr.Textbox(label="Malicious text", lines=5) | |
specific_generation_button.click(specific_func, inputs=[start_entity, end_entity], outputs=[malicisous_link, malicious_text]) | |
agnostic_generation_button.click(agnostic_func, inputs=[agnostic_entity], outputs=[malicisous_link, malicious_text]) | |
gallery_specific_generation_button.click(gallery_specific_func, inputs=[specific_target], outputs=[malicisous_link, malicious_text]) | |
gallery_agnostic_generation_button.click(gallery_agnostic_func, inputs=[agnostic_target], outputs=[malicisous_link, malicious_text]) | |
# demo.launch(server_name="0.0.0.0", server_port=8000, debug=False) | |
demo.launch() |