File size: 9,820 Bytes
fce1f4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
'''
A file modified on https://github.com/PeruBhardwaj/AttributionAttack/blob/main/KGEAttack/ConvE/utils.py
'''
#%%
import logging
import time
from tqdm import tqdm
import io
import pandas as pd
import numpy as np
import os
import json

import argparse
import torch
import random

from yaml import parse

from model import Conve, Distmult, Complex

logger = logging.getLogger(__name__)
#%%
def generate_dicts(data_path):
    with open (os.path.join(data_path, 'entities_dict.json'), 'r') as f:
        ent_to_id = json.load(f)
    with open (os.path.join(data_path, 'relations_dict.json'), 'r') as f:
        rel_to_id = json.load(f)
    n_ent = len(list(ent_to_id.keys()))
    n_rel = len(list(rel_to_id.keys()))
    
    return n_ent, n_rel, ent_to_id, rel_to_id

def save_data(file_name, data):
    with open(file_name, 'w') as fl:
        for item in data:
            fl.write("%s\n" % "\t".join(map(str, item)))

def load_data(file_name):
    df = pd.read_csv(file_name, sep='\t', header=None, names=None, dtype=str)
    df = df.drop_duplicates()
    return df.values

def seed_all(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True

def add_model(args, n_ent, n_rel):
    if args.model is None:
        model = Distmult(args, n_ent, n_rel)
    elif args.model == 'distmult':
        model = Distmult(args, n_ent, n_rel)
    elif args.model == 'complex':
        model = Complex(args, n_ent, n_rel)
    elif args.model == 'conve':
        model = Conve(args, n_ent, n_rel)
    else:
        raise Exception("Unknown model!")

    return model

def load_model(model_path, args, n_ent, n_rel, device):
    # add a model and load the pre-trained params
    model = add_model(args, n_ent, n_rel)
    model.to(device)
    logger.info('Loading saved model from {0}'.format(model_path))
    state = torch.load(model_path)
    model_params = state['state_dict']
    params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
    for key, size, count in params:
        logger.info('Key:{0}, Size:{1}, Count:{2}'.format(key, size, count))
        
    model.load_state_dict(model_params)
    model.eval()
    logger.info(model)
    
    return model

def add_eval_parameters(parser):

    parser.add_argument('--eval-mode', type = str, default = 'all', help = 'Method to evaluate the attack performance. Default: all. (all or single)')
    parser.add_argument('--cuda-name', type = str, required = True, help = 'Start a main thread on each cuda.')
    parser.add_argument('--direct', action='store_true', help = 'Directly add edge or not.')
    parser.add_argument('--seperate', action='store_true', help = 'Evaluate seperatly or not')
    return parser

def add_attack_parameters(parser):

    # parser.add_argument('--target-split', type=str, default='0_100_1', help='Ranks to use for target set. Values are 0 for ranks==1; 1 for ranks <=10; 2 for ranks>10 and ranks<=100. Default: 1')
    parser.add_argument('--target-split', type=str, default='min', help='Methods for target triple selection. Default: min. (min or top_?, top means top_0.1)')
    parser.add_argument('--target-size', type=int, default=50, help='Number of target triples. Default: 50')
    parser.add_argument('--target-existed', action='store_true', help='Whether the targeted s_?_o already exists.')

    # parser.add_argument('--budget', type=int, default=1, help='Budget for each target triple for each corruption side')

    parser.add_argument('--attack-goal', type = str, default='single', help='Attack goal. Default: single. (single or global)')
    parser.add_argument('--neighbor-num', type = int, default=20, help='Max neighbor num for each side. Default: 20')
    parser.add_argument('--candidate-mode', type = str, default='quadratic', help = 'The method to generate candidate edge. Default: quadratic. (quadratic or linear)')
    parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
    # parser.add_argument('--neighbor-num', type = int, default=200, help='Max neighbor num for each side. Default: 200')
    # parser.add_argument('--candidate-mode', type = str, default='linear', help = 'The method to generate candidate edge. Default: quadratic. (quadratic or linear)')
    parser.add_argument('--attack-batch-size', type=int, default=256, help='Batch size for processing neighbours of target')
    parser.add_argument('--template-mode', type=str, default = 'manual', help = 'Template mode for transforming edge to single sentense. Default: manual. (manual or auto)')

    parser.add_argument('--update-lissa', action='store_true', help = 'Update lissa cache or not.')

    parser.add_argument('--GPT-batch-size', type=int, default = 64, help = 'Batch size for GPT2 when calculating LM score. Default: 64')
    parser.add_argument('--LM-softmax', action='store_true', help = 'Use a softmax head on LM prob or not.')
    parser.add_argument('--LMprob-mode', type=str, default='relative', help = 'Use the absolute LM score or calculate the destruction score when target word is replaced. Default: absolute. (absolute or relative)')

    return parser

def get_argument_parser():
    '''Generate an argument parser'''
    parser = argparse.ArgumentParser(description='Graph embedding')

    parser.add_argument('--seed', type=int, default=1, metavar='S', help='Random seed (default: 1)')
    
    parser.add_argument('--data', type=str, default='GNBR', help='Dataset to use: { GNBR }')
    parser.add_argument('--model', type=str, default='distmult', help='Choose from: {distmult, complex, transe, conve}')
    
    parser.add_argument('--transe-margin', type=float, default=0.0, help='Margin value for TransE scoring function. Default:0.0')
    parser.add_argument('--transe-norm', type=int, default=2, help='P-norm value for TransE scoring function. Default:2')
    
    parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train (default: 100)')
    parser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)')
    parser.add_argument('--lr-decay', type=float, default=0.0, help='Weight decay value to use in the optimizer. Default: 0.0')
    parser.add_argument('--max-norm', action='store_true', help='Option to add unit max norm constraint to entity embeddings')
    
    parser.add_argument('--train-batch-size', type=int, default=64, help='Batch size for train split (default: 128)')
    parser.add_argument('--test-batch-size', type=int, default=128, help='Batch size for test split (default: 128)')
    parser.add_argument('--valid-batch-size', type=int, default=128, help='Batch size for valid split (default: 128)')
    parser.add_argument('--KG-valid-rate', type = float, default=0.1, help='Validation rate during KG embedding training. (default: 0.1)')
    
    parser.add_argument('--save-influence-map', action='store_true', help='Save the influence map during training for gradient rollback.')
    parser.add_argument('--add-reciprocals', action='store_true')

    parser.add_argument('--embedding-dim', type=int, default=128, help='The embedding dimension (1D). Default: 128')
    parser.add_argument('--stack-width', type=int, default=16, help='The first dimension of the reshaped/stacked 2D embedding. Second dimension is inferred. Default: 20')
    #parser.add_argument('--stack_height', type=int, default=10, help='The second dimension of the reshaped/stacked 2D embedding. Default: 10')
    parser.add_argument('--hidden-drop', type=float, default=0.3, help='Dropout for the hidden layer. Default: 0.3.')
    parser.add_argument('--input-drop', type=float, default=0.2, help='Dropout for the input embeddings. Default: 0.2.')
    parser.add_argument('--feat-drop', type=float, default=0.3, help='Dropout for the convolutional features. Default: 0.2.')
    parser.add_argument('-num-filters', default=32,   type=int, help='Number of filters for convolution')
    parser.add_argument('-kernel-size', default=3, type=int, help='Kernel Size for convolution')
    
    parser.add_argument('--use-bias', action='store_true', help='Use a bias in the convolutional layer. Default: True')
    
    parser.add_argument('--reg-weight', type=float, default=5e-2, help='Weight for regularization. Default: 5e-2')
    parser.add_argument('--reg-norm', type=int, default=3, help='Norm for regularization. Default: 2')
    # parser.add_argument('--resume', action='store_true', help='Restore a saved model.')
    # parser.add_argument('--resume-split', type=str, default='test', help='Split to evaluate a restored model')
    # parser.add_argument('--reproduce-results', action='store_true', help='Use the hyperparameters to reproduce the results.')
    # parser.add_argument('--original-data', type=str, default='FB15k-237', help='Dataset to use; this option is needed to set the hyperparams to reproduce the results for training after attack, default: FB15k-237')
    return parser

def set_hyperparams(args):
    if args.model == 'distmult':
        args.lr = 0.005
        args.train_batch_size = 1024
        args.reg_norm = 3
    elif args.model == 'complex':
        args.lr = 0.005
        args.reg_norm = 3
        args.input_drop = 0.4
        args.train_batch_size = 1024
    elif args.model == 'conve':
        args.lr = 0.005
        args.train_batch_size = 1024
        args.reg_weight = 0.0
    
    # args.damping = 0.01 
    # args.lissa_repeat = 1 
    # args.lissa_depth = 1
    # args.scale = 500
    # args.lissa_batch_size = 100

    args.damping = 0.01 
    args.lissa_repeat = 1 
    args.lissa_depth = 1
    args.scale = 400
    args.lissa_batch_size = 300
    return args