Spaces:
Running
Running
import torch | |
import numpy as np | |
from torch.autograd import Variable | |
from sklearn import metrics | |
import datetime | |
from typing import Dict, Tuple, List | |
import logging | |
import os | |
import utils | |
import pickle as pkl | |
import json | |
from tqdm import tqdm | |
import torch.backends.cudnn as cudnn | |
import sys | |
sys.path.append("..") | |
import Parameters | |
logger = logging.getLogger(__name__) | |
def get_model_loss_without_softmax(batch, model, device=None): | |
with torch.no_grad(): | |
s,r,o = batch[:,0], batch[:,1], batch[:,2] | |
emb_s = model.emb_e(s).squeeze(dim=1) | |
emb_r = model.emb_rel(r).squeeze(dim=1) | |
pred = model.forward(emb_s, emb_r) | |
return -pred[range(o.shape[0]), o] | |
def check(trip, model, reasonable_rate, device, data_mean = -4.008113861083984, data_std = 5.153779983520508, divide_bound = 0.05440050354114886): | |
if args.model == 'distmult': | |
pass | |
elif args.model == 'conve': | |
data_mean = 13.890259742 | |
data_std = 12.396190643 | |
divide_bound = -0.1986345871 | |
else: | |
raise Exception('Wrong model!!') | |
trip = np.array(trip) | |
train_trip = trip[None, :] | |
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device) | |
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze().item() | |
bound = 1 - reasonable_rate | |
edge_loss = (edge_loss - data_mean) / data_std | |
edge_loss_prob = 1 / ( 1 + np.exp(edge_loss - divide_bound)) | |
return edge_loss_prob > bound | |
def get_ranking(model, queries, | |
valid_filters:Dict[str, Dict[Tuple[str, int], torch.Tensor]], | |
device, batch_size, entityid_to_nodetype, exists_edge): | |
""" | |
Ranking for target generation. | |
""" | |
ranks = [] | |
total_nums = [] | |
b_begin = 0 | |
for b_begin in range(0, len(queries), 1): | |
b_queries = queries[b_begin : b_begin+1] | |
s,r,o = b_queries[:,0], b_queries[:,1], b_queries[:,2] | |
r_rev = r | |
lhs_score = model.score_or(o, r_rev, sigmoid=False) #this gives scores not probabilities | |
# print(b_queries.shape) | |
for i, query in enumerate(b_queries): | |
if not args.target_existed: | |
tp1 = entityid_to_nodetype[str(query[0].item())] | |
tp2 = entityid_to_nodetype[str(query[2].item())] | |
filter = valid_filters['lhs'][(tp2, query[1].item())].clone() | |
filter[exists_edge['lhs'][str(query[2].item())]] = False | |
filter = (filter == False) | |
else: | |
tp1 = entityid_to_nodetype[str(query[0].item())] | |
tp2 = entityid_to_nodetype[str(query[2].item())] | |
filter = valid_filters['lhs'][(tp2, query[1].item())] | |
filter = (filter == False) | |
# if (str(query[2].item())) == '16566': | |
# print('16566', filter.sum(), valid_filters['lhs'][(tp2, query[1].item())].sum(), tp2, query[1].item()) | |
# raise Exception('??') | |
score = lhs_score | |
# target_value = rhs_score[i, query[0].item()].item() | |
# zero all known cases (this are not interesting) | |
# this corresponds to the filtered setting | |
score[i][filter] = 1e6 | |
total_nums.append(n_ent - filter.sum().item()) | |
# write base the saved values | |
# if b_begin < len(queries) // 2: | |
# score[i][query[2].item()] = target_value | |
# else: | |
# score[i][query[0].item()] = target_value | |
# sort and rank | |
min_values, sort_v = torch.sort(score, dim=1, descending=False) #low scores get low number ranks | |
sort_v = sort_v.cpu().numpy() | |
for i, query in enumerate(b_queries): | |
# find the rank of the target entities | |
rank = np.where(sort_v[i]==query[0].item())[0][0] | |
# rank+1, since the lowest rank is rank 1 not rank 0 | |
ranks.append(rank) | |
#logger.info('Ranking done for all queries') | |
return ranks, total_nums | |
def evaluation(model, queries, | |
valid_filters:Dict[str, Dict[Tuple[str, int], torch.Tensor]], | |
device, batch_size, entityid_to_nodetype, exists_edge, eval_type = '', attack_data = None, ori_ranks = None, ori_totals = None): | |
#get ranking | |
ranks, total_nums = get_ranking(model, queries, valid_filters, device, batch_size, entityid_to_nodetype, exists_edge) | |
ranks, total_nums = np.array(ranks), np.array(total_nums) | |
# print(ranks) | |
# print(total_nums) | |
# print(ranks) | |
# print(total_nums) | |
ranks = total_nums - ranks | |
if (attack_data is not None): | |
for i, tri in enumerate(attack_data): | |
if args.mode == '': | |
if args.added_edge_num == '' or int(args.added_edge_num) == 1: | |
if int(tri[0]) == -1: | |
ranks[i] = ori_ranks[i] | |
total_nums[i] = ori_totals[i] | |
else: | |
if int(tri[0][0]) == -1: | |
ranks[i] = ori_ranks[i] | |
total_nums[i] = ori_totals[i] | |
else: | |
if len(tri) == 0: | |
ranks[i] = ori_ranks[i] | |
total_nums[i] = ori_totals[i] | |
mean = (ranks / total_nums).mean() | |
std = (ranks / total_nums).std() | |
#final logging | |
hits_at = np.arange(1,11) | |
hits_at_both = list(map(lambda x: np.mean((ranks <= x), dtype=np.float64).item(), | |
hits_at)) | |
mr = np.mean(ranks, dtype=np.float64).item() | |
mrr = np.mean(1. / ranks, dtype=np.float64).item() | |
logger.info('') | |
logger.info('-'*50) | |
# logger.info(split+'_'+save_name) | |
logger.info('') | |
if eval_type: | |
logger.info(eval_type) | |
else: | |
logger.info('after attck') | |
for i in hits_at: | |
logger.info('Hits @{0}: {1}'.format(i, hits_at_both[i-1])) | |
logger.info('Mean rank: {0}'.format( mr)) | |
logger.info('Mean reciprocal rank lhs: {0}'.format(mrr)) | |
logger.info('Mean proportion: {0}'.format(mean)) | |
logger.info('Std proportion: {0}'.format(std)) | |
logger.info('Mean candidate num: {0}'.format(np.mean(total_nums))) | |
# with open(os.path.join('results', split + '_' + save_name + '.txt'), 'a') as text_file: | |
# text_file.write('Epoch: {0}\n'.format(epoch)) | |
# text_file.write('Lhs denotes ranking by subject corruptions \n') | |
# text_file.write('Rhs denotes ranking by object corruptions \n') | |
# for i in hits_at: | |
# text_file.write('Hits left @{0}: {1}\n'.format(i, hits_at_lhs[i-1])) | |
# text_file.write('Hits right @{0}: {1}\n'.format(i, hits_at_rhs[i-1])) | |
# text_file.write('Hits @{0}: {1}\n'.format(i, np.mean([hits_at_lhs[i-1],hits_at_rhs[i-1]]).item())) | |
# text_file.write('Mean rank lhs: {0}\n'.format( mr_lhs)) | |
# text_file.write('Mean rank rhs: {0}\n'.format(mr_rhs)) | |
# text_file.write('Mean rank: {0}\n'.format( np.mean([mr_lhs, mr_rhs]))) | |
# text_file.write('MRR lhs: {0}\n'.format( mrr_lhs)) | |
# text_file.write('MRR rhs: {0}\n'.format(mrr_rhs)) | |
# text_file.write('MRR: {0}\n'.format(np.mean([mrr_rhs, mrr_lhs]))) | |
# text_file.write('-------------------------------------------------\n') | |
results = {} | |
for i in hits_at: | |
results['hits @{}'.format(i)] = hits_at_both[i-1] | |
results['mrr'] = mrr | |
results['mr'] = mr | |
results['proportion'] = mean | |
results['std'] = std | |
return results, list(ranks), list(total_nums) | |
parser = utils.get_argument_parser() | |
parser = utils.add_attack_parameters(parser) | |
parser = utils.add_eval_parameters(parser) | |
args = parser.parse_args() | |
args = utils.set_hyperparams(args) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
utils.seed_all(args.seed) | |
np.set_printoptions(precision=5) | |
cudnn.benchmark = False | |
data_path = os.path.join('processed_data', args.data) | |
target_path = os.path.join(data_path, 'DD_target_{0}_{1}_{2}_{3}_{4}_{5}.txt'.format(args.model, args.data, args.target_split, args.target_size, 'exists:'+str(args.target_existed), args.attack_goal)) | |
log_path = 'logs/evaluation_logs/cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}'.format(args.model, | |
args.target_split, | |
args.target_size, | |
'exists:'+str(args.target_existed), | |
args.neighbor_num, | |
args.candidate_mode, | |
args.attack_goal, | |
str(args.reasonable_rate), | |
args.mode) | |
record_path = 'eval_record/{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}{9}{10}'.format(args.model, | |
args.target_split, | |
args.target_size, | |
'exists:'+str(args.target_existed), | |
args.neighbor_num, | |
args.candidate_mode, | |
args.attack_goal, | |
str(args.reasonable_rate), | |
args.mode, | |
str(args.added_edge_num), | |
args.mask_ratio) | |
init_record_path = 'eval_record/{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}'.format(args.model, | |
args.target_split, | |
args.target_size, | |
'exists:'+str(args.target_existed), | |
args.neighbor_num, | |
args.candidate_mode, | |
args.attack_goal, | |
str(args.reasonable_rate), | |
'init') | |
if args.seperate: | |
record_path += '_seperate' | |
log_path += '_seperate' | |
else: | |
record_path += '_batch' | |
if args.direct: | |
log_path += '_direct' | |
record_path += '_direct' | |
else: | |
log_path += '_nodirect' | |
record_path += '_nodirect' | |
dis_turbrbed_path_pre = os.path.join(data_path, 'evaluation') | |
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', | |
datefmt = '%m/%d/%Y %H:%M:%S', | |
level = logging.INFO, | |
filename = log_path | |
) | |
logger = logging.getLogger(__name__) | |
logger.info(vars(args)) | |
n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path) | |
model_name = '{0}_{1}_{2}_{3}_{4}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop) | |
model_path = 'saved_models/{0}_{1}.model'.format(args.data, model_name) | |
model = utils.load_model(model_path, args, n_ent, n_rel, device) | |
ori_data = utils.load_data(os.path.join(data_path, 'all.txt')) | |
target_data = utils.load_data(target_path) | |
index = range(len(target_data)) | |
index = np.random.permutation(index) | |
target_data = target_data[index] | |
if args.direct: | |
assert args.attack_goal == 'single' | |
raise Exception('This option is abandoned in this version .') | |
# disturbed_data = list(ori_data) + list(target_data) | |
else: | |
attack_path = os.path.join('attack_results', args.data, 'cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}{9}{10}.txt'.format(args.model, | |
args.target_split, | |
args.target_size, | |
'exists:'+str(args.target_existed), | |
args.neighbor_num, | |
args.candidate_mode, | |
args.attack_goal, | |
str(args.reasonable_rate), | |
args.mode, | |
str(args.added_edge_num), | |
args.mask_ratio)) | |
if args.mode == '': | |
attack_data = utils.load_data(attack_path, drop=False) | |
if not(args.added_edge_num == '' or int(args.added_edge_num) == 1): | |
assert int(args.added_edge_num) * len(target_data) == len(attack_data) | |
attack_data = attack_data.reshape((len(target_data), int(args.added_edge_num), 3)) | |
attack_data = attack_data[index] | |
else: | |
assert len(target_data) == len(attack_data) | |
attack_data = attack_data[index] | |
# if not args.seperate: | |
# disturbed_data = list(ori_data) + list(attack_data) | |
else: | |
with open(attack_path, 'rb') as fl: | |
attack_data = pkl.load(fl) | |
tmp_attack_data = [] | |
for vv in attack_data: | |
a_attack = [] | |
for v in vv: | |
if check(v, model, args.reasonable_rate, device): | |
a_attack.append(v) | |
tmp_attack_data.append(a_attack) | |
attack_data = tmp_attack_data | |
attack_data = [attack_data[i] for i in index] | |
# if not args.seperate: | |
# disturbed_data = list(ori_data) | |
# if args.mode == '': | |
# for aa in list(attack_data): | |
# if int(aa[0]) != -1: | |
# disturbed_data.append(aa) | |
# else: | |
# for vv in attack_data: | |
# for v in vv: | |
# disturbed_data.append(v) | |
with open(os.path.join(data_path, 'filter.pickle'), 'rb') as fl: | |
valid_filters = pkl.load(fl) | |
with open(os.path.join(data_path, 'entityid_to_nodetype.json'), 'r') as fl: | |
entityid_to_nodetype = json.load(fl) | |
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl: | |
entity_raw_name = pkl.load(fl) | |
with open(os.path.join(data_path, 'disease_meshid.pickle'), 'rb') as fl: | |
disease_meshid = pkl.load(fl) | |
with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl: | |
entity_to_id = json.load(fl) | |
if args.attack_goal == 'global': | |
raise Exception('Please refer to pagerank method in global setting.') | |
# target_disease = [] | |
# tid = 1 | |
# bound = 50 | |
# while True: | |
# meshid = disease_meshid[tid][0] | |
# fre = disease_meshid[tid][1] | |
# if len(entity_raw_name[meshid]) > 4: | |
# target_disease.append(entity_to_id[meshid]) | |
# bound -= 1 | |
# if bound == 0: | |
# break | |
# tid += 1 | |
# s_set = set() | |
# for s, r, o in target_data: | |
# s_set.add(s) | |
# target_data = list(s_set) | |
# target_data.sort() | |
# target_list = [] | |
# for s in target_data: | |
# for o in target_disease: | |
# target_list.append([str(s), str(10), str(o)]) | |
# target_data = np.array(target_list, dtype = str) | |
init_mask = np.asarray([0] * n_ent).astype('int64') | |
init_mask = (init_mask == 1) | |
for k, v in valid_filters.items(): | |
for kk, vv in v.items(): | |
tmp = init_mask.copy() | |
tmp[np.asarray(vv)] = True | |
t = torch.ByteTensor(tmp).to(device) | |
valid_filters[k][kk] = t | |
# print('what??', valid_filters['lhs'][('disease', 10)].sum()) | |
exists_edge = {'lhs':{}, 'rhs':{}} | |
for s, r, o in ori_data: | |
if s not in exists_edge['rhs'].keys(): | |
exists_edge['rhs'][s] = [] | |
if o not in exists_edge['lhs'].keys(): | |
exists_edge['lhs'][o] = [] | |
exists_edge['rhs'][s].append(int(o)) | |
exists_edge['lhs'][o].append(int(s)) | |
target_data = torch.from_numpy(target_data.astype('int64')).to(device) | |
# print(target_data[:5, :]) | |
ori_results, ori_ranks, ori_totals = evaluation(model, target_data, valid_filters, device, args.test_batch_size, entityid_to_nodetype, exists_edge, 'original') | |
print('Original:', ori_results) | |
with open(init_record_path, 'wb') as fl: | |
pkl.dump([ori_results, ori_ranks, ori_totals], fl) | |
# raise Exception('Check Original Rank!!!') | |
thread_name = args.model+'_'+args.target_split+'_'+args.attack_goal+'_'+str(args.reasonable_rate)+str(args.added_edge_num)+str(args.mask_ratio) | |
if args.direct: | |
thread_name += '_direct' | |
else: | |
thread_name += '_nodirect' | |
if args.seperate: | |
thread_name += '_seperate' | |
else: | |
thread_name += '_batch' | |
thread_name += args.mode | |
disturbed_data_path = os.path.join(dis_turbrbed_path_pre, 'all_{}.txt'.format(thread_name)) | |
if args.seperate: | |
# assert len(attack_data) * len(target_disease) == len(target_data) | |
assert len(attack_data) == len(target_data) | |
# final_result = None | |
Ranks = [] | |
Totals = [] | |
print('Training model {}...'.format(thread_name)) | |
for i in tqdm(range(len(attack_data))): | |
attack_trip = attack_data[i] | |
if args.mode == '': | |
attack_trip = [attack_trip] | |
# target = target_data[i*len(target_disease) : (i+1)*len(target_disease)] | |
target = target_data[i: i+1, :] | |
if len(attack_trip) > 0 and int(attack_trip[0][0]) != -1: | |
disturbed_data = list(ori_data) + attack_trip | |
disturbed_data = np.array(disturbed_data) | |
utils.save_data(disturbed_data_path, disturbed_data) | |
cmd = 'CUDA_VISIBLE_DEVICES={} python main_multiprocess.py --data {} --model {} --thread-name {}'.format(args.cuda_name,args.data, args.model, thread_name) | |
os.system(cmd) | |
model_name = '{0}_{1}_{2}_{3}_{4}_{5}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop, thread_name) | |
model_path = 'saved_models/evaluation/{0}_{1}.model'.format(args.data, model_name) | |
model = utils.load_model(model_path, args, n_ent, n_rel, device) | |
a_results, a_ranks, a_total_nums = evaluation(model, target, valid_filters, device, args.test_batch_size, entityid_to_nodetype, exists_edge) | |
assert len(a_ranks) == 1 | |
if not final_result: | |
final_result = a_results | |
else: | |
for k in final_result.keys(): | |
final_result[k] += a_results[k] | |
Ranks += a_ranks | |
Totals += a_total_nums | |
else: | |
Ranks += [ori_ranks[i]] | |
Totals += [ori_totals[i]] | |
final_result['proportion'] += ori_ranks[i] / ori_totals[i] | |
for k in final_result.keys(): | |
final_result[k] /= attack_data.shape[0] | |
print('Final !!!') | |
print(final_result) | |
logger.info('Final !!!!') | |
for k, v in final_result.items(): | |
logger.info('{} : {}'.format(k, v)) | |
tmp = np.array(Ranks) / np.array(Totals) | |
print('Std:', np.std(tmp)) | |
with open(record_path, 'wb') as fl: | |
pkl.dump([final_result, Ranks, Totals], fl) | |
else: | |
assert len(target_data) == len(attack_data) | |
print('Attack shape:' , len(attack_data)) | |
Results = [] | |
Ranks = [] | |
Totals = [] | |
for l in range(0, len(target_data), 50): | |
r = min(l+50, len(target_data)) | |
t_target_data = target_data[l:r] | |
t_attack_data = attack_data[l:r] | |
t_ori_ranks = ori_ranks[l:r] | |
t_ori_totals = ori_totals[l:r] | |
if args.mode == '': | |
if not(args.added_edge_num == '' or int(args.added_edge_num) == 1): | |
tt_attack_data = [] | |
for vv in t_attack_data: | |
tt_attack_data += list(vv) | |
t_attack_data = tt_attack_data | |
else: | |
assert args.mode == 'sentence' or args.mode == 'bioBART' | |
tt_attack_data = [] | |
for vv in t_attack_data: | |
tt_attack_data += vv | |
t_attack_data = tt_attack_data | |
disturbed_data = list(ori_data) + list(t_attack_data) | |
utils.save_data(disturbed_data_path, disturbed_data) | |
cmd = 'CUDA_VISIBLE_DEVICES={} python main_multiprocess.py --data {} --model {} --thread-name {}'.format(args.cuda_name,args.data, args.model, thread_name) | |
print('Training model {}...'.format(thread_name)) | |
os.system(cmd) | |
model_name = '{0}_{1}_{2}_{3}_{4}_{5}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop, thread_name) | |
model_path = 'saved_models/evaluation/{0}_{1}.model'.format(args.data, model_name) | |
model = utils.load_model(model_path, args, n_ent, n_rel, device) | |
a_results, a_ranks, a_totals = evaluation(model, t_target_data, valid_filters, device, args.test_batch_size, entityid_to_nodetype, exists_edge, attack_data = attack_data[l:r], ori_ranks = t_ori_ranks, ori_totals = t_ori_totals) | |
print(f'************Current l: {l}\n', a_results) | |
assert len(a_ranks) == t_target_data.shape[0] | |
Results += [a_results] | |
Ranks += list(a_ranks) | |
Totals += list(a_totals) | |
with open(record_path, 'wb') as fl: | |
pkl.dump([Results, Ranks, Totals, index], fl) |