Spaces:
Running
Running
#%% | |
import logging | |
from symbol import parameters | |
from textwrap import indent | |
import os | |
import tempfile | |
import sys | |
from matplotlib import collections | |
import pandas as pd | |
import json | |
from glob import glob | |
from tqdm import tqdm | |
import numpy as np | |
from pprint import pprint | |
import torch | |
import pickle as pkl | |
from collections import Counter | |
# print(dir(collections)) | |
import networkx as nx | |
from collections import Counter | |
import utils | |
from torch.nn import functional as F | |
sys.path.append("..") | |
import Parameters | |
from DiseaseSpecific.attack import calculate_edge_bound, get_model_loss_without_softmax | |
#%% | |
def load_data(file_name): | |
df = pd.read_csv(file_name, sep='\t', header=None, names=None, dtype=str) | |
df = df.drop_duplicates() | |
return df.values | |
parser = utils.get_argument_parser() | |
parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate') | |
parser.add_argument('--mode', type = str, default='', help = ' "" or chat or bioBART') | |
parser.add_argument('--init-mode', type = str, default='random', help = 'How to select target nodes') # 'single' for case study | |
parser.add_argument('--added-edge-num', type = str, default = '', help = 'Added edge num') | |
args = parser.parse_args() | |
args = utils.set_hyperparams(args) | |
utils.seed_all(args.seed) | |
graph_edge_path = '../DiseaseSpecific/processed_data/GNBR/all.txt' | |
idtomeshid_path = '../DiseaseSpecific/processed_data/GNBR/entities_reverse_dict.json' | |
model_path = f'../DiseaseSpecific/saved_models/GNBR_{args.model}_128_0.2_0.3_0.3.model' | |
data_path = '../DiseaseSpecific/processed_data/GNBR' | |
target_path = f'processed_data/target_{args.reasonable_rate}{args.init_mode}.pkl' | |
attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}{args.added_edge_num}{args.mode}.pkl' | |
with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl: | |
full_entity_raw_name = pkl.load(fl) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# device = torch.device("cpu") | |
args.device = device | |
n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path) | |
model = utils.load_model(model_path, args, n_ent, n_rel, args.device) | |
graph_edge = utils.load_data(graph_edge_path) | |
with open(idtomeshid_path, 'r') as fl: | |
idtomeshid = json.load(fl) | |
print(graph_edge.shape, len(idtomeshid)) | |
divide_bound, data_mean, data_std = calculate_edge_bound(graph_edge, model, args.device, n_ent) | |
print('Defender ...') | |
print(divide_bound, data_mean, data_std) | |
meshids = list(idtomeshid.values()) | |
cal = { | |
'chemical' : 0, | |
'disease' : 0, | |
'gene' : 0 | |
} | |
for meshid in meshids: | |
cal[meshid.split('_')[0]] += 1 | |
# pprint(cal) | |
def check_reasonable(s, r, o): | |
train_trip = np.asarray([[s, r, o]]) | |
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) | |
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze() | |
# edge_losse_log_prob = torch.log(F.softmax(-edge_loss, dim = -1)) | |
edge_loss = edge_loss.item() | |
edge_loss = (edge_loss - data_mean) / data_std | |
edge_losses_prob = 1 / ( 1 + np.exp(edge_loss - divide_bound) ) | |
bound = 1 - args.reasonable_rate | |
return (edge_losses_prob > bound), edge_losses_prob | |
edgeid_to_edgetype = {} | |
edgeid_to_reversemask = {} | |
for k, id_list in Parameters.edge_type_to_id.items(): | |
for iid, mask in zip(id_list, Parameters.reverse_mask[k]): | |
edgeid_to_edgetype[str(iid)] = k | |
edgeid_to_reversemask[str(iid)] = mask | |
with open(target_path, 'rb') as fl: | |
Target_node_list = pkl.load(fl) | |
with open(attack_path, 'rb') as fl: | |
Attack_edge_list = pkl.load(fl) | |
with open(Parameters.UMLSfile+'drug_term', 'rb') as fl: | |
drug_term = pkl.load(fl) | |
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl: | |
entity_raw_name = pkl.load(fl) | |
drug_meshid = [] | |
for meshid, nm in entity_raw_name.items(): | |
if nm.lower() in drug_term and meshid.split('_')[0] == 'chemical': | |
drug_meshid.append(meshid) | |
drug_meshid = set(drug_meshid) | |
if args.init_mode == 'single': | |
name_list = [] | |
for target in Target_node_list: | |
name = entity_raw_name[idtomeshid[str(target)]] | |
name_list.append(name) | |
with open(f'results/name_list_{args.reasonable_rate}{args.init_mode}.txt', 'w') as fl: | |
fl.write('\n'.join(name_list)) | |
# print(Target_node_list) | |
# # print(Attack_edge_list) | |
# addset = set() | |
# if args.added_edge_num == 1: | |
# for edge in Attack_edge_list: | |
# addset.add(edge[2]) | |
# else: | |
# for edge_list in Attack_edge_list: | |
# for edge in edge_list: | |
# addset.add(edge[2]) | |
# print(addset) | |
# print(len(addset)) | |
# typeset = set() | |
# for iid in addset: | |
# typeset.add(idtomeshid[str(iid)].split('_')[0]) | |
# print(typeset) | |
# raise Exception('done') | |
if args.init_mode == 'single': | |
Target_node_list = [[Target_node_list[i]] for i in range(len(Target_node_list))] | |
Attack_edge_list = [[Attack_edge_list[i]] for i in range(len(Attack_edge_list))] | |
else: | |
print(len(Attack_edge_list), len(Target_node_list)) | |
tmp_target_node_list = [] | |
tmp_attack_edge_list = [] | |
for l in range(0,len(Target_node_list), 50): | |
r = min(l+50, len(Target_node_list)) | |
tmp_target_node_list.append(Target_node_list[l:r]) | |
tmp_attack_edge_list.append(Attack_edge_list[l:r]) | |
Target_node_list = tmp_target_node_list | |
Attack_edge_list = tmp_attack_edge_list | |
# for i, init_p in enumerate([0.1, 0.3, 0.5, 0.7, 0.9]): | |
# target_node_list = Target_node_list[i] | |
# attack_edge_list = Attack_edge_list[i] | |
Init = [] | |
After = [] | |
# final_init = [] | |
# final_after = [] | |
for i, (target_node_list, attack_edge_list) in enumerate(zip(Target_node_list, Attack_edge_list)): | |
G = nx.DiGraph() | |
for s, r, o in graph_edge: | |
assert idtomeshid[s].split('_')[0] == edgeid_to_edgetype[r].split('-')[0] | |
if edgeid_to_reversemask[r] == 1: | |
G.add_edge(int(o), int(s)) | |
else: | |
G.add_edge(int(s), int(o)) | |
pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7) | |
for target, attack_list in tqdm(list(zip(target_node_list, attack_edge_list))): | |
pr = list(pagerank_value_1.items()) | |
pr.sort(key = lambda x: x[1]) | |
list_iid = [] | |
for iid, score in pr: | |
tp = idtomeshid[str(iid)].split('_')[0] | |
if tp == 'chemical': | |
# if idtomeshid[str(iid)] in drug_meshid: | |
list_iid.append(iid) | |
init_rank = len(list_iid) - list_iid.index(target) | |
# init_rank = 1 - list_iid.index(target) / len(list_iid) | |
Init.append(init_rank) | |
for target, attack_list in tqdm(list(zip(target_node_list, attack_edge_list))): | |
if args.mode == '' and (args.added_edge_num == '' or int(args.added_edge_num) == 1): | |
if int(attack_list[0]) == -1: | |
attack_list = [] | |
else: | |
attack_list = [attack_list] | |
if len(attack_list) > 0: | |
for s, r, o in attack_list: | |
bo, prob = check_reasonable(s, r, o) | |
if bo: | |
if edgeid_to_reversemask[str(r)] == 1: | |
G.add_edge(int(o), int(s)) | |
else: | |
G.add_edge(int(s), int(o)) | |
pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7) | |
for target, attack_list in tqdm(list(zip(target_node_list, attack_edge_list))): | |
pr = list(pagerank_value_1.items()) | |
pr.sort(key = lambda x: x[1]) | |
list_iid = [] | |
for iid, score in pr: | |
tp = idtomeshid[str(iid)].split('_')[0] | |
if tp == 'chemical': | |
# if idtomeshid[str(iid)] in drug_meshid: | |
list_iid.append(iid) | |
after_rank = len(list_iid) - list_iid.index(target) | |
# after_rank = 1 - list_iid.index(target) / len(list_iid) | |
After.append(after_rank) | |
with open(f'results/Init_{args.reasonable_rate}{args.init_mode}.pkl', 'wb') as fl: | |
pkl.dump(Init, fl) | |
with open(f'results/After_{args.model}_{args.reasonable_rate}{args.init_mode}{args.added_edge_num}{args.mode}.pkl', 'wb') as fl: | |
pkl.dump(After, fl) | |
print(np.mean(Init), np.std(Init)) | |
print(np.mean(After), np.std(After)) | |