Spaces:
Running
Running
#%% | |
import gradio as gr | |
import time | |
import sys | |
import os | |
import torch | |
import torch.backends.cudnn as cudnn | |
import numpy as np | |
import json | |
import networkx as nx | |
import spacy | |
import pickle as pkl | |
#%% | |
from torch.nn.modules.loss import CrossEntropyLoss | |
from transformers import AutoTokenizer | |
from transformers import BioGptForCausalLM, BartForConditionalGeneration | |
import server_utils | |
sys.path.append("..") | |
import Parameters | |
from Openai.chat import generate_abstract | |
sys.path.append("../DiseaseSpecific") | |
import utils, attack | |
from attack import calculate_edge_bound, get_model_loss_without_softmax | |
specific_model = None | |
def capitalize_the_first_letter(s): | |
return s[0].upper() + s[1:] | |
parser = utils.get_argument_parser() | |
parser = utils.add_attack_parameters(parser) | |
parser.add_argument('--init-mode', type = str, default='single', help = 'How to select target nodes') # 'single' for case study | |
args = parser.parse_args() | |
args = utils.set_hyperparams(args) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# device = torch.device("cpu") | |
args.device = device | |
args.device1 = device | |
if torch.cuda.device_count() >= 2: | |
args.device = "cuda:0" | |
args.device1 = "cuda:1" | |
utils.seed_all(args.seed) | |
np.set_printoptions(precision=5) | |
cudnn.benchmark = False | |
model_name = '{0}_{1}_{2}_{3}_{4}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop) | |
model_path = '../DiseaseSpecific/saved_models/{0}_{1}.model'.format(args.data, model_name) | |
data_path = os.path.join('../DiseaseSpecific/processed_data', args.data) | |
data = utils.load_data(os.path.join(data_path, 'all.txt')) | |
n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path) | |
with open(os.path.join(data_path, 'filter.pickle'), 'rb') as fl: | |
filters = pkl.load(fl) | |
with open(os.path.join(data_path, 'entityid_to_nodetype.json'), 'r') as fl: | |
entityid_to_nodetype = json.load(fl) | |
with open(os.path.join(data_path, 'edge_nghbrs.pickle'), 'rb') as fl: | |
edge_nghbrs = pkl.load(fl) | |
with open(os.path.join(data_path, 'disease_meshid.pickle'), 'rb') as fl: | |
disease_meshid = pkl.load(fl) | |
with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl: | |
entity_to_id = json.load(fl) | |
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl: | |
entity_raw_name = pkl.load(fl) | |
with open(os.path.join(data_path, 'entities_reverse_dict.json'), 'r') as fl: | |
id_to_entity = json.load(fl) | |
id_to_meshid = id_to_entity.copy() | |
with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl: | |
retieve_sentence_through_edgetype = pkl.load(fl) | |
with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl: | |
raw_text_sen = pkl.load(fl) | |
with open(Parameters.UMLSfile+'drug_term', 'rb') as fl: | |
drug_term = pkl.load(fl) | |
drug_dict = {} | |
disease_dict = {} | |
for k, v in entity_raw_name.items(): | |
#chemical_mesh:c050048 | |
tp = k.split('_')[0] | |
v = capitalize_the_first_letter(v) | |
if len(v) <= 2: | |
continue | |
if tp == 'chemical': | |
drug_dict[v] = k | |
elif tp == 'disease': | |
disease_dict[v] = k | |
drug_list = list(drug_dict.keys()) | |
disease_list = list(disease_dict.keys()) | |
drug_list.sort() | |
disease_list.sort() | |
init_mask = np.asarray([0] * n_ent).astype('int64') | |
init_mask = (init_mask == 1) | |
for k, v in filters.items(): | |
for kk, vv in v.items(): | |
tmp = init_mask.copy() | |
tmp[np.asarray(vv)] = True | |
t = torch.ByteTensor(tmp).to(args.device) | |
filters[k][kk] = t | |
gpt_tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt') | |
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token | |
gpt_model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=gpt_tokenizer.eos_token_id) | |
gpt_model.eval() | |
specific_model = utils.load_model(model_path, args, n_ent, n_rel, args.device) | |
specific_model.eval() | |
divide_bound, data_mean, data_std = attack.calculate_edge_bound(data, specific_model, args.device, n_ent) | |
nlp = spacy.load("en_core_web_sm") | |
bart_model = BartForConditionalGeneration.from_pretrained('GanjinZero/biobart-large') | |
bart_model.eval() | |
bart_tokenizer = AutoTokenizer.from_pretrained('GanjinZero/biobart-large') | |
def tune_chatgpt(draft, attack_data, dpath): | |
dpath_i = 0 | |
bart_model.to(args.device1) | |
for i, v in enumerate(draft): | |
input = v['in'].replace('\n', '') | |
output = v['out'].replace('\n', '') | |
s, r, o = attack_data[i] | |
path_text = dpath[dpath_i].replace('\n', '') | |
dpath_i += 1 | |
text_s = entity_raw_name[id_to_meshid[s]] | |
text_o = entity_raw_name[id_to_meshid[o]] | |
doc = nlp(output) | |
words= input.split(' ') | |
tokenized_sens = [sen for sen in doc.sents] | |
sens = np.array([sen.text for sen in doc.sents]) | |
checkset = set([text_s, text_o]) | |
e_entity = set(['start_entity', 'end_entity']) | |
for path in path_text.split(' '): | |
a, b, c = path.split('|') | |
if a not in e_entity: | |
checkset.add(a) | |
if c not in e_entity: | |
checkset.add(c) | |
vec = [] | |
l = 0 | |
while(l < len(words)): | |
bo =False | |
for j in range(len(words), l, -1): # reversing is important !!! | |
cc = ' '.join(words[l:j]) | |
if (cc in checkset): | |
vec += [True] * (j-l) | |
l = j | |
bo = True | |
break | |
if not bo: | |
vec.append(False) | |
l += 1 | |
vec, span = server_utils.find_mini_span(vec, words, checkset) | |
# vec = np.vectorize(lambda x: x in checkset)(words) | |
vec[-1] = True | |
prompt = [] | |
mask_num = 0 | |
for j, bo in enumerate(vec): | |
if not bo: | |
mask_num += 1 | |
else: | |
if mask_num > 0: | |
# mask_num = mask_num // 3 # span length ~ poisson distribution (lambda = 3) | |
mask_num = max(mask_num, 1) | |
mask_num= min(8, mask_num) | |
prompt += ['<mask>'] * mask_num | |
prompt.append(words[j]) | |
mask_num = 0 | |
prompt = ' '.join(prompt) | |
Text = [] | |
Assist = [] | |
for j in range(len(sens)): | |
Bart_input = list(sens[:j]) + [prompt] +list(sens[j+1:]) | |
assist = list(sens[:j]) + [input] +list(sens[j+1:]) | |
Text.append(' '.join(Bart_input)) | |
Assist.append(' '.join(assist)) | |
for j in range(len(sens)): | |
Bart_input = server_utils.mask_func(tokenized_sens[:j]) + [input] + server_utils.mask_func(tokenized_sens[j+1:]) | |
assist = list(sens[:j]) + [input] +list(sens[j+1:]) | |
Text.append(' '.join(Bart_input)) | |
Assist.append(' '.join(assist)) | |
batch_size = 8 | |
Outs = [] | |
for l in range(0, len(Text), batch_size): | |
R = min(len(Text), l + batch_size) | |
A = bart_tokenizer(Text[l:R], | |
truncation = True, | |
padding = True, | |
max_length = 1024, | |
return_tensors="pt") | |
input_ids = A['input_ids'].to(args.device1) | |
attention_mask = A['attention_mask'].to(args.device1) | |
aaid = bart_model.generate(input_ids, attention_mask = attention_mask, num_beams = 5, max_length = 1024) | |
outs = bart_tokenizer.batch_decode(aaid, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
Outs += outs | |
bart_model.to('cpu') | |
return span, prompt, Outs, Text, Assist | |
def score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, v): | |
criterion = CrossEntropyLoss(reduction="none") | |
text_s = entity_raw_name[id_to_meshid[s]] | |
text_o = entity_raw_name[id_to_meshid[o]] | |
sen_list = [server_utils.process(text) for text in sen_list] | |
path_text = dpath[0].replace('\n', '') | |
checkset = set([text_s, text_o]) | |
e_entity = set(['start_entity', 'end_entity']) | |
for path in path_text.split(' '): | |
a, b, c = path.split('|') | |
if a not in e_entity: | |
checkset.add(a) | |
if c not in e_entity: | |
checkset.add(c) | |
input = v['in'].replace('\n', '') | |
output = v['out'].replace('\n', '') | |
doc = nlp(output) | |
gpt_sens = [sen.text for sen in doc.sents] | |
assert len(gpt_sens) == len(sen_list) // 2 | |
word_sets = [] | |
for sen in gpt_sens: | |
word_sets.append(set(sen.split(' '))) | |
def sen_align(word_sets, modified_word_sets): | |
l = 0 | |
while(l < len(modified_word_sets)): | |
if len(word_sets[l].intersection(modified_word_sets[l])) > len(word_sets[l]) * 0.8: | |
l += 1 | |
else: | |
break | |
if l == len(modified_word_sets): | |
return -1, -1, -1, -1 | |
r = l + 1 | |
r1 = None | |
r2 = None | |
for pos1 in range(r, len(word_sets)): | |
for pos2 in range(r, len(modified_word_sets)): | |
if len(word_sets[pos1].intersection(modified_word_sets[pos2])) > len(word_sets[pos1]) * 0.8: | |
r1 = pos1 | |
r2 = pos2 | |
break | |
if r1 is not None: | |
break | |
if r1 is None: | |
r1 = len(word_sets) | |
r2 = len(modified_word_sets) | |
return l, r1, l, r2 | |
replace_sen_list = [] | |
boundary = [] | |
assert len(sen_list) % 2 == 0 | |
for j in range(len(sen_list) // 2): | |
doc = nlp(sen_list[j]) | |
sens = [sen.text for sen in doc.sents] | |
modified_word_sets = [set(sen.split(' ')) for sen in sens] | |
l1, r1, l2, r2 = sen_align(word_sets, modified_word_sets) | |
boundary.append((l1, r1, l2, r2)) | |
if l1 == -1: | |
replace_sen_list.append(sen_list[j]) | |
continue | |
check_text = ' '.join(sens[l2: r2]) | |
replace_sen_list.append(' '.join(gpt_sens[:l1] + [check_text] + gpt_sens[r1:])) | |
sen_list = replace_sen_list + sen_list[len(sen_list) // 2:] | |
gpt_model.to(args.device1) | |
sen_list.append(output) | |
tokens = gpt_tokenizer( sen_list, | |
truncation = True, | |
padding = True, | |
max_length = 1024, | |
return_tensors="pt") | |
target_ids = tokens['input_ids'].to(args.device1) | |
attention_mask = tokens['attention_mask'].to(args.device1) | |
L = len(sen_list) | |
ret_log_L = [] | |
for l in range(0, L, 5): | |
R = min(L, l + 5) | |
target = target_ids[l:R, :] | |
attention = attention_mask[l:R, :] | |
outputs = gpt_model(input_ids = target, | |
attention_mask = attention, | |
labels = target) | |
logits = outputs.logits | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = target[..., 1:].contiguous() | |
Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1)) | |
Loss = Loss.view(-1, shift_logits.shape[1]) | |
attention = attention[..., 1:].contiguous() | |
log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1)) | |
ret_log_L.append(log_Loss.detach()) | |
log_Loss = torch.cat(ret_log_L, -1).cpu().numpy() | |
gpt_model.to('cpu') | |
p = np.argmin(log_Loss) | |
return sen_list[p] | |
def generate_template_for_triplet(attack_data): | |
criterion = CrossEntropyLoss(reduction="none") | |
gpt_model.to(args.device1) | |
print('Generating template ...') | |
GPT_batch_size = 8 | |
single_sentence = [] | |
test_text = [] | |
test_dp = [] | |
test_parse = [] | |
s, r, o = attack_data[0] | |
dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual'] | |
candidate_sen = [] | |
Dp_path = [] | |
L = len(dependency_sen_dict.keys()) | |
bound = 500 // L | |
if bound == 0: | |
bound = 1 | |
for dp_path, sen_list in dependency_sen_dict.items(): | |
if len(sen_list) > bound: | |
index = np.random.choice(np.array(range(len(sen_list))), bound, replace=False) | |
sen_list = [sen_list[aa] for aa in index] | |
ssen_list = [] | |
for aa in range(len(sen_list)): | |
paper_id, sen_id = sen_list[aa] | |
if raw_text_sen[paper_id][sen_id]['start_formatted'] == raw_text_sen[paper_id][sen_id]['end_formatted']: | |
continue | |
ssen_list.append(sen_list[aa]) | |
sen_list = ssen_list | |
candidate_sen += sen_list | |
Dp_path += [dp_path] * len(sen_list) | |
text_s = entity_raw_name[id_to_meshid[s]] | |
text_o = entity_raw_name[id_to_meshid[o]] | |
candidate_text_sen = [] | |
candidate_ori_sen = [] | |
candidate_parse_sen = [] | |
for paper_id, sen_id in candidate_sen: | |
sen = raw_text_sen[paper_id][sen_id] | |
text = sen['text'] | |
candidate_ori_sen.append(text) | |
ss = sen['start_formatted'] | |
oo = sen['end_formatted'] | |
text = text.replace('-LRB-', '(') | |
text = text.replace('-RRB-', ')') | |
text = text.replace('-LSB-', '[') | |
text = text.replace('-RSB-', ']') | |
text = text.replace('-LCB-', '{') | |
text = text.replace('-RCB-', '}') | |
parse_text = text | |
parse_text = parse_text.replace(ss, text_s.replace(' ', '_')) | |
parse_text = parse_text.replace(oo, text_o.replace(' ', '_')) | |
text = text.replace(ss, text_s) | |
text = text.replace(oo, text_o) | |
text = text.replace('_', ' ') | |
candidate_text_sen.append(text) | |
candidate_parse_sen.append(parse_text) | |
tokens = gpt_tokenizer( candidate_text_sen, | |
truncation = True, | |
padding = True, | |
max_length = 300, | |
return_tensors="pt") | |
target_ids = tokens['input_ids'].to(args.device1) | |
attention_mask = tokens['attention_mask'].to(args.device1) | |
L = len(candidate_text_sen) | |
assert L > 0 | |
ret_log_L = [] | |
for l in range(0, L, GPT_batch_size): | |
R = min(L, l + GPT_batch_size) | |
target = target_ids[l:R, :] | |
attention = attention_mask[l:R, :] | |
outputs = gpt_model(input_ids = target, | |
attention_mask = attention, | |
labels = target) | |
logits = outputs.logits | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = target[..., 1:].contiguous() | |
Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1)) | |
Loss = Loss.view(-1, shift_logits.shape[1]) | |
attention = attention[..., 1:].contiguous() | |
log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1)) | |
ret_log_L.append(log_Loss.detach()) | |
ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy()) | |
sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen)) | |
sen_score.sort(key = lambda x: x[1]) | |
test_text.append(sen_score[0][2]) | |
test_dp.append(sen_score[0][3]) | |
test_parse.append(sen_score[0][4]) | |
single_sentence.append(sen_score[0][0]) | |
gpt_model.to('cpu') | |
return single_sentence, test_text, test_dp, test_parse | |
meshids = list(id_to_meshid.values()) | |
cal = { | |
'chemical' : 0, | |
'disease' : 0, | |
'gene' : 0 | |
} | |
for meshid in meshids: | |
cal[meshid.split('_')[0]] += 1 | |
def check_reasonable(s, r, o): | |
train_trip = np.asarray([[s, r, o]]) | |
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) | |
edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze() | |
# edge_losse_log_prob = torch.log(F.softmax(-edge_loss, dim = -1)) | |
edge_loss = edge_loss.item() | |
edge_loss = (edge_loss - data_mean) / data_std | |
edge_losses_prob = 1 / ( 1 + np.exp(edge_loss - divide_bound) ) | |
bound = 1 - args.reasonable_rate | |
return (edge_losses_prob > bound), edge_losses_prob | |
edgeid_to_edgetype = {} | |
edgeid_to_reversemask = {} | |
for k, id_list in Parameters.edge_type_to_id.items(): | |
for iid, mask in zip(id_list, Parameters.reverse_mask[k]): | |
edgeid_to_edgetype[str(iid)] = k | |
edgeid_to_reversemask[str(iid)] = mask | |
reverse_tot = 0 | |
G = nx.DiGraph() | |
for s, r, o in data: | |
assert id_to_meshid[s].split('_')[0] == edgeid_to_edgetype[r].split('-')[0] | |
if edgeid_to_reversemask[r] == 1: | |
reverse_tot += 1 | |
G.add_edge(int(o), int(s)) | |
else: | |
G.add_edge(int(s), int(o)) | |
print('Page ranking ...') | |
pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7) | |
drug_meshid = [] | |
drug_list = [] | |
for meshid, nm in entity_raw_name.items(): | |
if nm.lower() in drug_term and meshid.split('_')[0] == 'chemical': | |
drug_meshid.append(meshid) | |
drug_list.append(capitalize_the_first_letter(nm)) | |
drug_list = list(set(drug_list)) | |
drug_list.sort() | |
drug_meshid = set(drug_meshid) | |
pr = list(pagerank_value_1.items()) | |
pr.sort(key = lambda x: x[1]) | |
sorted_rank = { 'chemical' : [], | |
'gene' : [], | |
'disease': [], | |
'merged' : []} | |
for iid, score in pr: | |
tp = id_to_meshid[str(iid)].split('_')[0] | |
if tp == 'chemical': | |
if id_to_meshid[str(iid)] in drug_meshid: | |
sorted_rank[tp].append((iid, score)) | |
else: | |
sorted_rank[tp].append((iid, score)) | |
sorted_rank['merged'].append((iid, score)) | |
llen = len(sorted_rank['merged']) | |
sorted_rank['merged'] = sorted_rank['merged'][llen * 3 // 4 : ] | |
def generate_specific_attack_edge(start_entity, end_entity): | |
global specific_model | |
specific_model.to(device) | |
strat_meshid = drug_dict[start_entity] | |
end_meshid = disease_dict[end_entity] | |
start_entity = entity_to_id[strat_meshid] | |
end_entity = entity_to_id[end_meshid] | |
target_data = np.array([[start_entity, '10', end_entity]]) | |
neighbors = attack.generate_nghbrs(target_data, edge_nghbrs, args) | |
ret = f'Generating malicious link for {strat_meshid}_treatment_{end_meshid}', 'Generation malicious text ...' | |
param_optimizer = list(specific_model.named_parameters()) | |
param_influence = [] | |
for n,p in param_optimizer: | |
param_influence.append(p) | |
len_list = [] | |
for v in neighbors.values(): | |
len_list.append(len(v)) | |
mean_len = np.mean(len_list) | |
attack_trip, score_record = attack.addition_attack(param_influence, args.device, n_rel, data, target_data, neighbors, specific_model, filters, entityid_to_nodetype, args.attack_batch_size, args, load_Record = args.load_existed, divide_bound = divide_bound, data_mean = data_mean, data_std = data_std, cache_intermidiate = False) | |
s, r, o = attack_trip[0] | |
specific_model.to('cpu') | |
return s, r, o | |
def generate_agnostic_attack_edge(targets): | |
specific_model.to(device) | |
attack_edge_list = [] | |
for target in targets: | |
candidate_list = [] | |
score_list = [] | |
loss_list = [] | |
main_dict = {} | |
for iid, score in sorted_rank['merged']: | |
a = G.number_of_edges(iid, target) + 1 | |
if a != 1: | |
continue | |
b = G.out_degree(iid) + 1 | |
tp = id_to_meshid[str(iid)].split('_')[0] | |
edge_losses = [] | |
r_list = [] | |
for r in range(len(edgeid_to_edgetype)): | |
r_tp = edgeid_to_edgetype[str(r)] | |
if (edgeid_to_reversemask[str(r)] == 0 and r_tp.split('-')[0] == tp and r_tp.split('-')[1] == 'chemical'): | |
train_trip = np.array([[iid, r, target]]) | |
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) | |
edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze() | |
edge_losses.append(edge_loss.unsqueeze(0).detach()) | |
r_list.append(r) | |
elif(edgeid_to_reversemask[str(r)] == 1 and r_tp.split('-')[0] == 'chemical' and r_tp.split('-')[1] == tp): | |
train_trip = np.array([[iid, r, target]]) # add batch dim | |
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) | |
edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze() | |
edge_losses.append(edge_loss.unsqueeze(0).detach()) | |
r_list.append(r) | |
if len(edge_losses)==0: | |
continue | |
min_index = torch.argmin(torch.cat(edge_losses, dim = 0)) | |
r = r_list[min_index] | |
r_tp = edgeid_to_edgetype[str(r)] | |
old_len = len(candidate_list) | |
if (edgeid_to_reversemask[str(r)] == 0): | |
bo, prob = check_reasonable(iid, r, target) | |
if bo: | |
candidate_list.append((iid, r, target)) | |
score_list.append(score * a / b) | |
loss_list.append(edge_losses[min_index].item()) | |
if (edgeid_to_reversemask[str(r)] == 1): | |
bo, prob = check_reasonable(target, r, iid) | |
if bo: | |
candidate_list.append((target, r, iid)) | |
score_list.append(score * a / b) | |
loss_list.append(edge_losses[min_index].item()) | |
if len(candidate_list) == 0: | |
if args.added_edge_num == '' or int(args.added_edge_num) == 1: | |
attack_edge_list.append((-1,-1,-1)) | |
else: | |
attack_edge_list.append([]) | |
continue | |
norm_score = np.array(score_list) / np.sum(score_list) | |
norm_loss = np.exp(-np.array(loss_list)) / np.sum(np.exp(-np.array(loss_list))) | |
total_score = norm_score * norm_loss | |
total_score_index = list(zip(range(len(total_score)), total_score)) | |
total_score_index.sort(key = lambda x: x[1], reverse = True) | |
total_index = np.argsort(total_score)[::-1] | |
assert total_index[0] == total_score_index[0][0] | |
# find rank of main index | |
max_index = np.argmax(total_score) | |
assert max_index == total_score_index[0][0] | |
tmp_add = [] | |
add_num = 1 | |
if args.added_edge_num == '' or int(args.added_edge_num) == 1: | |
attack_edge_list.append(candidate_list[max_index]) | |
else: | |
add_num = int(args.added_edge_num) | |
for i in range(add_num): | |
tmp_add.append(candidate_list[total_score_index[i][0]]) | |
attack_edge_list.append(tmp_add) | |
specific_model.to('cpu') | |
return attack_edge_list[0] | |
def specific_func(start_entity, end_entity): | |
args.reasonable_rate = 0.5 | |
s, r, o = generate_specific_attack_edge(start_entity, end_entity) | |
if int(s) == -1: | |
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated' | |
s_name = entity_raw_name[id_to_entity[str(s)]] | |
r_name = Parameters.edge_id_to_type[int(r)].split(':')[1] | |
o_name = entity_raw_name[id_to_entity[str(o)]] | |
attack_data = np.array([[s, r, o]]) | |
path_list = [] | |
with open(f'../DiseaseSpecific/generate_abstract/path/random_{args.reasonable_rate}_path.json', 'r') as fl: | |
for line in fl.readlines(): | |
line.replace('\n', '') | |
path_list.append(line) | |
with open(f'../DiseaseSpecific/generate_abstract/random_{args.reasonable_rate}_sentence.json', 'r') as fl: | |
sentence_dict = json.load(fl) | |
dpath = [] | |
for k, v in sentence_dict.items(): | |
if f'{s}_{r}_{o}' in k: | |
single_sentence = [v] | |
dpath = [path_list[int(k.split('_')[-1])]] | |
break | |
if len(dpath) == 0: | |
single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data) | |
elif not(s_name in single_sentence[0] and o_name in single_sentence[0]): | |
single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data) | |
print('Using ChatGPT for generation...') | |
draft = generate_abstract(single_sentence[0]) | |
print('Using BioBART for tuning...') | |
span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath) | |
text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft}) | |
return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text) | |
# f'The sentence is: {single_sentence[0]}\n The path is: {dpath[0]}' | |
def agnostic_func(agnostic_entity): | |
args.reasonable_rate = 0.7 | |
target_id = entity_to_id[drug_dict[agnostic_entity]] | |
s = generate_agnostic_attack_edge([int(target_id)]) | |
if len(s) == 0: | |
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated' | |
if int(s[0]) == -1: | |
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated' | |
s, r, o = str(s[0]), str(s[1]), str(s[2]) | |
s_name = entity_raw_name[id_to_entity[str(s)]] | |
r_name = Parameters.edge_id_to_type[int(r)].split(':')[1] | |
o_name = entity_raw_name[id_to_entity[str(o)]] | |
attack_data = np.array([[s, r, o]]) | |
single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data) | |
print('Using ChatGPT for generation...') | |
draft = generate_abstract(single_sentence[0]) | |
print('Using BioBART for tuning...') | |
span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath) | |
text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft}) | |
return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text) | |
#%% | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown("Poison scitific knowledge with Scorpius") | |
# with gr.Column(): | |
with gr.Row(): | |
# Center | |
with gr.Column(): | |
gr.Markdown("Select your poison target") | |
with gr.Tab('Target specific'): | |
with gr.Column(): | |
with gr.Row(): | |
start_entity = gr.Dropdown(drug_list, label="Promoting drug") | |
end_entity = gr.Dropdown(disease_list, label="Target disease") | |
specific_generation_button = gr.Button('Poison!') | |
with gr.Tab('Target agnostic'): | |
agnostic_entity = gr.Dropdown(drug_list, label="Promoting drug") | |
agnostic_generation_button = gr.Button('Poison!') | |
with gr.Column(): | |
gr.Markdown("Malicious link") | |
malicisous_link = gr.Textbox(lines=1, label="Malicious link") | |
gr.Markdown("Malicious text") | |
malicious_text = gr.Textbox(label="Malicious text", lines=5) | |
specific_generation_button.click(specific_func, inputs=[start_entity, end_entity], outputs=[malicisous_link, malicious_text]) | |
agnostic_generation_button.click(agnostic_func, inputs=[agnostic_entity], outputs=[malicisous_link, malicious_text]) | |
demo.launch(server_name="0.0.0.0", server_port=8000, debug=False) |