PFEemp2024 commited on
Commit
63775f2
·
1 Parent(s): c2c01a0

add necessary file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. anonymous_demo/__init__.py +5 -0
  2. anonymous_demo/core/__init__.py +0 -0
  3. anonymous_demo/core/tad/__init__.py +0 -0
  4. anonymous_demo/core/tad/classic/__bert__/README.MD +3 -0
  5. anonymous_demo/core/tad/classic/__bert__/__init__.py +1 -0
  6. anonymous_demo/core/tad/classic/__bert__/dataset_utils/__init__.py +0 -0
  7. anonymous_demo/core/tad/classic/__bert__/dataset_utils/data_utils_for_inference.py +121 -0
  8. anonymous_demo/core/tad/classic/__bert__/models/__init__.py +1 -0
  9. anonymous_demo/core/tad/classic/__bert__/models/tad_bert.py +46 -0
  10. anonymous_demo/core/tad/classic/__init__.py +0 -0
  11. anonymous_demo/core/tad/models/__init__.py +9 -0
  12. anonymous_demo/core/tad/prediction/__init__.py +0 -0
  13. anonymous_demo/core/tad/prediction/tad_classifier.py +518 -0
  14. anonymous_demo/functional/__init__.py +3 -0
  15. anonymous_demo/functional/checkpoint/__init__.py +1 -0
  16. anonymous_demo/functional/checkpoint/checkpoint_manager.py +19 -0
  17. anonymous_demo/functional/config/__init__.py +1 -0
  18. anonymous_demo/functional/config/config_manager.py +64 -0
  19. anonymous_demo/functional/config/tad_config_manager.py +229 -0
  20. anonymous_demo/functional/dataset/__init__.py +1 -0
  21. anonymous_demo/functional/dataset/dataset_manager.py +45 -0
  22. anonymous_demo/network/__init__.py +0 -0
  23. anonymous_demo/network/lcf_pooler.py +28 -0
  24. anonymous_demo/network/lsa.py +73 -0
  25. anonymous_demo/network/sa_encoder.py +199 -0
  26. anonymous_demo/utils/__init__.py +0 -0
  27. anonymous_demo/utils/demo_utils.py +247 -0
  28. anonymous_demo/utils/logger.py +38 -0
  29. checkpoints.zip +3 -0
  30. text_defense/201.SST2/stsa.binary.dev.dat +0 -0
  31. text_defense/201.SST2/stsa.binary.test.dat +0 -0
  32. text_defense/201.SST2/stsa.binary.train.dat +0 -0
  33. text_defense/202.IMDB10K/imdb10k.test.dat +0 -0
  34. text_defense/202.IMDB10K/imdb10k.train.dat +0 -0
  35. text_defense/202.IMDB10K/imdb10k.valid.dat +0 -0
  36. text_defense/204.AGNews10K/AGNews10K.test.dat +0 -0
  37. text_defense/204.AGNews10K/AGNews10K.train.dat +0 -0
  38. text_defense/204.AGNews10K/AGNews10K.valid.dat +0 -0
  39. text_defense/206.Amazon_Review_Polarity10K/amazon.test.dat +0 -0
  40. text_defense/206.Amazon_Review_Polarity10K/amazon.train.dat +0 -0
  41. textattack/__init__.py +39 -0
  42. textattack/__main__.py +6 -0
  43. textattack/attack.py +492 -0
  44. textattack/attack_args.py +763 -0
  45. textattack/attack_recipes/__init__.py +43 -0
  46. textattack/attack_recipes/a2t_yoo_2021.py +74 -0
  47. textattack/attack_recipes/attack_recipe.py +30 -0
  48. textattack/attack_recipes/bae_garg_2019.py +123 -0
  49. textattack/attack_recipes/bert_attack_li_2020.py +95 -0
  50. textattack/attack_recipes/checklist_ribeiro_2020.py +53 -0
anonymous_demo/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __version__ = "1.0.0"
2
+
3
+ __name__ = "anonymous_demo"
4
+
5
+ from anonymous_demo.functional import TADCheckpointManager
anonymous_demo/core/__init__.py ADDED
File without changes
anonymous_demo/core/tad/__init__.py ADDED
File without changes
anonymous_demo/core/tad/classic/__bert__/README.MD ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ## This is the simple migration from ABSA-PyTorch under MIT license
2
+
3
+ Project Address: https://github.com/songyouwei/ABSA-PyTorch
anonymous_demo/core/tad/classic/__bert__/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .models import *
anonymous_demo/core/tad/classic/__bert__/dataset_utils/__init__.py ADDED
File without changes
anonymous_demo/core/tad/classic/__bert__/dataset_utils/data_utils_for_inference.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ from findfile import find_cwd_dir
3
+ from torch.utils.data import Dataset
4
+ from transformers import AutoTokenizer
5
+
6
+
7
+ class Tokenizer4Pretraining:
8
+ def __init__(self, max_seq_len, opt, **kwargs):
9
+ if kwargs.pop("offline", False):
10
+ self.tokenizer = AutoTokenizer.from_pretrained(
11
+ find_cwd_dir(opt.pretrained_bert.split("/")[-1]),
12
+ do_lower_case="uncased" in opt.pretrained_bert,
13
+ )
14
+ else:
15
+ self.tokenizer = AutoTokenizer.from_pretrained(
16
+ opt.pretrained_bert, do_lower_case="uncased" in opt.pretrained_bert
17
+ )
18
+ self.max_seq_len = max_seq_len
19
+
20
+ def text_to_sequence(self, text, reverse=False, padding="post", truncating="post"):
21
+ return self.tokenizer.encode(
22
+ text,
23
+ truncation=True,
24
+ padding="max_length",
25
+ max_length=self.max_seq_len,
26
+ return_tensors="pt",
27
+ )
28
+
29
+
30
+ class BERTTADDataset(Dataset):
31
+ def __init__(self, tokenizer, opt):
32
+ self.bert_baseline_input_colses = {"bert": ["text_bert_indices"]}
33
+
34
+ self.tokenizer = tokenizer
35
+ self.opt = opt
36
+ self.all_data = []
37
+
38
+ def parse_sample(self, text):
39
+ return [text]
40
+
41
+ def prepare_infer_sample(self, text: str, ignore_error):
42
+ self.process_data(self.parse_sample(text), ignore_error=ignore_error)
43
+
44
+ def process_data(self, samples, ignore_error=True):
45
+ all_data = []
46
+ if len(samples) > 100:
47
+ it = tqdm.tqdm(
48
+ samples, postfix="preparing text classification inference dataloader..."
49
+ )
50
+ else:
51
+ it = samples
52
+ for text in it:
53
+ try:
54
+ # handle for empty lines in inference datasets
55
+ if text is None or "" == text.strip():
56
+ raise RuntimeError("Invalid Input!")
57
+
58
+ if "!ref!" in text:
59
+ text, _, labels = text.strip().partition("!ref!")
60
+ text = text.strip()
61
+ if labels.count(",") == 2:
62
+ label, is_adv, adv_train_label = labels.strip().split(",")
63
+ label, is_adv, adv_train_label = (
64
+ label.strip(),
65
+ is_adv.strip(),
66
+ adv_train_label.strip(),
67
+ )
68
+ elif labels.count(",") == 1:
69
+ label, is_adv = labels.strip().split(",")
70
+ label, is_adv = label.strip(), is_adv.strip()
71
+ adv_train_label = "-100"
72
+ elif labels.count(",") == 0:
73
+ label = labels.strip()
74
+ adv_train_label = "-100"
75
+ is_adv = "-100"
76
+ else:
77
+ label = "-100"
78
+ adv_train_label = "-100"
79
+ is_adv = "-100"
80
+
81
+ label = int(label)
82
+ adv_train_label = int(adv_train_label)
83
+ is_adv = int(is_adv)
84
+
85
+ else:
86
+ text = text.strip()
87
+ label = -100
88
+ adv_train_label = -100
89
+ is_adv = -100
90
+
91
+ text_indices = self.tokenizer.text_to_sequence("{}".format(text))
92
+
93
+ data = {
94
+ "text_bert_indices": text_indices[0],
95
+ "text_raw": text,
96
+ "label": label,
97
+ "adv_train_label": adv_train_label,
98
+ "is_adv": is_adv,
99
+ # 'label': self.opt.label_to_index.get(label, -100) if isinstance(label, str) else label,
100
+ #
101
+ # 'adv_train_label': self.opt.adv_train_label_to_index.get(adv_train_label, -100) if isinstance(adv_train_label, str) else adv_train_label,
102
+ #
103
+ # 'is_adv': self.opt.is_adv_to_index.get(is_adv, -100) if isinstance(is_adv, str) else is_adv,
104
+ }
105
+
106
+ all_data.append(data)
107
+
108
+ except Exception as e:
109
+ if ignore_error:
110
+ print("Ignore error while processing:", text)
111
+ else:
112
+ raise e
113
+
114
+ self.all_data = all_data
115
+ return self.all_data
116
+
117
+ def __getitem__(self, index):
118
+ return self.all_data[index]
119
+
120
+ def __len__(self):
121
+ return len(self.all_data)
anonymous_demo/core/tad/classic/__bert__/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .tad_bert import TADBERT
anonymous_demo/core/tad/classic/__bert__/models/tad_bert.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers.models.bert.modeling_bert import BertPooler
4
+
5
+ from anonymous_demo.network.sa_encoder import Encoder
6
+
7
+
8
+ class TADBERT(nn.Module):
9
+ inputs = ["text_bert_indices"]
10
+
11
+ def __init__(self, bert, opt):
12
+ super(TADBERT, self).__init__()
13
+ self.opt = opt
14
+ self.bert = bert
15
+ self.pooler = BertPooler(bert.config)
16
+ self.dense1 = nn.Linear(self.opt.hidden_dim, self.opt.class_dim)
17
+ self.dense2 = nn.Linear(self.opt.hidden_dim, self.opt.adv_det_dim)
18
+ self.dense3 = nn.Linear(self.opt.hidden_dim, self.opt.class_dim)
19
+
20
+ self.encoder1 = Encoder(self.bert.config, opt=opt)
21
+ self.encoder2 = Encoder(self.bert.config, opt=opt)
22
+ self.encoder3 = Encoder(self.bert.config, opt=opt)
23
+
24
+ def forward(self, inputs):
25
+ text_raw_indices = inputs[0]
26
+ last_hidden_state = self.bert(text_raw_indices)["last_hidden_state"]
27
+
28
+ sent_logits = self.dense1(self.pooler(last_hidden_state))
29
+ advdet_logits = self.dense2(self.pooler(last_hidden_state))
30
+ adv_tr_logits = self.dense3(self.pooler(last_hidden_state))
31
+
32
+ att_score = torch.nn.functional.normalize(
33
+ last_hidden_state.abs().sum(dim=1, keepdim=False)
34
+ - last_hidden_state.abs().min(dim=1, keepdim=True)[0],
35
+ p=1,
36
+ dim=1,
37
+ )
38
+
39
+ outputs = {
40
+ "sent_logits": sent_logits,
41
+ "advdet_logits": advdet_logits,
42
+ "adv_tr_logits": adv_tr_logits,
43
+ "last_hidden_state": last_hidden_state,
44
+ "att_score": att_score,
45
+ }
46
+ return outputs
anonymous_demo/core/tad/classic/__init__.py ADDED
File without changes
anonymous_demo/core/tad/models/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import anonymous_demo.core.tad.classic.__bert__.models
2
+
3
+
4
+ class BERTTADModelList(list):
5
+ TADBERT = anonymous_demo.core.tad.classic.__bert__.TADBERT
6
+
7
+ def __init__(self):
8
+ model_list = [self.TADBERT]
9
+ super().__init__(model_list)
anonymous_demo/core/tad/prediction/__init__.py ADDED
File without changes
anonymous_demo/core/tad/prediction/tad_classifier.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pickle
4
+ import time
5
+
6
+ import torch
7
+ import tqdm
8
+ from findfile import find_file, find_cwd_dir
9
+ from termcolor import colored
10
+
11
+ from torch.utils.data import DataLoader
12
+ from transformers import (
13
+ AutoTokenizer,
14
+ AutoModel,
15
+ AutoConfig,
16
+ DebertaV2ForMaskedLM,
17
+ RobertaForMaskedLM,
18
+ BertForMaskedLM,
19
+ )
20
+
21
+ from ....functional.dataset.dataset_manager import detect_infer_dataset
22
+
23
+ from ..models import BERTTADModelList
24
+ from ..classic.__bert__.dataset_utils.data_utils_for_inference import (
25
+ BERTTADDataset,
26
+ Tokenizer4Pretraining,
27
+ )
28
+
29
+ from ....utils.demo_utils import (
30
+ print_args,
31
+ TransformerConnectionError,
32
+ get_device,
33
+ build_embedding_matrix,
34
+ )
35
+
36
+
37
+ def init_attacker(tad_classifier, defense):
38
+ try:
39
+ from textattack import Attacker
40
+ from textattack.attack_recipes import (
41
+ BAEGarg2019,
42
+ PWWSRen2019,
43
+ TextFoolerJin2019,
44
+ PSOZang2020,
45
+ IGAWang2019,
46
+ GeneticAlgorithmAlzantot2018,
47
+ DeepWordBugGao2018,
48
+ )
49
+ from textattack.datasets import Dataset
50
+ from textattack.models.wrappers import HuggingFaceModelWrapper
51
+
52
+ class DemoModelWrapper(HuggingFaceModelWrapper):
53
+ def __init__(self, model):
54
+ self.model = model # pipeline = pipeline
55
+
56
+ def __call__(self, text_inputs, **kwargs):
57
+ outputs = []
58
+ for text_input in text_inputs:
59
+ raw_outputs = self.model.infer(
60
+ text_input, print_result=False, **kwargs
61
+ )
62
+ outputs.append(raw_outputs["probs"])
63
+ return outputs
64
+
65
+ class SentAttacker:
66
+ def __init__(self, model, recipe_class=BAEGarg2019):
67
+ model = model
68
+ model_wrapper = DemoModelWrapper(model)
69
+
70
+ recipe = recipe_class.build(model_wrapper)
71
+
72
+ _dataset = [("", 0)]
73
+ _dataset = Dataset(_dataset)
74
+
75
+ self.attacker = Attacker(recipe, _dataset)
76
+
77
+ attackers = {
78
+ "bae": BAEGarg2019,
79
+ "pwws": PWWSRen2019,
80
+ "textfooler": TextFoolerJin2019,
81
+ "pso": PSOZang2020,
82
+ "iga": IGAWang2019,
83
+ "ga": GeneticAlgorithmAlzantot2018,
84
+ "wordbugger": DeepWordBugGao2018,
85
+ }
86
+ return SentAttacker(tad_classifier, attackers[defense])
87
+ except Exception as e:
88
+ print("Original error:", e)
89
+
90
+
91
+ def get_mlm_and_tokenizer(text_classifier, config):
92
+ if isinstance(text_classifier, TADTextClassifier):
93
+ base_model = text_classifier.model.bert.base_model
94
+ else:
95
+ base_model = text_classifier.bert.base_model
96
+ pretrained_config = AutoConfig.from_pretrained(config.pretrained_bert)
97
+ if "deberta-v3" in config.pretrained_bert:
98
+ MLM = DebertaV2ForMaskedLM(pretrained_config)
99
+ MLM.deberta = base_model
100
+ elif "roberta" in config.pretrained_bert:
101
+ MLM = RobertaForMaskedLM(pretrained_config)
102
+ MLM.roberta = base_model
103
+ else:
104
+ MLM = BertForMaskedLM(pretrained_config)
105
+ MLM.bert = base_model
106
+ return MLM, AutoTokenizer.from_pretrained(config.pretrained_bert)
107
+
108
+
109
+ class TADTextClassifier:
110
+ def __init__(self, model_arg=None, cal_perplexity=False, **kwargs):
111
+ """
112
+ from_train_model: load inference model from trained model
113
+ """
114
+ self.cal_perplexity = cal_perplexity
115
+ # load from a training
116
+ if not isinstance(model_arg, str):
117
+ print("Load text classifier from training")
118
+ self.model = model_arg[0]
119
+ self.opt = model_arg[1]
120
+ self.tokenizer = model_arg[2]
121
+ else:
122
+ try:
123
+ if "fine-tuned" in model_arg:
124
+ raise ValueError(
125
+ "Do not support to directly load a fine-tuned model, please load a .state_dict or .model instead!"
126
+ )
127
+ print("Load text classifier from", model_arg)
128
+ state_dict_path = find_file(
129
+ model_arg, key=".state_dict", exclude_key=["__MACOSX"]
130
+ )
131
+ model_path = find_file(
132
+ model_arg, key=".model", exclude_key=["__MACOSX"]
133
+ )
134
+ tokenizer_path = find_file(
135
+ model_arg, key=".tokenizer", exclude_key=["__MACOSX"]
136
+ )
137
+ config_path = find_file(
138
+ model_arg, key=".config", exclude_key=["__MACOSX"]
139
+ )
140
+
141
+ print("config: {}".format(config_path))
142
+ print("state_dict: {}".format(state_dict_path))
143
+ print("model: {}".format(model_path))
144
+ print("tokenizer: {}".format(tokenizer_path))
145
+
146
+ with open(config_path, mode="rb") as f:
147
+ self.opt = pickle.load(f)
148
+ self.opt.device = get_device(kwargs.pop("auto_device", True))[0]
149
+
150
+ if state_dict_path or model_path:
151
+ if hasattr(BERTTADModelList, self.opt.model.__name__):
152
+ if state_dict_path:
153
+ if kwargs.pop("offline", False):
154
+ self.bert = AutoModel.from_pretrained(
155
+ find_cwd_dir(
156
+ self.opt.pretrained_bert.split("/")[-1]
157
+ )
158
+ )
159
+ else:
160
+ self.bert = AutoModel.from_pretrained(
161
+ self.opt.pretrained_bert
162
+ )
163
+ self.model = self.opt.model(self.bert, self.opt)
164
+ self.model.load_state_dict(
165
+ torch.load(state_dict_path, map_location="cpu")
166
+ )
167
+ elif model_path:
168
+ self.model = torch.load(model_path, map_location="cpu")
169
+
170
+ try:
171
+ self.tokenizer = Tokenizer4Pretraining(
172
+ max_seq_len=self.opt.max_seq_len, opt=self.opt, **kwargs
173
+ )
174
+ except ValueError:
175
+ if tokenizer_path:
176
+ with open(tokenizer_path, mode="rb") as f:
177
+ self.tokenizer = pickle.load(f)
178
+ else:
179
+ raise TransformerConnectionError()
180
+
181
+ except Exception as e:
182
+ raise RuntimeError(
183
+ "Exception: {} Fail to load the model from {}! ".format(
184
+ e, model_arg
185
+ )
186
+ )
187
+
188
+ self.infer_dataloader = None
189
+ self.opt.eval_batch_size = kwargs.pop("eval_batch_size", 128)
190
+
191
+ self.opt.initializer = self.opt.initializer
192
+
193
+ if self.cal_perplexity:
194
+ try:
195
+ self.MLM, self.MLM_tokenizer = get_mlm_and_tokenizer(self, self.opt)
196
+ except Exception as e:
197
+ self.MLM, self.MLM_tokenizer = None, None
198
+
199
+ self.to(self.opt.device)
200
+
201
+ def to(self, device=None):
202
+ self.opt.device = device
203
+ self.model.to(device)
204
+ if hasattr(self, "MLM"):
205
+ self.MLM.to(self.opt.device)
206
+
207
+ def cpu(self):
208
+ self.opt.device = "cpu"
209
+ self.model.to("cpu")
210
+ if hasattr(self, "MLM"):
211
+ self.MLM.to("cpu")
212
+
213
+ def cuda(self, device="cuda:0"):
214
+ self.opt.device = device
215
+ self.model.to(device)
216
+ if hasattr(self, "MLM"):
217
+ self.MLM.to(device)
218
+
219
+ def _log_write_args(self):
220
+ n_trainable_params, n_nontrainable_params = 0, 0
221
+ for p in self.model.parameters():
222
+ n_params = torch.prod(torch.tensor(p.shape))
223
+ if p.requires_grad:
224
+ n_trainable_params += n_params
225
+ else:
226
+ n_nontrainable_params += n_params
227
+ print(
228
+ "n_trainable_params: {0}, n_nontrainable_params: {1}".format(
229
+ n_trainable_params, n_nontrainable_params
230
+ )
231
+ )
232
+ for arg in vars(self.opt):
233
+ if getattr(self.opt, arg) is not None:
234
+ print(">>> {0}: {1}".format(arg, getattr(self.opt, arg)))
235
+
236
+ def batch_infer(
237
+ self,
238
+ target_file=None,
239
+ print_result=True,
240
+ save_result=False,
241
+ ignore_error=True,
242
+ defense: str = None,
243
+ ):
244
+ save_path = os.path.join(os.getcwd(), "tad_text_classification.result.json")
245
+
246
+ target_file = detect_infer_dataset(target_file, task="text_defense")
247
+ if not target_file:
248
+ raise FileNotFoundError("Can not find inference datasets!")
249
+
250
+ if hasattr(BERTTADModelList, self.opt.model.__name__):
251
+ dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
252
+
253
+ dataset.prepare_infer_dataset(target_file, ignore_error=ignore_error)
254
+ self.infer_dataloader = DataLoader(
255
+ dataset=dataset,
256
+ batch_size=self.opt.eval_batch_size,
257
+ pin_memory=True,
258
+ shuffle=False,
259
+ )
260
+ return self._infer(
261
+ save_path=save_path if save_result else None,
262
+ print_result=print_result,
263
+ defense=defense,
264
+ )
265
+
266
+ def infer(
267
+ self,
268
+ text: str = None,
269
+ print_result=True,
270
+ ignore_error=True,
271
+ defense: str = None,
272
+ ):
273
+ if hasattr(BERTTADModelList, self.opt.model.__name__):
274
+ dataset = BERTTADDataset(tokenizer=self.tokenizer, opt=self.opt)
275
+
276
+ if text:
277
+ dataset.prepare_infer_sample(text, ignore_error=ignore_error)
278
+ else:
279
+ raise RuntimeError("Please specify your datasets path!")
280
+ self.infer_dataloader = DataLoader(
281
+ dataset=dataset, batch_size=self.opt.eval_batch_size, shuffle=False
282
+ )
283
+ return self._infer(print_result=print_result, defense=defense)[0]
284
+
285
+ def _infer(self, save_path=None, print_result=True, defense=None):
286
+ _params = filter(lambda p: p.requires_grad, self.model.parameters())
287
+
288
+ correct = {True: "Correct", False: "Wrong"}
289
+ results = []
290
+
291
+ with torch.no_grad():
292
+ self.model.eval()
293
+ n_correct = 0
294
+ n_labeled = 0
295
+
296
+ n_advdet_correct = 0
297
+ n_advdet_labeled = 0
298
+ if len(self.infer_dataloader.dataset) >= 100:
299
+ it = tqdm.tqdm(self.infer_dataloader, postfix="inferring...")
300
+ else:
301
+ it = self.infer_dataloader
302
+ for _, sample in enumerate(it):
303
+ inputs = [
304
+ sample[col].to(self.opt.device) for col in self.opt.inputs_cols
305
+ ]
306
+ outputs = self.model(inputs)
307
+ logits, advdet_logits, adv_tr_logits = (
308
+ outputs["sent_logits"],
309
+ outputs["advdet_logits"],
310
+ outputs["adv_tr_logits"],
311
+ )
312
+ probs, advdet_probs, adv_tr_probs = (
313
+ torch.softmax(logits, dim=-1),
314
+ torch.softmax(advdet_logits, dim=-1),
315
+ torch.softmax(adv_tr_logits, dim=-1),
316
+ )
317
+
318
+ for i, (prob, advdet_prob, adv_tr_prob) in enumerate(
319
+ zip(probs, advdet_probs, adv_tr_probs)
320
+ ):
321
+ text_raw = sample["text_raw"][i]
322
+
323
+ pred_label = int(prob.argmax(axis=-1))
324
+ pred_is_adv_label = int(advdet_prob.argmax(axis=-1))
325
+ pred_adv_tr_label = int(adv_tr_prob.argmax(axis=-1))
326
+ ref_label = (
327
+ int(sample["label"][i])
328
+ if int(sample["label"][i]) in self.opt.index_to_label
329
+ else ""
330
+ )
331
+ ref_is_adv_label = (
332
+ int(sample["is_adv"][i])
333
+ if int(sample["is_adv"][i]) in self.opt.index_to_is_adv
334
+ else ""
335
+ )
336
+ ref_adv_tr_label = (
337
+ int(sample["adv_train_label"][i])
338
+ if int(sample["adv_train_label"][i])
339
+ in self.opt.index_to_adv_train_label
340
+ else ""
341
+ )
342
+
343
+ if self.cal_perplexity:
344
+ ids = self.MLM_tokenizer(text_raw, return_tensors="pt")
345
+ ids["labels"] = ids["input_ids"].clone()
346
+ ids = ids.to(self.opt.device)
347
+ loss = self.MLM(**ids)["loss"]
348
+ perplexity = float(torch.exp(loss / ids["input_ids"].size(1)))
349
+ else:
350
+ perplexity = "N.A."
351
+
352
+ result = {
353
+ "text": text_raw,
354
+ "label": self.opt.index_to_label[pred_label],
355
+ "probs": prob.cpu().numpy(),
356
+ "confidence": float(max(prob)),
357
+ "ref_label": self.opt.index_to_label[ref_label]
358
+ if isinstance(ref_label, int)
359
+ else ref_label,
360
+ "ref_label_check": correct[pred_label == ref_label]
361
+ if ref_label != -100
362
+ else "",
363
+ "is_fixed": False,
364
+ "is_adv_label": self.opt.index_to_is_adv[pred_is_adv_label],
365
+ "is_adv_probs": advdet_prob.cpu().numpy(),
366
+ "is_adv_confidence": float(max(advdet_prob)),
367
+ "ref_is_adv_label": self.opt.index_to_is_adv[ref_is_adv_label]
368
+ if isinstance(ref_is_adv_label, int)
369
+ else ref_is_adv_label,
370
+ "ref_is_adv_check": correct[
371
+ pred_is_adv_label == ref_is_adv_label
372
+ ]
373
+ if ref_is_adv_label != -100
374
+ and isinstance(ref_is_adv_label, int)
375
+ else "",
376
+ "pred_adv_tr_label": self.opt.index_to_label[pred_adv_tr_label],
377
+ "ref_adv_tr_label": self.opt.index_to_label[ref_adv_tr_label],
378
+ "perplexity": perplexity,
379
+ }
380
+ if defense:
381
+ try:
382
+ if not hasattr(self, "sent_attacker"):
383
+ self.sent_attacker = init_attacker(
384
+ self, defense.lower()
385
+ )
386
+ if result["is_adv_label"] == "1":
387
+ res = self.sent_attacker.attacker.simple_attack(
388
+ text_raw, int(result["label"])
389
+ )
390
+ new_infer_res = self.infer(
391
+ res.perturbed_result.attacked_text.text,
392
+ print_result=False,
393
+ )
394
+ result["perturbed_label"] = result["label"]
395
+ result["label"] = new_infer_res["label"]
396
+ result["probs"] = new_infer_res["probs"]
397
+ result["ref_label_check"] = (
398
+ correct[int(result["label"]) == ref_label]
399
+ if ref_label != -100
400
+ else ""
401
+ )
402
+ result[
403
+ "restored_text"
404
+ ] = res.perturbed_result.attacked_text.text
405
+ result["is_fixed"] = True
406
+ else:
407
+ result["restored_text"] = ""
408
+ result["is_fixed"] = False
409
+
410
+ except Exception as e:
411
+ print(
412
+ "Error:{}, try install TextAttack and tensorflow_text after 10 seconds...".format(
413
+ e
414
+ )
415
+ )
416
+ time.sleep(10)
417
+ raise RuntimeError("Installation done, please run again...")
418
+
419
+ if ref_label != -100:
420
+ n_labeled += 1
421
+
422
+ if result["label"] == result["ref_label"]:
423
+ n_correct += 1
424
+
425
+ if ref_is_adv_label != -100:
426
+ n_advdet_labeled += 1
427
+ if ref_is_adv_label == pred_is_adv_label:
428
+ n_advdet_correct += 1
429
+
430
+ results.append(result)
431
+
432
+ try:
433
+ if print_result:
434
+ for ex_id, result in enumerate(results):
435
+ text_printing = result["text"][:]
436
+ text_info = ""
437
+ if result["label"] != "-100":
438
+ if not result["ref_label"]:
439
+ text_info += " -> <CLS:{}(ref:{} confidence:{})>".format(
440
+ result["label"],
441
+ result["ref_label"],
442
+ result["confidence"],
443
+ )
444
+ elif result["label"] == result["ref_label"]:
445
+ text_info += colored(
446
+ " -> <CLS:{}(ref:{} confidence:{})>".format(
447
+ result["label"],
448
+ result["ref_label"],
449
+ result["confidence"],
450
+ ),
451
+ "green",
452
+ )
453
+ else:
454
+ text_info += colored(
455
+ " -> <CLS:{}(ref:{} confidence:{})>".format(
456
+ result["label"],
457
+ result["ref_label"],
458
+ result["confidence"],
459
+ ),
460
+ "red",
461
+ )
462
+
463
+ # AdvDet
464
+ if result["is_adv_label"] != "-100":
465
+ if not result["ref_is_adv_label"]:
466
+ text_info += " -> <AdvDet:{}(ref:{} confidence:{})>".format(
467
+ result["is_adv_label"],
468
+ result["ref_is_adv_check"],
469
+ result["is_adv_confidence"],
470
+ )
471
+ elif result["is_adv_label"] == result["ref_is_adv_label"]:
472
+ text_info += colored(
473
+ " -> <AdvDet:{}(ref:{} confidence:{})>".format(
474
+ result["is_adv_label"],
475
+ result["ref_is_adv_label"],
476
+ result["is_adv_confidence"],
477
+ ),
478
+ "green",
479
+ )
480
+ else:
481
+ text_info += colored(
482
+ " -> <AdvDet:{}(ref:{} confidence:{})>".format(
483
+ result["is_adv_label"],
484
+ result["ref_is_adv_label"],
485
+ result["is_adv_confidence"],
486
+ ),
487
+ "red",
488
+ )
489
+ text_printing += text_info
490
+ if self.cal_perplexity:
491
+ text_printing += colored(
492
+ " --> <perplexity:{}>".format(result["perplexity"]),
493
+ "yellow",
494
+ )
495
+ print("Example {}: {}".format(ex_id, text_printing))
496
+ if save_path:
497
+ with open(save_path, "w", encoding="utf8") as fout:
498
+ json.dump(str(results), fout, ensure_ascii=False)
499
+ print("inference result saved in: {}".format(save_path))
500
+ except Exception as e:
501
+ print("Can not save result: {}, Exception: {}".format(text_raw, e))
502
+
503
+ if len(results) > 1:
504
+ print(
505
+ "CLS Acc:{}%".format(100 * n_correct / n_labeled if n_labeled else "")
506
+ )
507
+ print(
508
+ "AdvDet Acc:{}%".format(
509
+ 100 * n_advdet_correct / n_advdet_labeled
510
+ if n_advdet_labeled
511
+ else ""
512
+ )
513
+ )
514
+
515
+ return results
516
+
517
+ def clear_input_samples(self):
518
+ self.dataset.all_data = []
anonymous_demo/functional/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from anonymous_demo.functional.checkpoint.checkpoint_manager import TADCheckpointManager
2
+
3
+ from anonymous_demo.functional.config import TADConfigManager
anonymous_demo/functional/checkpoint/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .checkpoint_manager import TADCheckpointManager
anonymous_demo/functional/checkpoint/checkpoint_manager.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from findfile import find_file
3
+
4
+ from anonymous_demo.core.tad.prediction.tad_classifier import TADTextClassifier
5
+ from anonymous_demo.utils.demo_utils import retry
6
+
7
+
8
+ class CheckpointManager:
9
+ pass
10
+
11
+
12
+ class TADCheckpointManager(CheckpointManager):
13
+ @staticmethod
14
+ @retry
15
+ def get_tad_text_classifier(checkpoint: str = None, eval_batch_size=128, **kwargs):
16
+ tad_text_classifier = TADTextClassifier(
17
+ checkpoint, eval_batch_size=eval_batch_size, **kwargs
18
+ )
19
+ return tad_text_classifier
anonymous_demo/functional/config/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .tad_config_manager import TADConfigManager
anonymous_demo/functional/config/config_manager.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+
3
+ import torch
4
+
5
+ one_shot_messages = set()
6
+
7
+
8
+ def config_check(args):
9
+ pass
10
+
11
+
12
+ class ConfigManager(Namespace):
13
+ def __init__(self, args=None, **kwargs):
14
+ """
15
+ The ConfigManager is a subclass of argparse.Namespace and based on parameter dict and count the call-frequency of each parameter
16
+ :param args: A parameter dict
17
+ :param kwargs: Same param as Namespce
18
+ """
19
+ if not args:
20
+ args = {}
21
+ super().__init__(**kwargs)
22
+
23
+ if isinstance(args, Namespace):
24
+ self.args = vars(args)
25
+ self.args_call_count = {arg: 0 for arg in vars(args)}
26
+ else:
27
+ self.args = args
28
+ self.args_call_count = {arg: 0 for arg in args}
29
+
30
+ def __getattribute__(self, arg_name):
31
+ if arg_name == "args" or arg_name == "args_call_count":
32
+ return super().__getattribute__(arg_name)
33
+ try:
34
+ value = super().__getattribute__("args")[arg_name]
35
+ args_call_count = super().__getattribute__("args_call_count")
36
+ args_call_count[arg_name] += 1
37
+ super().__setattr__("args_call_count", args_call_count)
38
+ return value
39
+
40
+ except Exception as e:
41
+ return super().__getattribute__(arg_name)
42
+
43
+ def __setattr__(self, arg_name, value):
44
+ if arg_name == "args" or arg_name == "args_call_count":
45
+ super().__setattr__(arg_name, value)
46
+ return
47
+ try:
48
+ args = super().__getattribute__("args")
49
+ args[arg_name] = value
50
+ super().__setattr__("args", args)
51
+ args_call_count = super().__getattribute__("args_call_count")
52
+
53
+ if arg_name in args_call_count:
54
+ # args_call_count[arg_name] += 1
55
+ super().__setattr__("args_call_count", args_call_count)
56
+
57
+ else:
58
+ args_call_count[arg_name] = 0
59
+ super().__setattr__("args_call_count", args_call_count)
60
+
61
+ except Exception as e:
62
+ super().__setattr__(arg_name, value)
63
+
64
+ config_check(args)
anonymous_demo/functional/config/tad_config_manager.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from anonymous_demo.functional.config.config_manager import ConfigManager
4
+ from anonymous_demo.core.tad.classic.__bert__.models import TADBERT
5
+
6
+ _tad_config_template = {
7
+ "model": TADBERT,
8
+ "optimizer": "adamw",
9
+ "learning_rate": 0.00002,
10
+ "patience": 99999,
11
+ "pretrained_bert": "microsoft/mdeberta-v3-base",
12
+ "cache_dataset": True,
13
+ "warmup_step": -1,
14
+ "show_metric": False,
15
+ "max_seq_len": 80,
16
+ "dropout": 0,
17
+ "l2reg": 0.000001,
18
+ "num_epoch": 10,
19
+ "batch_size": 16,
20
+ "initializer": "xavier_uniform_",
21
+ "seed": 52,
22
+ "polarities_dim": 3,
23
+ "log_step": 10,
24
+ "evaluate_begin": 0,
25
+ "cross_validate_fold": -1,
26
+ "use_amp": False,
27
+ # split train and test datasets into 5 folds and repeat 3 training
28
+ }
29
+
30
+ _tad_config_base = {
31
+ "model": TADBERT,
32
+ "optimizer": "adamw",
33
+ "learning_rate": 0.00002,
34
+ "pretrained_bert": "microsoft/deberta-v3-base",
35
+ "cache_dataset": True,
36
+ "warmup_step": -1,
37
+ "show_metric": False,
38
+ "max_seq_len": 80,
39
+ "patience": 99999,
40
+ "dropout": 0,
41
+ "l2reg": 0.000001,
42
+ "num_epoch": 10,
43
+ "batch_size": 16,
44
+ "initializer": "xavier_uniform_",
45
+ "seed": 52,
46
+ "polarities_dim": 3,
47
+ "log_step": 10,
48
+ "evaluate_begin": 0,
49
+ "cross_validate_fold": -1
50
+ # split train and test datasets into 5 folds and repeat 3 training
51
+ }
52
+
53
+ _tad_config_english = {
54
+ "model": TADBERT,
55
+ "optimizer": "adamw",
56
+ "learning_rate": 0.00002,
57
+ "patience": 99999,
58
+ "pretrained_bert": "microsoft/deberta-v3-base",
59
+ "cache_dataset": True,
60
+ "warmup_step": -1,
61
+ "show_metric": False,
62
+ "max_seq_len": 80,
63
+ "dropout": 0,
64
+ "l2reg": 0.000001,
65
+ "num_epoch": 10,
66
+ "batch_size": 16,
67
+ "initializer": "xavier_uniform_",
68
+ "seed": 52,
69
+ "polarities_dim": 3,
70
+ "log_step": 10,
71
+ "evaluate_begin": 0,
72
+ "cross_validate_fold": -1
73
+ # split train and test datasets into 5 folds and repeat 3 training
74
+ }
75
+
76
+ _tad_config_multilingual = {
77
+ "model": TADBERT,
78
+ "optimizer": "adamw",
79
+ "learning_rate": 0.00002,
80
+ "patience": 99999,
81
+ "pretrained_bert": "microsoft/mdeberta-v3-base",
82
+ "cache_dataset": True,
83
+ "warmup_step": -1,
84
+ "show_metric": False,
85
+ "max_seq_len": 80,
86
+ "dropout": 0,
87
+ "l2reg": 0.000001,
88
+ "num_epoch": 10,
89
+ "batch_size": 16,
90
+ "initializer": "xavier_uniform_",
91
+ "seed": 52,
92
+ "polarities_dim": 3,
93
+ "log_step": 10,
94
+ "evaluate_begin": 0,
95
+ "cross_validate_fold": -1
96
+ # split train and test datasets into 5 folds and repeat 3 training
97
+ }
98
+
99
+ _tad_config_chinese = {
100
+ "model": TADBERT,
101
+ "optimizer": "adamw",
102
+ "learning_rate": 0.00002,
103
+ "patience": 99999,
104
+ "cache_dataset": True,
105
+ "warmup_step": -1,
106
+ "show_metric": False,
107
+ "pretrained_bert": "bert-base-chinese",
108
+ "max_seq_len": 80,
109
+ "dropout": 0,
110
+ "l2reg": 0.000001,
111
+ "num_epoch": 10,
112
+ "batch_size": 16,
113
+ "initializer": "xavier_uniform_",
114
+ "seed": 52,
115
+ "polarities_dim": 3,
116
+ "log_step": 10,
117
+ "evaluate_begin": 0,
118
+ "cross_validate_fold": -1
119
+ # split train and test datasets into 5 folds and repeat 3 training
120
+ }
121
+
122
+
123
+ class TADConfigManager(ConfigManager):
124
+ def __init__(self, args, **kwargs):
125
+ """
126
+ Available Params: {'model': BERT,
127
+ 'optimizer': "adamw",
128
+ 'learning_rate': 0.00002,
129
+ 'pretrained_bert': "roberta-base",
130
+ 'cache_dataset': True,
131
+ 'warmup_step': -1,
132
+ 'show_metric': False,
133
+ 'max_seq_len': 80,
134
+ 'patience': 99999,
135
+ 'dropout': 0,
136
+ 'l2reg': 0.000001,
137
+ 'num_epoch': 10,
138
+ 'batch_size': 16,
139
+ 'initializer': 'xavier_uniform_',
140
+ 'seed': {52, 25}
141
+ 'embed_dim': 768,
142
+ 'hidden_dim': 768,
143
+ 'polarities_dim': 3,
144
+ 'log_step': 10,
145
+ 'evaluate_begin': 0,
146
+ 'cross_validate_fold': -1 # split train and test datasets into 5 folds and repeat 3 training
147
+ }
148
+ :param args:
149
+ :param kwargs:
150
+ """
151
+ super().__init__(args, **kwargs)
152
+
153
+ @staticmethod
154
+ def set_tad_config(configType: str, newitem: dict):
155
+ if isinstance(newitem, dict):
156
+ if configType == "template":
157
+ _tad_config_template.update(newitem)
158
+ elif configType == "base":
159
+ _tad_config_base.update(newitem)
160
+ elif configType == "english":
161
+ _tad_config_english.update(newitem)
162
+ elif configType == "chinese":
163
+ _tad_config_chinese.update(newitem)
164
+ elif configType == "multilingual":
165
+ _tad_config_multilingual.update(newitem)
166
+ elif configType == "glove":
167
+ _tad_config_glove.update(newitem)
168
+ else:
169
+ raise ValueError(
170
+ "Wrong value of config type supplied, please use one from following type: template, base, english, chinese, multilingual, glove"
171
+ )
172
+ else:
173
+ raise TypeError(
174
+ "Wrong type of new config item supplied, please use dict e.g.{'NewConfig': NewValue}"
175
+ )
176
+
177
+ @staticmethod
178
+ def set_tad_config_template(newitem):
179
+ TADConfigManager.set_tad_config("template", newitem)
180
+
181
+ @staticmethod
182
+ def set_tad_config_base(newitem):
183
+ TADConfigManager.set_tad_config("base", newitem)
184
+
185
+ @staticmethod
186
+ def set_tad_config_english(newitem):
187
+ TADConfigManager.set_tad_config("english", newitem)
188
+
189
+ @staticmethod
190
+ def set_tad_config_chinese(newitem):
191
+ TADConfigManager.set_tad_config("chinese", newitem)
192
+
193
+ @staticmethod
194
+ def set_tad_config_multilingual(newitem):
195
+ TADConfigManager.set_tad_config("multilingual", newitem)
196
+
197
+ @staticmethod
198
+ def set_tad_config_glove(newitem):
199
+ TADConfigManager.set_tad_config("glove", newitem)
200
+
201
+ @staticmethod
202
+ def get_tad_config_template() -> ConfigManager:
203
+ _tad_config_template.update(_tad_config_template)
204
+ return TADConfigManager(copy.deepcopy(_tad_config_template))
205
+
206
+ @staticmethod
207
+ def get_tad_config_base() -> ConfigManager:
208
+ _tad_config_template.update(_tad_config_base)
209
+ return TADConfigManager(copy.deepcopy(_tad_config_template))
210
+
211
+ @staticmethod
212
+ def get_tad_config_english() -> ConfigManager:
213
+ _tad_config_template.update(_tad_config_english)
214
+ return TADConfigManager(copy.deepcopy(_tad_config_template))
215
+
216
+ @staticmethod
217
+ def get_tad_config_chinese() -> ConfigManager:
218
+ _tad_config_template.update(_tad_config_chinese)
219
+ return TADConfigManager(copy.deepcopy(_tad_config_template))
220
+
221
+ @staticmethod
222
+ def get_tad_config_multilingual() -> ConfigManager:
223
+ _tad_config_template.update(_tad_config_multilingual)
224
+ return TADConfigManager(copy.deepcopy(_tad_config_template))
225
+
226
+ @staticmethod
227
+ def get_tad_config_glove() -> ConfigManager:
228
+ _tad_config_template.update(_tad_config_glove)
229
+ return TADConfigManager(copy.deepcopy(_tad_config_template))
anonymous_demo/functional/dataset/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from anonymous_demo.functional.dataset.dataset_manager import detect_infer_dataset
anonymous_demo/functional/dataset/dataset_manager.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from findfile import find_files, find_dir
3
+
4
+ filter_key_words = [
5
+ ".py",
6
+ ".md",
7
+ "readme",
8
+ "log",
9
+ "result",
10
+ "zip",
11
+ ".state_dict",
12
+ ".model",
13
+ ".png",
14
+ "acc_",
15
+ "f1_",
16
+ ".backup",
17
+ ".bak",
18
+ ]
19
+
20
+
21
+ def detect_infer_dataset(dataset_path, task="apc"):
22
+ dataset_file = []
23
+ if isinstance(dataset_path, str) and os.path.isfile(dataset_path):
24
+ dataset_file.append(dataset_path)
25
+ return dataset_file
26
+
27
+ for d in dataset_path:
28
+ if not os.path.exists(d):
29
+ search_path = find_dir(
30
+ os.getcwd(),
31
+ [d, task, "dataset"],
32
+ exclude_key=filter_key_words,
33
+ disable_alert=False,
34
+ )
35
+ dataset_file += find_files(
36
+ search_path,
37
+ [".inference", d],
38
+ exclude_key=["train."] + filter_key_words,
39
+ )
40
+ else:
41
+ dataset_file += find_files(
42
+ d, [".inference", task], exclude_key=["train."] + filter_key_words
43
+ )
44
+
45
+ return dataset_file
anonymous_demo/network/__init__.py ADDED
File without changes
anonymous_demo/network/lcf_pooler.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class LCF_Pooler(nn.Module):
7
+ def __init__(self, config):
8
+ super().__init__()
9
+ self.config = config
10
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
11
+ self.activation = nn.Tanh()
12
+
13
+ def forward(self, hidden_states, lcf_vec):
14
+ device = hidden_states.device
15
+ lcf_vec = lcf_vec.detach().cpu().numpy()
16
+
17
+ pooled_output = numpy.zeros(
18
+ (hidden_states.shape[0], hidden_states.shape[2]), dtype=numpy.float32
19
+ )
20
+ hidden_states = hidden_states.detach().cpu().numpy()
21
+ for i, vec in enumerate(lcf_vec):
22
+ lcf_ids = [j for j in range(len(vec)) if sum(vec[j] - 1.0) == 0]
23
+ pooled_output[i] = hidden_states[i][lcf_ids[len(lcf_ids) // 2]]
24
+
25
+ pooled_output = torch.Tensor(pooled_output).to(device)
26
+ pooled_output = self.dense(pooled_output)
27
+ pooled_output = self.activation(pooled_output)
28
+ return pooled_output
anonymous_demo/network/lsa.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from anonymous_demo.network.sa_encoder import Encoder
3
+ from torch import nn
4
+
5
+
6
+ class LSA(nn.Module):
7
+ def __init__(self, bert, opt):
8
+ super(LSA, self).__init__()
9
+ self.opt = opt
10
+
11
+ self.encoder = Encoder(bert.config, opt)
12
+ self.encoder_left = Encoder(bert.config, opt)
13
+ self.encoder_right = Encoder(bert.config, opt)
14
+ self.linear_window_3h = nn.Linear(opt.embed_dim * 3, opt.embed_dim)
15
+ self.linear_window_2h = nn.Linear(opt.embed_dim * 2, opt.embed_dim)
16
+ self.eta1 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
17
+ self.eta2 = nn.Parameter(torch.tensor(self.opt.eta, dtype=torch.float))
18
+
19
+ def forward(
20
+ self,
21
+ global_context_features,
22
+ spc_mask_vec,
23
+ lcf_matrix,
24
+ left_lcf_matrix,
25
+ right_lcf_matrix,
26
+ ):
27
+ masked_global_context_features = torch.mul(
28
+ spc_mask_vec, global_context_features
29
+ )
30
+
31
+ # # --------------------------------------------------- #
32
+ lcf_features = torch.mul(global_context_features, lcf_matrix)
33
+ lcf_features = self.encoder(lcf_features)
34
+ # # --------------------------------------------------- #
35
+ left_lcf_features = torch.mul(masked_global_context_features, left_lcf_matrix)
36
+ left_lcf_features = self.encoder_left(left_lcf_features)
37
+ # # --------------------------------------------------- #
38
+ right_lcf_features = torch.mul(masked_global_context_features, right_lcf_matrix)
39
+ right_lcf_features = self.encoder_right(right_lcf_features)
40
+ # # --------------------------------------------------- #
41
+ if "lr" == self.opt.window or "rl" == self.opt.window:
42
+ if self.eta1 <= 0 and self.opt.eta != -1:
43
+ torch.nn.init.uniform_(self.eta1)
44
+ print("reset eta1 to: {}".format(self.eta1.item()))
45
+ if self.eta2 <= 0 and self.opt.eta != -1:
46
+ torch.nn.init.uniform_(self.eta2)
47
+ print("reset eta2 to: {}".format(self.eta2.item()))
48
+ if self.opt.eta >= 0:
49
+ cat_features = torch.cat(
50
+ (
51
+ lcf_features,
52
+ self.eta1 * left_lcf_features,
53
+ self.eta2 * right_lcf_features,
54
+ ),
55
+ -1,
56
+ )
57
+ else:
58
+ cat_features = torch.cat(
59
+ (lcf_features, left_lcf_features, right_lcf_features), -1
60
+ )
61
+ sent_out = self.linear_window_3h(cat_features)
62
+ elif "l" == self.opt.window:
63
+ sent_out = self.linear_window_2h(
64
+ torch.cat((lcf_features, self.eta1 * left_lcf_features), -1)
65
+ )
66
+ elif "r" == self.opt.window:
67
+ sent_out = self.linear_window_2h(
68
+ torch.cat((lcf_features, self.eta2 * right_lcf_features), -1)
69
+ )
70
+ else:
71
+ raise KeyError("Invalid parameter:", self.opt.window)
72
+
73
+ return sent_out
anonymous_demo/network/sa_encoder.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class BertSelfAttention(nn.Module):
9
+ def __init__(self, config):
10
+ super().__init__()
11
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
12
+ config, "embedding_size"
13
+ ):
14
+ raise ValueError(
15
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
16
+ f"heads ({config.num_attention_heads})"
17
+ )
18
+
19
+ self.num_attention_heads = config.num_attention_heads
20
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
21
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
22
+
23
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
24
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
25
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
26
+
27
+ self.dropout = nn.Dropout(
28
+ config.attention_probs_dropout_prob
29
+ if hasattr(config, "attention_probs_dropout_prob")
30
+ else 0
31
+ )
32
+ self.position_embedding_type = getattr(
33
+ config, "position_embedding_type", "absolute"
34
+ )
35
+ if (
36
+ self.position_embedding_type == "relative_key"
37
+ or self.position_embedding_type == "relative_key_query"
38
+ ):
39
+ self.max_position_embeddings = config.max_position_embeddings
40
+ self.distance_embedding = nn.Embedding(
41
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
42
+ )
43
+
44
+ self.is_decoder = config.is_decoder
45
+
46
+ def transpose_for_scores(self, x):
47
+ new_x_shape = x.size()[:-1] + (
48
+ self.num_attention_heads,
49
+ self.attention_head_size,
50
+ )
51
+ x = x.view(*new_x_shape)
52
+ return x.permute(0, 2, 1, 3)
53
+
54
+ def forward(
55
+ self,
56
+ hidden_states,
57
+ attention_mask=None,
58
+ head_mask=None,
59
+ encoder_hidden_states=None,
60
+ encoder_attention_mask=None,
61
+ past_key_value=None,
62
+ output_attentions=False,
63
+ ):
64
+ mixed_query_layer = self.query(hidden_states)
65
+
66
+ # If this is instantiated as a cross-attention module, the keys
67
+ # and values come from an encoder; the attention mask needs to be
68
+ # such that the encoder's padding tokens are not attended to.
69
+ is_cross_attention = encoder_hidden_states is not None
70
+
71
+ if is_cross_attention and past_key_value is not None:
72
+ # reuse k,v, cross_attentions
73
+ key_layer = past_key_value[0]
74
+ value_layer = past_key_value[1]
75
+ attention_mask = encoder_attention_mask
76
+ elif is_cross_attention:
77
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
78
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
79
+ attention_mask = encoder_attention_mask
80
+ elif past_key_value is not None:
81
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
82
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
83
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
84
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
85
+ else:
86
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
87
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
88
+
89
+ query_layer = self.transpose_for_scores(mixed_query_layer)
90
+
91
+ if self.is_decoder:
92
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
93
+ # Further calls to cross_attention layer can then reuse all cross-attention
94
+ # key/value_states (first "if" case)
95
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
96
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
97
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
98
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
99
+ past_key_value = (key_layer, value_layer)
100
+
101
+ # Take the dot product between "query" and "key" to get the raw attention scores.
102
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
103
+
104
+ if (
105
+ self.position_embedding_type == "relative_key"
106
+ or self.position_embedding_type == "relative_key_query"
107
+ ):
108
+ seq_length = hidden_states.size()[1]
109
+ position_ids_l = torch.arange(
110
+ seq_length, dtype=torch.long, device=hidden_states.device
111
+ ).view(-1, 1)
112
+ position_ids_r = torch.arange(
113
+ seq_length, dtype=torch.long, device=hidden_states.device
114
+ ).view(1, -1)
115
+ distance = position_ids_l - position_ids_r
116
+ positional_embedding = self.distance_embedding(
117
+ distance + self.max_position_embeddings - 1
118
+ )
119
+ positional_embedding = positional_embedding.to(
120
+ dtype=query_layer.dtype
121
+ ) # fp16 compatibility
122
+
123
+ if self.position_embedding_type == "relative_key":
124
+ relative_position_scores = torch.einsum(
125
+ "bhld,lrd->bhlr", query_layer, positional_embedding
126
+ )
127
+ attention_scores = attention_scores + relative_position_scores
128
+ elif self.position_embedding_type == "relative_key_query":
129
+ relative_position_scores_query = torch.einsum(
130
+ "bhld,lrd->bhlr", query_layer, positional_embedding
131
+ )
132
+ relative_position_scores_key = torch.einsum(
133
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
134
+ )
135
+ attention_scores = (
136
+ attention_scores
137
+ + relative_position_scores_query
138
+ + relative_position_scores_key
139
+ )
140
+
141
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
142
+ if attention_mask is not None:
143
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
144
+ attention_scores = attention_scores + attention_mask
145
+
146
+ # Normalize the attention scores to probabilities.
147
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
148
+
149
+ # This is actually dropping out entire tokens to attend to, which might
150
+ # seem a bit unusual, but is taken from the original Transformer paper.
151
+ attention_probs = self.dropout(attention_probs)
152
+
153
+ # Mask heads if we want to
154
+ if head_mask is not None:
155
+ attention_probs = attention_probs * head_mask
156
+
157
+ context_layer = torch.matmul(attention_probs, value_layer)
158
+
159
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
160
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
161
+ context_layer = context_layer.view(*new_context_layer_shape)
162
+
163
+ outputs = (
164
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
165
+ )
166
+
167
+ if self.is_decoder:
168
+ outputs = outputs + (past_key_value,)
169
+ return outputs
170
+
171
+
172
+ class Encoder(nn.Module):
173
+ def __init__(self, config, opt, layer_num=1):
174
+ super(Encoder, self).__init__()
175
+ self.opt = opt
176
+ self.config = config
177
+ self.encoder = nn.ModuleList(
178
+ [SelfAttention(config, opt) for _ in range(layer_num)]
179
+ )
180
+ self.tanh = torch.nn.Tanh()
181
+
182
+ def forward(self, x):
183
+ for i, enc in enumerate(self.encoder):
184
+ x = self.tanh(enc(x)[0])
185
+ return x
186
+
187
+
188
+ class SelfAttention(nn.Module):
189
+ def __init__(self, config, opt):
190
+ super(SelfAttention, self).__init__()
191
+ self.opt = opt
192
+ self.config = config
193
+ self.SA = BertSelfAttention(config)
194
+
195
+ def forward(self, inputs):
196
+ zero_vec = np.zeros((inputs.size(0), 1, 1, self.opt.max_seq_len))
197
+ zero_tensor = torch.tensor(zero_vec).float().to(inputs.device)
198
+ SA_out = self.SA(inputs, zero_tensor)
199
+ return SA_out
anonymous_demo/utils/__init__.py ADDED
File without changes
anonymous_demo/utils/demo_utils.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pickle
4
+ import signal
5
+ import threading
6
+ import time
7
+ import zipfile
8
+
9
+ import gdown
10
+ import numpy as np
11
+ import requests
12
+ import torch
13
+ import tqdm
14
+ from autocuda import auto_cuda, auto_cuda_name
15
+ from findfile import find_files, find_cwd_file, find_file
16
+ from termcolor import colored
17
+ from functools import wraps
18
+
19
+ from update_checker import parse_version
20
+
21
+ from anonymous_demo import __version__
22
+
23
+
24
+ def save_args(config, save_path):
25
+ f = open(os.path.join(save_path), mode="w", encoding="utf8")
26
+ for arg in config.args:
27
+ if config.args_call_count[arg]:
28
+ f.write("{}: {}\n".format(arg, config.args[arg]))
29
+ f.close()
30
+
31
+
32
+ def print_args(config, logger=None, mode=0):
33
+ args = [key for key in sorted(config.args.keys())]
34
+ for arg in args:
35
+ if logger:
36
+ logger.info(
37
+ "{0}:{1}\t-->\tCalling Count:{2}".format(
38
+ arg, config.args[arg], config.args_call_count[arg]
39
+ )
40
+ )
41
+ else:
42
+ print(
43
+ "{0}:{1}\t-->\tCalling Count:{2}".format(
44
+ arg, config.args[arg], config.args_call_count[arg]
45
+ )
46
+ )
47
+
48
+
49
+ def check_and_fix_labels(label_set: set, label_name, all_data, opt):
50
+ if "-100" in label_set:
51
+ label_to_index = {
52
+ origin_label: int(idx) - 1 if origin_label != "-100" else -100
53
+ for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
54
+ }
55
+ index_to_label = {
56
+ int(idx) - 1 if origin_label != "-100" else -100: origin_label
57
+ for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
58
+ }
59
+ else:
60
+ label_to_index = {
61
+ origin_label: int(idx)
62
+ for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
63
+ }
64
+ index_to_label = {
65
+ int(idx): origin_label
66
+ for origin_label, idx in zip(sorted(label_set), range(len(label_set)))
67
+ }
68
+ if "index_to_label" not in opt.args:
69
+ opt.index_to_label = index_to_label
70
+ opt.label_to_index = label_to_index
71
+
72
+ if opt.index_to_label != index_to_label:
73
+ opt.index_to_label.update(index_to_label)
74
+ opt.label_to_index.update(label_to_index)
75
+ num_label = {l: 0 for l in label_set}
76
+ num_label["Sum"] = len(all_data)
77
+ for item in all_data:
78
+ try:
79
+ num_label[item[label_name]] += 1
80
+ item[label_name] = label_to_index[item[label_name]]
81
+ except Exception as e:
82
+ # print(e)
83
+ num_label[item.polarity] += 1
84
+ item.polarity = label_to_index[item.polarity]
85
+ print("Dataset Label Details: {}".format(num_label))
86
+
87
+
88
+ def check_and_fix_IOB_labels(label_map, opt):
89
+ index_to_IOB_label = {
90
+ int(label_map[origin_label]): origin_label for origin_label in label_map
91
+ }
92
+ opt.index_to_IOB_label = index_to_IOB_label
93
+
94
+
95
+ def get_device(auto_device):
96
+ if isinstance(auto_device, str) and auto_device == "allcuda":
97
+ device = "cuda"
98
+ elif isinstance(auto_device, str):
99
+ device = auto_device
100
+ elif isinstance(auto_device, bool):
101
+ device = auto_cuda() if auto_device else "cpu"
102
+ else:
103
+ device = auto_cuda()
104
+ try:
105
+ torch.device(device)
106
+ except RuntimeError as e:
107
+ print(
108
+ colored("Device assignment error: {}, redirect to CPU".format(e), "red")
109
+ )
110
+ device = "cpu"
111
+ device_name = auto_cuda_name()
112
+ return device, device_name
113
+
114
+
115
+ def _load_word_vec(path, word2idx=None, embed_dim=300):
116
+ fin = open(path, "r", encoding="utf-8", newline="\n", errors="ignore")
117
+ word_vec = {}
118
+ for line in tqdm.tqdm(fin.readlines(), postfix="Loading embedding file..."):
119
+ tokens = line.rstrip().split()
120
+ word, vec = " ".join(tokens[:-embed_dim]), tokens[-embed_dim:]
121
+ if word in word2idx.keys():
122
+ word_vec[word] = np.asarray(vec, dtype="float32")
123
+ return word_vec
124
+
125
+
126
+ def build_embedding_matrix(word2idx, embed_dim, dat_fname, opt):
127
+ if not os.path.exists("run"):
128
+ os.makedirs("run")
129
+ embed_matrix_path = "run/{}".format(os.path.join(opt.dataset_name, dat_fname))
130
+ if os.path.exists(embed_matrix_path):
131
+ print(
132
+ colored(
133
+ "Loading cached embedding_matrix from {} (Please remove all cached files if there is any problem!)".format(
134
+ embed_matrix_path
135
+ ),
136
+ "green",
137
+ )
138
+ )
139
+ embedding_matrix = pickle.load(open(embed_matrix_path, "rb"))
140
+ else:
141
+ glove_path = prepare_glove840_embedding(embed_matrix_path)
142
+ embedding_matrix = np.zeros((len(word2idx) + 2, embed_dim))
143
+
144
+ word_vec = _load_word_vec(glove_path, word2idx=word2idx, embed_dim=embed_dim)
145
+
146
+ for word, i in tqdm.tqdm(
147
+ word2idx.items(),
148
+ postfix=colored("Building embedding_matrix {}".format(dat_fname), "yellow"),
149
+ ):
150
+ vec = word_vec.get(word)
151
+ if vec is not None:
152
+ embedding_matrix[i] = vec
153
+ pickle.dump(embedding_matrix, open(embed_matrix_path, "wb"))
154
+ return embedding_matrix
155
+
156
+
157
+ def pad_and_truncate(
158
+ sequence, maxlen, dtype="int64", padding="post", truncating="post", value=0
159
+ ):
160
+ x = (np.ones(maxlen) * value).astype(dtype)
161
+ if truncating == "pre":
162
+ trunc = sequence[-maxlen:]
163
+ else:
164
+ trunc = sequence[:maxlen]
165
+ trunc = np.asarray(trunc, dtype=dtype)
166
+ if padding == "post":
167
+ x[: len(trunc)] = trunc
168
+ else:
169
+ x[-len(trunc) :] = trunc
170
+ return x
171
+
172
+
173
+ class TransformerConnectionError(ValueError):
174
+ def __init__(self):
175
+ pass
176
+
177
+
178
+ def retry(f):
179
+ @wraps(f)
180
+ def decorated(*args, **kwargs):
181
+ count = 5
182
+ while count:
183
+ try:
184
+ return f(*args, **kwargs)
185
+ except (
186
+ TransformerConnectionError,
187
+ requests.exceptions.RequestException,
188
+ requests.exceptions.ConnectionError,
189
+ requests.exceptions.HTTPError,
190
+ requests.exceptions.ConnectTimeout,
191
+ requests.exceptions.ProxyError,
192
+ requests.exceptions.SSLError,
193
+ requests.exceptions.BaseHTTPError,
194
+ ) as e:
195
+ print(colored("Training Exception: {}, will retry later".format(e)))
196
+ time.sleep(60)
197
+ count -= 1
198
+
199
+ return decorated
200
+
201
+
202
+ def save_json(dic, save_path):
203
+ if isinstance(dic, str):
204
+ dic = eval(dic)
205
+ with open(save_path, "w", encoding="utf-8") as f:
206
+ # f.write(str(dict))
207
+ str_ = json.dumps(dic, ensure_ascii=False)
208
+ f.write(str_)
209
+
210
+
211
+ def load_json(save_path):
212
+ with open(save_path, "r", encoding="utf-8") as f:
213
+ data = f.readline().strip()
214
+ print(type(data), data)
215
+ dic = json.loads(data)
216
+ return dic
217
+
218
+
219
+ def init_optimizer(optimizer):
220
+ optimizers = {
221
+ "adadelta": torch.optim.Adadelta, # default lr=1.0
222
+ "adagrad": torch.optim.Adagrad, # default lr=0.01
223
+ "adam": torch.optim.Adam, # default lr=0.001
224
+ "adamax": torch.optim.Adamax, # default lr=0.002
225
+ "asgd": torch.optim.ASGD, # default lr=0.01
226
+ "rmsprop": torch.optim.RMSprop, # default lr=0.01
227
+ "sgd": torch.optim.SGD,
228
+ "adamw": torch.optim.AdamW,
229
+ torch.optim.Adadelta: torch.optim.Adadelta, # default lr=1.0
230
+ torch.optim.Adagrad: torch.optim.Adagrad, # default lr=0.01
231
+ torch.optim.Adam: torch.optim.Adam, # default lr=0.001
232
+ torch.optim.Adamax: torch.optim.Adamax, # default lr=0.002
233
+ torch.optim.ASGD: torch.optim.ASGD, # default lr=0.01
234
+ torch.optim.RMSprop: torch.optim.RMSprop, # default lr=0.01
235
+ torch.optim.SGD: torch.optim.SGD,
236
+ torch.optim.AdamW: torch.optim.AdamW,
237
+ }
238
+ if optimizer in optimizers:
239
+ return optimizers[optimizer]
240
+ elif hasattr(torch.optim, optimizer.__name__):
241
+ return optimizer
242
+ else:
243
+ raise KeyError(
244
+ "Unsupported optimizer: {}. Please use string or the optimizer objects in torch.optim as your optimizer".format(
245
+ optimizer
246
+ )
247
+ )
anonymous_demo/utils/logger.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ import time
5
+
6
+ import termcolor
7
+
8
+ today = time.strftime("%Y%m%d %H%M%S", time.localtime(time.time()))
9
+
10
+
11
+ def get_logger(log_path, log_name="", log_type="training_log"):
12
+ if not log_path:
13
+ log_dir = os.path.join(log_path, "logs")
14
+ else:
15
+ log_dir = os.path.join(".", "logs")
16
+
17
+ full_path = os.path.join(log_dir, log_name + "_" + today)
18
+ if not os.path.exists(full_path):
19
+ os.makedirs(full_path)
20
+ log_path = os.path.join(full_path, "{}.log".format(log_type))
21
+ logger = logging.getLogger(log_name)
22
+ if not logger.handlers:
23
+ formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s")
24
+
25
+ file_handler = logging.FileHandler(log_path, encoding="utf8")
26
+ file_handler.setFormatter(formatter)
27
+ file_handler.setLevel(logging.INFO)
28
+
29
+ console_handler = logging.StreamHandler(sys.stdout)
30
+ console_handler.formatter = formatter
31
+ console_handler.setLevel(logging.INFO)
32
+
33
+ logger.addHandler(file_handler)
34
+ logger.addHandler(console_handler)
35
+
36
+ logger.setLevel(logging.INFO)
37
+
38
+ return logger
checkpoints.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f77ae4a45785183900ee874cb318a16b0e2f173b31749a2555215aca93672f26
3
+ size 2456834455
text_defense/201.SST2/stsa.binary.dev.dat ADDED
The diff for this file is too large to render. See raw diff
 
text_defense/201.SST2/stsa.binary.test.dat ADDED
The diff for this file is too large to render. See raw diff
 
text_defense/201.SST2/stsa.binary.train.dat ADDED
The diff for this file is too large to render. See raw diff
 
text_defense/202.IMDB10K/imdb10k.test.dat ADDED
The diff for this file is too large to render. See raw diff
 
text_defense/202.IMDB10K/imdb10k.train.dat ADDED
The diff for this file is too large to render. See raw diff
 
text_defense/202.IMDB10K/imdb10k.valid.dat ADDED
The diff for this file is too large to render. See raw diff
 
text_defense/204.AGNews10K/AGNews10K.test.dat ADDED
The diff for this file is too large to render. See raw diff
 
text_defense/204.AGNews10K/AGNews10K.train.dat ADDED
The diff for this file is too large to render. See raw diff
 
text_defense/204.AGNews10K/AGNews10K.valid.dat ADDED
The diff for this file is too large to render. See raw diff
 
text_defense/206.Amazon_Review_Polarity10K/amazon.test.dat ADDED
The diff for this file is too large to render. See raw diff
 
text_defense/206.Amazon_Review_Polarity10K/amazon.train.dat ADDED
The diff for this file is too large to render. See raw diff
 
textattack/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Welcome to the API references for TextAttack!
2
+
3
+ What is TextAttack?
4
+
5
+ `TextAttack <https://github.com/QData/TextAttack>`__ is a Python framework for adversarial attacks, adversarial training, and data augmentation in NLP.
6
+
7
+ TextAttack makes experimenting with the robustness of NLP models seamless, fast, and easy. It's also useful for NLP model training, adversarial training, and data augmentation.
8
+
9
+ TextAttack provides components for common NLP tasks like sentence encoding, grammar-checking, and word replacement that can be used on their own.
10
+ """
11
+ from .attack_args import AttackArgs, CommandLineAttackArgs
12
+ from .augment_args import AugmenterArgs
13
+ from .dataset_args import DatasetArgs
14
+ from .model_args import ModelArgs
15
+ from .training_args import TrainingArgs, CommandLineTrainingArgs
16
+ from .attack import Attack
17
+ from .attacker import Attacker
18
+ from .trainer import Trainer
19
+ from .metrics import Metric
20
+
21
+ from . import (
22
+ attack_recipes,
23
+ attack_results,
24
+ augmentation,
25
+ commands,
26
+ constraints,
27
+ datasets,
28
+ goal_function_results,
29
+ goal_functions,
30
+ loggers,
31
+ metrics,
32
+ models,
33
+ search_methods,
34
+ shared,
35
+ transformations,
36
+ )
37
+
38
+
39
+ name = "textattack"
textattack/__main__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ if __name__ == "__main__":
4
+ import textattack
5
+
6
+ textattack.commands.textattack_cli.main()
textattack/attack.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Attack Class
3
+ ============
4
+ """
5
+
6
+ from collections import OrderedDict
7
+ from typing import List, Union
8
+
9
+ import lru
10
+ import torch
11
+
12
+ import textattack
13
+ from textattack.attack_results import (
14
+ FailedAttackResult,
15
+ MaximizedAttackResult,
16
+ SkippedAttackResult,
17
+ SuccessfulAttackResult,
18
+ )
19
+ from textattack.constraints import Constraint, PreTransformationConstraint
20
+ from textattack.goal_function_results import GoalFunctionResultStatus
21
+ from textattack.goal_functions import GoalFunction
22
+ from textattack.models.wrappers import ModelWrapper
23
+ from textattack.search_methods import SearchMethod
24
+ from textattack.shared import AttackedText, utils
25
+ from textattack.transformations import CompositeTransformation, Transformation
26
+
27
+
28
+ class Attack:
29
+ """An attack generates adversarial examples on text.
30
+
31
+ An attack is comprised of a goal function, constraints, transformation, and a search method. Use :meth:`attack` method to attack one sample at a time.
32
+
33
+ Args:
34
+ goal_function (:class:`~textattack.goal_functions.GoalFunction`):
35
+ A function for determining how well a perturbation is doing at achieving the attack's goal.
36
+ constraints (list of :class:`~textattack.constraints.Constraint` or :class:`~textattack.constraints.PreTransformationConstraint`):
37
+ A list of constraints to add to the attack, defining which perturbations are valid.
38
+ transformation (:class:`~textattack.transformations.Transformation`):
39
+ The transformation applied at each step of the attack.
40
+ search_method (:class:`~textattack.search_methods.SearchMethod`):
41
+ The method for exploring the search space of possible perturbations
42
+ transformation_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**15`):
43
+ The number of items to keep in the transformations cache
44
+ constraint_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**15`):
45
+ The number of items to keep in the constraints cache
46
+
47
+ Example::
48
+
49
+ >>> import textattack
50
+ >>> import transformers
51
+
52
+ >>> # Load model, tokenizer, and model_wrapper
53
+ >>> model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb")
54
+ >>> tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb")
55
+ >>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
56
+
57
+ >>> # Construct our four components for `Attack`
58
+ >>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification
59
+ >>> from textattack.constraints.semantics import WordEmbeddingDistance
60
+
61
+ >>> goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
62
+ >>> constraints = [
63
+ ... RepeatModification(),
64
+ ... StopwordModification()
65
+ ... WordEmbeddingDistance(min_cos_sim=0.9)
66
+ ... ]
67
+ >>> transformation = WordSwapEmbedding(max_candidates=50)
68
+ >>> search_method = GreedyWordSwapWIR(wir_method="delete")
69
+
70
+ >>> # Construct the actual attack
71
+ >>> attack = Attack(goal_function, constraints, transformation, search_method)
72
+
73
+ >>> input_text = "I really enjoyed the new movie that came out last month."
74
+ >>> label = 1 #Positive
75
+ >>> attack_result = attack.attack(input_text, label)
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ goal_function: GoalFunction,
81
+ constraints: List[Union[Constraint, PreTransformationConstraint]],
82
+ transformation: Transformation,
83
+ search_method: SearchMethod,
84
+ transformation_cache_size=2**15,
85
+ constraint_cache_size=2**15,
86
+ ):
87
+ """Initialize an attack object.
88
+
89
+ Attacks can be run multiple times.
90
+ """
91
+ assert isinstance(
92
+ goal_function, GoalFunction
93
+ ), f"`goal_function` must be of type `textattack.goal_functions.GoalFunction`, but got type `{type(goal_function)}`."
94
+ assert isinstance(
95
+ constraints, list
96
+ ), "`constraints` must be a list of `textattack.constraints.Constraint` or `textattack.constraints.PreTransformationConstraint`."
97
+ for c in constraints:
98
+ assert isinstance(
99
+ c, (Constraint, PreTransformationConstraint)
100
+ ), "`constraints` must be a list of `textattack.constraints.Constraint` or `textattack.constraints.PreTransformationConstraint`."
101
+ assert isinstance(
102
+ transformation, Transformation
103
+ ), f"`transformation` must be of type `textattack.transformations.Transformation`, but got type `{type(transformation)}`."
104
+ assert isinstance(
105
+ search_method, SearchMethod
106
+ ), f"`search_method` must be of type `textattack.search_methods.SearchMethod`, but got type `{type(search_method)}`."
107
+
108
+ self.goal_function = goal_function
109
+ self.search_method = search_method
110
+ self.transformation = transformation
111
+ self.is_black_box = (
112
+ getattr(transformation, "is_black_box", True) and search_method.is_black_box
113
+ )
114
+
115
+ if not self.search_method.check_transformation_compatibility(
116
+ self.transformation
117
+ ):
118
+ raise ValueError(
119
+ f"SearchMethod {self.search_method} incompatible with transformation {self.transformation}"
120
+ )
121
+
122
+ self.constraints = []
123
+ self.pre_transformation_constraints = []
124
+ for constraint in constraints:
125
+ if isinstance(
126
+ constraint,
127
+ textattack.constraints.PreTransformationConstraint,
128
+ ):
129
+ self.pre_transformation_constraints.append(constraint)
130
+ else:
131
+ self.constraints.append(constraint)
132
+
133
+ # Check if we can use transformation cache for our transformation.
134
+ if not self.transformation.deterministic:
135
+ self.use_transformation_cache = False
136
+ elif isinstance(self.transformation, CompositeTransformation):
137
+ self.use_transformation_cache = True
138
+ for t in self.transformation.transformations:
139
+ if not t.deterministic:
140
+ self.use_transformation_cache = False
141
+ break
142
+ else:
143
+ self.use_transformation_cache = True
144
+ self.transformation_cache_size = transformation_cache_size
145
+ self.transformation_cache = lru.LRU(transformation_cache_size)
146
+
147
+ self.constraint_cache_size = constraint_cache_size
148
+ self.constraints_cache = lru.LRU(constraint_cache_size)
149
+
150
+ # Give search method access to functions for getting transformations and evaluating them
151
+ self.search_method.get_transformations = self.get_transformations
152
+ # Give search method access to self.goal_function for model query count, etc.
153
+ self.search_method.goal_function = self.goal_function
154
+ # The search method only needs access to the first argument. The second is only used
155
+ # by the attack class when checking whether to skip the sample
156
+ self.search_method.get_goal_results = self.goal_function.get_results
157
+
158
+ # Give search method access to get indices which need to be ordered / searched
159
+ self.search_method.get_indices_to_order = self.get_indices_to_order
160
+
161
+ self.search_method.filter_transformations = self.filter_transformations
162
+
163
+ def clear_cache(self, recursive=True):
164
+ self.constraints_cache.clear()
165
+ if self.use_transformation_cache:
166
+ self.transformation_cache.clear()
167
+ if recursive:
168
+ self.goal_function.clear_cache()
169
+ for constraint in self.constraints:
170
+ if hasattr(constraint, "clear_cache"):
171
+ constraint.clear_cache()
172
+
173
+ def cpu_(self):
174
+ """Move any `torch.nn.Module` models that are part of Attack to CPU."""
175
+ visited = set()
176
+
177
+ def to_cpu(obj):
178
+ visited.add(id(obj))
179
+ if isinstance(obj, torch.nn.Module):
180
+ obj.cpu()
181
+ elif isinstance(
182
+ obj,
183
+ (
184
+ Attack,
185
+ GoalFunction,
186
+ Transformation,
187
+ SearchMethod,
188
+ Constraint,
189
+ PreTransformationConstraint,
190
+ ModelWrapper,
191
+ ),
192
+ ):
193
+ for key in obj.__dict__:
194
+ s_obj = obj.__dict__[key]
195
+ if id(s_obj) not in visited:
196
+ to_cpu(s_obj)
197
+ elif isinstance(obj, (list, tuple)):
198
+ for item in obj:
199
+ if id(item) not in visited and isinstance(
200
+ item, (Transformation, Constraint, PreTransformationConstraint)
201
+ ):
202
+ to_cpu(item)
203
+
204
+ to_cpu(self)
205
+
206
+ def cuda_(self):
207
+ """Move any `torch.nn.Module` models that are part of Attack to GPU."""
208
+ visited = set()
209
+
210
+ def to_cuda(obj):
211
+ visited.add(id(obj))
212
+ if isinstance(obj, torch.nn.Module):
213
+ obj.to(textattack.shared.utils.device)
214
+ elif isinstance(
215
+ obj,
216
+ (
217
+ Attack,
218
+ GoalFunction,
219
+ Transformation,
220
+ SearchMethod,
221
+ Constraint,
222
+ PreTransformationConstraint,
223
+ ModelWrapper,
224
+ ),
225
+ ):
226
+ for key in obj.__dict__:
227
+ s_obj = obj.__dict__[key]
228
+ if id(s_obj) not in visited:
229
+ to_cuda(s_obj)
230
+ elif isinstance(obj, (list, tuple)):
231
+ for item in obj:
232
+ if id(item) not in visited and isinstance(
233
+ item, (Transformation, Constraint, PreTransformationConstraint)
234
+ ):
235
+ to_cuda(item)
236
+
237
+ to_cuda(self)
238
+
239
+ def get_indices_to_order(self, current_text, **kwargs):
240
+ """Applies ``pre_transformation_constraints`` to ``text`` to get all
241
+ the indices that can be used to search and order.
242
+
243
+ Args:
244
+ current_text: The current ``AttackedText`` for which we need to find indices are eligible to be ordered.
245
+ Returns:
246
+ The length and the filtered list of indices which search methods can use to search/order.
247
+ """
248
+
249
+ indices_to_order = self.transformation(
250
+ current_text,
251
+ pre_transformation_constraints=self.pre_transformation_constraints,
252
+ return_indices=True,
253
+ **kwargs,
254
+ )
255
+
256
+ len_text = len(indices_to_order)
257
+
258
+ # Convert indices_to_order to list for easier shuffling later
259
+ return len_text, list(indices_to_order)
260
+
261
+ def _get_transformations_uncached(self, current_text, original_text=None, **kwargs):
262
+ """Applies ``self.transformation`` to ``text``, then filters the list
263
+ of possible transformations through the applicable constraints.
264
+
265
+ Args:
266
+ current_text: The current ``AttackedText`` on which to perform the transformations.
267
+ original_text: The original ``AttackedText`` from which the attack started.
268
+ Returns:
269
+ A filtered list of transformations where each transformation matches the constraints
270
+ """
271
+ transformed_texts = self.transformation(
272
+ current_text,
273
+ pre_transformation_constraints=self.pre_transformation_constraints,
274
+ **kwargs,
275
+ )
276
+
277
+ return transformed_texts
278
+
279
+ def get_transformations(self, current_text, original_text=None, **kwargs):
280
+ """Applies ``self.transformation`` to ``text``, then filters the list
281
+ of possible transformations through the applicable constraints.
282
+
283
+ Args:
284
+ current_text: The current ``AttackedText`` on which to perform the transformations.
285
+ original_text: The original ``AttackedText`` from which the attack started.
286
+ Returns:
287
+ A filtered list of transformations where each transformation matches the constraints
288
+ """
289
+ if not self.transformation:
290
+ raise RuntimeError(
291
+ "Cannot call `get_transformations` without a transformation."
292
+ )
293
+
294
+ if self.use_transformation_cache:
295
+ cache_key = tuple([current_text] + sorted(kwargs.items()))
296
+ if utils.hashable(cache_key) and cache_key in self.transformation_cache:
297
+ # promote transformed_text to the top of the LRU cache
298
+ self.transformation_cache[cache_key] = self.transformation_cache[
299
+ cache_key
300
+ ]
301
+ transformed_texts = list(self.transformation_cache[cache_key])
302
+ else:
303
+ transformed_texts = self._get_transformations_uncached(
304
+ current_text, original_text, **kwargs
305
+ )
306
+ if utils.hashable(cache_key):
307
+ self.transformation_cache[cache_key] = tuple(transformed_texts)
308
+ else:
309
+ transformed_texts = self._get_transformations_uncached(
310
+ current_text, original_text, **kwargs
311
+ )
312
+
313
+ return self.filter_transformations(
314
+ transformed_texts, current_text, original_text
315
+ )
316
+
317
+ def _filter_transformations_uncached(
318
+ self, transformed_texts, current_text, original_text=None
319
+ ):
320
+ """Filters a list of potential transformed texts based on
321
+ ``self.constraints``
322
+
323
+ Args:
324
+ transformed_texts: A list of candidate transformed ``AttackedText`` to filter.
325
+ current_text: The current ``AttackedText`` on which the transformation was applied.
326
+ original_text: The original ``AttackedText`` from which the attack started.
327
+ """
328
+ filtered_texts = transformed_texts[:]
329
+ for C in self.constraints:
330
+ if len(filtered_texts) == 0:
331
+ break
332
+ if C.compare_against_original:
333
+ if not original_text:
334
+ raise ValueError(
335
+ f"Missing `original_text` argument when constraint {type(C)} is set to compare against `original_text`"
336
+ )
337
+
338
+ filtered_texts = C.call_many(filtered_texts, original_text)
339
+ else:
340
+ filtered_texts = C.call_many(filtered_texts, current_text)
341
+ # Default to false for all original transformations.
342
+ for original_transformed_text in transformed_texts:
343
+ self.constraints_cache[(current_text, original_transformed_text)] = False
344
+ # Set unfiltered transformations to True in the cache.
345
+ for filtered_text in filtered_texts:
346
+ self.constraints_cache[(current_text, filtered_text)] = True
347
+ return filtered_texts
348
+
349
+ def filter_transformations(
350
+ self, transformed_texts, current_text, original_text=None
351
+ ):
352
+ """Filters a list of potential transformed texts based on
353
+ ``self.constraints`` Utilizes an LRU cache to attempt to avoid
354
+ recomputing common transformations.
355
+
356
+ Args:
357
+ transformed_texts: A list of candidate transformed ``AttackedText`` to filter.
358
+ current_text: The current ``AttackedText`` on which the transformation was applied.
359
+ original_text: The original ``AttackedText`` from which the attack started.
360
+ """
361
+ # Remove any occurences of current_text in transformed_texts
362
+ transformed_texts = [
363
+ t for t in transformed_texts if t.text != current_text.text
364
+ ]
365
+ # Populate cache with transformed_texts
366
+ uncached_texts = []
367
+ filtered_texts = []
368
+ for transformed_text in transformed_texts:
369
+ if (current_text, transformed_text) not in self.constraints_cache:
370
+ uncached_texts.append(transformed_text)
371
+ else:
372
+ # promote transformed_text to the top of the LRU cache
373
+ self.constraints_cache[
374
+ (current_text, transformed_text)
375
+ ] = self.constraints_cache[(current_text, transformed_text)]
376
+ if self.constraints_cache[(current_text, transformed_text)]:
377
+ filtered_texts.append(transformed_text)
378
+ filtered_texts += self._filter_transformations_uncached(
379
+ uncached_texts, current_text, original_text=original_text
380
+ )
381
+ # Sort transformations to ensure order is preserved between runs
382
+ filtered_texts.sort(key=lambda t: t.text)
383
+ return filtered_texts
384
+
385
+ def _attack(self, initial_result):
386
+ """Calls the ``SearchMethod`` to perturb the ``AttackedText`` stored in
387
+ ``initial_result``.
388
+
389
+ Args:
390
+ initial_result: The initial ``GoalFunctionResult`` from which to perturb.
391
+
392
+ Returns:
393
+ A ``SuccessfulAttackResult``, ``FailedAttackResult``,
394
+ or ``MaximizedAttackResult``.
395
+ """
396
+ final_result = self.search_method(initial_result)
397
+ self.clear_cache()
398
+ if final_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
399
+ result = SuccessfulAttackResult(
400
+ initial_result,
401
+ final_result,
402
+ )
403
+ elif final_result.goal_status == GoalFunctionResultStatus.SEARCHING:
404
+ result = FailedAttackResult(
405
+ initial_result,
406
+ final_result,
407
+ )
408
+ elif final_result.goal_status == GoalFunctionResultStatus.MAXIMIZING:
409
+ result = MaximizedAttackResult(
410
+ initial_result,
411
+ final_result,
412
+ )
413
+ else:
414
+ raise ValueError(f"Unrecognized goal status {final_result.goal_status}")
415
+ return result
416
+
417
+ def attack(self, example, ground_truth_output):
418
+ """Attack a single example.
419
+
420
+ Args:
421
+ example (:obj:`str`, :obj:`OrderedDict[str, str]` or :class:`~textattack.shared.AttackedText`):
422
+ Example to attack. It can be a single string or an `OrderedDict` where
423
+ keys represent the input fields (e.g. "premise", "hypothesis") and the values are the actual input textx.
424
+ Also accepts :class:`~textattack.shared.AttackedText` that wraps around the input.
425
+ ground_truth_output(:obj:`int`, :obj:`float` or :obj:`str`):
426
+ Ground truth output of `example`.
427
+ For classification tasks, it should be an integer representing the ground truth label.
428
+ For regression tasks (e.g. STS), it should be the target value.
429
+ For seq2seq tasks (e.g. translation), it should be the target string.
430
+ Returns:
431
+ :class:`~textattack.attack_results.AttackResult` that represents the result of the attack.
432
+ """
433
+ assert isinstance(
434
+ example, (str, OrderedDict, AttackedText)
435
+ ), "`example` must either be `str`, `collections.OrderedDict`, `textattack.shared.AttackedText`."
436
+ if isinstance(example, (str, OrderedDict)):
437
+ example = AttackedText(example)
438
+
439
+ assert isinstance(
440
+ ground_truth_output, (int, str)
441
+ ), "`ground_truth_output` must either be `str` or `int`."
442
+ goal_function_result, _ = self.goal_function.init_attack_example(
443
+ example, ground_truth_output
444
+ )
445
+ if goal_function_result.goal_status == GoalFunctionResultStatus.SKIPPED:
446
+ return SkippedAttackResult(goal_function_result)
447
+ else:
448
+ result = self._attack(goal_function_result)
449
+ return result
450
+
451
+ def __repr__(self):
452
+ """Prints attack parameters in a human-readable string.
453
+
454
+ Inspired by the readability of printing PyTorch nn.Modules:
455
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py
456
+ """
457
+ main_str = "Attack" + "("
458
+ lines = []
459
+
460
+ lines.append(utils.add_indent(f"(search_method): {self.search_method}", 2))
461
+ # self.goal_function
462
+ lines.append(utils.add_indent(f"(goal_function): {self.goal_function}", 2))
463
+ # self.transformation
464
+ lines.append(utils.add_indent(f"(transformation): {self.transformation}", 2))
465
+ # self.constraints
466
+ constraints_lines = []
467
+ constraints = self.constraints + self.pre_transformation_constraints
468
+ if len(constraints):
469
+ for i, constraint in enumerate(constraints):
470
+ constraints_lines.append(utils.add_indent(f"({i}): {constraint}", 2))
471
+ constraints_str = utils.add_indent("\n" + "\n".join(constraints_lines), 2)
472
+ else:
473
+ constraints_str = "None"
474
+ lines.append(utils.add_indent(f"(constraints): {constraints_str}", 2))
475
+ # self.is_black_box
476
+ lines.append(utils.add_indent(f"(is_black_box): {self.is_black_box}", 2))
477
+ main_str += "\n " + "\n ".join(lines) + "\n"
478
+ main_str += ")"
479
+ return main_str
480
+
481
+ def __getstate__(self):
482
+ state = self.__dict__.copy()
483
+ state["transformation_cache"] = None
484
+ state["constraints_cache"] = None
485
+ return state
486
+
487
+ def __setstate__(self, state):
488
+ self.__dict__ = state
489
+ self.transformation_cache = lru.LRU(self.transformation_cache_size)
490
+ self.constraints_cache = lru.LRU(self.constraint_cache_size)
491
+
492
+ __str__ = __repr__
textattack/attack_args.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AttackArgs Class
3
+ ================
4
+ """
5
+
6
+ from dataclasses import dataclass, field
7
+ import json
8
+ import os
9
+ import sys
10
+ import time
11
+ from typing import Dict, Optional
12
+
13
+ import textattack
14
+ from textattack.shared.utils import ARGS_SPLIT_TOKEN, load_module_from_file
15
+
16
+ from .attack import Attack
17
+ from .dataset_args import DatasetArgs
18
+ from .model_args import ModelArgs
19
+
20
+ ATTACK_RECIPE_NAMES = {
21
+ "alzantot": "textattack.attack_recipes.GeneticAlgorithmAlzantot2018",
22
+ "bae": "textattack.attack_recipes.BAEGarg2019",
23
+ "bert-attack": "textattack.attack_recipes.BERTAttackLi2020",
24
+ "faster-alzantot": "textattack.attack_recipes.FasterGeneticAlgorithmJia2019",
25
+ "deepwordbug": "textattack.attack_recipes.DeepWordBugGao2018",
26
+ "hotflip": "textattack.attack_recipes.HotFlipEbrahimi2017",
27
+ "input-reduction": "textattack.attack_recipes.InputReductionFeng2018",
28
+ "kuleshov": "textattack.attack_recipes.Kuleshov2017",
29
+ "morpheus": "textattack.attack_recipes.MorpheusTan2020",
30
+ "seq2sick": "textattack.attack_recipes.Seq2SickCheng2018BlackBox",
31
+ "textbugger": "textattack.attack_recipes.TextBuggerLi2018",
32
+ "textfooler": "textattack.attack_recipes.TextFoolerJin2019",
33
+ "pwws": "textattack.attack_recipes.PWWSRen2019",
34
+ "iga": "textattack.attack_recipes.IGAWang2019",
35
+ "pruthi": "textattack.attack_recipes.Pruthi2019",
36
+ "pso": "textattack.attack_recipes.PSOZang2020",
37
+ "checklist": "textattack.attack_recipes.CheckList2020",
38
+ "clare": "textattack.attack_recipes.CLARE2020",
39
+ "a2t": "textattack.attack_recipes.A2TYoo2021",
40
+ }
41
+
42
+
43
+ BLACK_BOX_TRANSFORMATION_CLASS_NAMES = {
44
+ "random-synonym-insertion": "textattack.transformations.RandomSynonymInsertion",
45
+ "word-deletion": "textattack.transformations.WordDeletion",
46
+ "word-swap-embedding": "textattack.transformations.WordSwapEmbedding",
47
+ "word-swap-homoglyph": "textattack.transformations.WordSwapHomoglyphSwap",
48
+ "word-swap-inflections": "textattack.transformations.WordSwapInflections",
49
+ "word-swap-neighboring-char-swap": "textattack.transformations.WordSwapNeighboringCharacterSwap",
50
+ "word-swap-random-char-deletion": "textattack.transformations.WordSwapRandomCharacterDeletion",
51
+ "word-swap-random-char-insertion": "textattack.transformations.WordSwapRandomCharacterInsertion",
52
+ "word-swap-random-char-substitution": "textattack.transformations.WordSwapRandomCharacterSubstitution",
53
+ "word-swap-wordnet": "textattack.transformations.WordSwapWordNet",
54
+ "word-swap-masked-lm": "textattack.transformations.WordSwapMaskedLM",
55
+ "word-swap-hownet": "textattack.transformations.WordSwapHowNet",
56
+ "word-swap-qwerty": "textattack.transformations.WordSwapQWERTY",
57
+ }
58
+
59
+
60
+ WHITE_BOX_TRANSFORMATION_CLASS_NAMES = {
61
+ "word-swap-gradient": "textattack.transformations.WordSwapGradientBased"
62
+ }
63
+
64
+
65
+ CONSTRAINT_CLASS_NAMES = {
66
+ #
67
+ # Semantics constraints
68
+ #
69
+ "embedding": "textattack.constraints.semantics.WordEmbeddingDistance",
70
+ "bert": "textattack.constraints.semantics.sentence_encoders.BERT",
71
+ "infer-sent": "textattack.constraints.semantics.sentence_encoders.InferSent",
72
+ "thought-vector": "textattack.constraints.semantics.sentence_encoders.ThoughtVector",
73
+ "use": "textattack.constraints.semantics.sentence_encoders.UniversalSentenceEncoder",
74
+ "muse": "textattack.constraints.semantics.sentence_encoders.MultilingualUniversalSentenceEncoder",
75
+ "bert-score": "textattack.constraints.semantics.BERTScore",
76
+ #
77
+ # Grammaticality constraints
78
+ #
79
+ "lang-tool": "textattack.constraints.grammaticality.LanguageTool",
80
+ "part-of-speech": "textattack.constraints.grammaticality.PartOfSpeech",
81
+ "goog-lm": "textattack.constraints.grammaticality.language_models.GoogleLanguageModel",
82
+ "gpt2": "textattack.constraints.grammaticality.language_models.GPT2",
83
+ "learning-to-write": "textattack.constraints.grammaticality.language_models.LearningToWriteLanguageModel",
84
+ "cola": "textattack.constraints.grammaticality.COLA",
85
+ #
86
+ # Overlap constraints
87
+ #
88
+ "bleu": "textattack.constraints.overlap.BLEU",
89
+ "chrf": "textattack.constraints.overlap.chrF",
90
+ "edit-distance": "textattack.constraints.overlap.LevenshteinEditDistance",
91
+ "meteor": "textattack.constraints.overlap.METEOR",
92
+ "max-words-perturbed": "textattack.constraints.overlap.MaxWordsPerturbed",
93
+ #
94
+ # Pre-transformation constraints
95
+ #
96
+ "repeat": "textattack.constraints.pre_transformation.RepeatModification",
97
+ "stopword": "textattack.constraints.pre_transformation.StopwordModification",
98
+ "max-word-index": "textattack.constraints.pre_transformation.MaxWordIndexModification",
99
+ }
100
+
101
+
102
+ SEARCH_METHOD_CLASS_NAMES = {
103
+ "beam-search": "textattack.search_methods.BeamSearch",
104
+ "greedy": "textattack.search_methods.GreedySearch",
105
+ "ga-word": "textattack.search_methods.GeneticAlgorithm",
106
+ "greedy-word-wir": "textattack.search_methods.GreedyWordSwapWIR",
107
+ "pso": "textattack.search_methods.ParticleSwarmOptimization",
108
+ }
109
+
110
+
111
+ GOAL_FUNCTION_CLASS_NAMES = {
112
+ #
113
+ # Classification goal functions
114
+ #
115
+ "targeted-classification": "textattack.goal_functions.classification.TargetedClassification",
116
+ "untargeted-classification": "textattack.goal_functions.classification.UntargetedClassification",
117
+ "input-reduction": "textattack.goal_functions.classification.InputReduction",
118
+ #
119
+ # Text goal functions
120
+ #
121
+ "minimize-bleu": "textattack.goal_functions.text.MinimizeBleu",
122
+ "non-overlapping-output": "textattack.goal_functions.text.NonOverlappingOutput",
123
+ "text-to-text": "textattack.goal_functions.text.TextToTextGoalFunction",
124
+ }
125
+
126
+
127
+ @dataclass
128
+ class AttackArgs:
129
+ """Attack arguments to be passed to :class:`~textattack.Attacker`.
130
+
131
+ Args:
132
+ num_examples (:obj:`int`, 'optional`, defaults to :obj:`10`):
133
+ The number of examples to attack. :obj:`-1` for entire dataset.
134
+ num_successful_examples (:obj:`int`, `optional`, defaults to :obj:`None`):
135
+ The number of successful adversarial examples we want. This is different from :obj:`num_examples`
136
+ as :obj:`num_examples` only cares about attacking `N` samples while :obj:`num_successful_examples` aims to keep attacking
137
+ until we have `N` successful cases.
138
+
139
+ .. note::
140
+ If set, this argument overrides `num_examples` argument.
141
+ num_examples_offset (:obj: `int`, `optional`, defaults to :obj:`0`):
142
+ The offset index to start at in the dataset.
143
+ attack_n (:obj:`bool`, `optional`, defaults to :obj:`False`):
144
+ Whether to run attack until total of `N` examples have been attacked (and not skipped).
145
+ shuffle (:obj:`bool`, `optional`, defaults to :obj:`False`):
146
+ If :obj:`True`, we randomly shuffle the dataset before attacking. However, this avoids actually shuffling
147
+ the dataset internally and opts for shuffling the list of indices of examples we want to attack. This means
148
+ :obj:`shuffle` can now be used with checkpoint saving.
149
+ query_budget (:obj:`int`, `optional`, defaults to :obj:`None`):
150
+ The maximum number of model queries allowed per example attacked.
151
+ If not set, we use the query budget set in the :class:`~textattack.goal_functions.GoalFunction` object (which by default is :obj:`float("inf")`).
152
+
153
+ .. note::
154
+ Setting this overwrites the query budget set in :class:`~textattack.goal_functions.GoalFunction` object.
155
+ checkpoint_interval (:obj:`int`, `optional`, defaults to :obj:`None`):
156
+ If set, checkpoint will be saved after attacking every `N` examples. If :obj:`None` is passed, no checkpoints will be saved.
157
+ checkpoint_dir (:obj:`str`, `optional`, defaults to :obj:`"checkpoints"`):
158
+ The directory to save checkpoint files.
159
+ random_seed (:obj:`int`, `optional`, defaults to :obj:`765`):
160
+ Random seed for reproducibility.
161
+ parallel (:obj:`False`, `optional`, defaults to :obj:`False`):
162
+ If :obj:`True`, run attack using multiple CPUs/GPUs.
163
+ num_workers_per_device (:obj:`int`, `optional`, defaults to :obj:`1`):
164
+ Number of worker processes to run per device in parallel mode (i.e. :obj:`parallel=True`). For example, if you are using GPUs and :obj:`num_workers_per_device=2`,
165
+ then 2 processes will be running in each GPU.
166
+ log_to_txt (:obj:`str`, `optional`, defaults to :obj:`None`):
167
+ If set, save attack logs as a `.txt` file to the directory specified by this argument.
168
+ If the last part of the provided path ends with `.txt` extension, it is assumed to the desired path of the log file.
169
+ log_to_csv (:obj:`str`, `optional`, defaults to :obj:`None`):
170
+ If set, save attack logs as a CSV file to the directory specified by this argument.
171
+ If the last part of the provided path ends with `.csv` extension, it is assumed to the desired path of the log file.
172
+ csv_coloring_style (:obj:`str`, `optional`, defaults to :obj:`"file"`):
173
+ Method for choosing how to mark perturbed parts of the text. Options are :obj:`"file"`, :obj:`"plain"`, and :obj:`"html"`.
174
+ :obj:`"file"` wraps perturbed parts with double brackets :obj:`[[ <text> ]]` while :obj:`"plain"` does not mark the text in any way.
175
+ log_to_visdom (:obj:`dict`, `optional`, defaults to :obj:`None`):
176
+ If set, Visdom logger is used with the provided dictionary passed as a keyword arguments to :class:`~textattack.loggers.VisdomLogger`.
177
+ Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following
178
+ three keys and their corresponding values: :obj:`"env", "port", "hostname"`.
179
+ log_to_wandb(:obj:`dict`, `optional`, defaults to :obj:`None`):
180
+ If set, WandB logger is used with the provided dictionary passed as a keyword arguments to :class:`~textattack.loggers.WeightsAndBiasesLogger`.
181
+ Pass in empty dictionary to use default arguments. For custom logger, the dictionary should have the following
182
+ key and its corresponding value: :obj:`"project"`.
183
+ disable_stdout (:obj:`bool`, `optional`, defaults to :obj:`False`):
184
+ Disable displaying individual attack results to stdout.
185
+ silent (:obj:`bool`, `optional`, defaults to :obj:`False`):
186
+ Disable all logging (except for errors). This is stronger than :obj:`disable_stdout`.
187
+ enable_advance_metrics (:obj:`bool`, `optional`, defaults to :obj:`False`):
188
+ Enable calculation and display of optional advance post-hoc metrics like perplexity, grammar errors, etc.
189
+ """
190
+
191
+ num_examples: int = 10
192
+ num_successful_examples: int = None
193
+ num_examples_offset: int = 0
194
+ attack_n: bool = False
195
+ shuffle: bool = False
196
+ query_budget: int = None
197
+ checkpoint_interval: int = None
198
+ checkpoint_dir: str = "checkpoints"
199
+ random_seed: int = 765 # equivalent to sum((ord(c) for c in "TEXTATTACK"))
200
+ parallel: bool = False
201
+ num_workers_per_device: int = 1
202
+ log_to_txt: str = None
203
+ log_to_csv: str = None
204
+ log_summary_to_json: str = None
205
+ csv_coloring_style: str = "file"
206
+ log_to_visdom: dict = None
207
+ log_to_wandb: dict = None
208
+ disable_stdout: bool = False
209
+ silent: bool = False
210
+ enable_advance_metrics: bool = False
211
+ metrics: Optional[Dict] = None
212
+
213
+ def __post_init__(self):
214
+ if self.num_successful_examples:
215
+ self.num_examples = None
216
+ if self.num_examples:
217
+ assert (
218
+ self.num_examples >= 0 or self.num_examples == -1
219
+ ), "`num_examples` must be greater than or equal to 0 or equal to -1."
220
+ if self.num_successful_examples:
221
+ assert (
222
+ self.num_successful_examples >= 0
223
+ ), "`num_examples` must be greater than or equal to 0."
224
+
225
+ if self.query_budget:
226
+ assert self.query_budget > 0, "`query_budget` must be greater than 0."
227
+
228
+ if self.checkpoint_interval:
229
+ assert (
230
+ self.checkpoint_interval > 0
231
+ ), "`checkpoint_interval` must be greater than 0."
232
+
233
+ assert (
234
+ self.num_workers_per_device > 0
235
+ ), "`num_workers_per_device` must be greater than 0."
236
+
237
+ @classmethod
238
+ def _add_parser_args(cls, parser):
239
+ """Add listed args to command line parser."""
240
+ default_obj = cls()
241
+ num_ex_group = parser.add_mutually_exclusive_group(required=False)
242
+ num_ex_group.add_argument(
243
+ "--num-examples",
244
+ "-n",
245
+ type=int,
246
+ default=default_obj.num_examples,
247
+ help="The number of examples to process, -1 for entire dataset.",
248
+ )
249
+ num_ex_group.add_argument(
250
+ "--num-successful-examples",
251
+ type=int,
252
+ default=default_obj.num_successful_examples,
253
+ help="The number of successful adversarial examples we want.",
254
+ )
255
+ parser.add_argument(
256
+ "--num-examples-offset",
257
+ "-o",
258
+ type=int,
259
+ required=False,
260
+ default=default_obj.num_examples_offset,
261
+ help="The offset to start at in the dataset.",
262
+ )
263
+ parser.add_argument(
264
+ "--query-budget",
265
+ "-q",
266
+ type=int,
267
+ default=default_obj.query_budget,
268
+ help="The maximum number of model queries allowed per example attacked. Setting this overwrites the query budget set in `GoalFunction` object.",
269
+ )
270
+ parser.add_argument(
271
+ "--shuffle",
272
+ action="store_true",
273
+ default=default_obj.shuffle,
274
+ help="If `True`, shuffle the samples before we attack the dataset. Default is False.",
275
+ )
276
+ parser.add_argument(
277
+ "--attack-n",
278
+ action="store_true",
279
+ default=default_obj.attack_n,
280
+ help="Whether to run attack until `n` examples have been attacked (not skipped).",
281
+ )
282
+ parser.add_argument(
283
+ "--checkpoint-dir",
284
+ required=False,
285
+ type=str,
286
+ default=default_obj.checkpoint_dir,
287
+ help="The directory to save checkpoint files.",
288
+ )
289
+ parser.add_argument(
290
+ "--checkpoint-interval",
291
+ required=False,
292
+ type=int,
293
+ default=default_obj.checkpoint_interval,
294
+ help="If set, checkpoint will be saved after attacking every N examples. If not set, no checkpoints will be saved.",
295
+ )
296
+ parser.add_argument(
297
+ "--random-seed",
298
+ default=default_obj.random_seed,
299
+ type=int,
300
+ help="Random seed for reproducibility.",
301
+ )
302
+ parser.add_argument(
303
+ "--parallel",
304
+ action="store_true",
305
+ default=default_obj.parallel,
306
+ help="Run attack using multiple GPUs.",
307
+ )
308
+ parser.add_argument(
309
+ "--num-workers-per-device",
310
+ default=default_obj.num_workers_per_device,
311
+ type=int,
312
+ help="Number of worker processes to run per device.",
313
+ )
314
+ parser.add_argument(
315
+ "--log-to-txt",
316
+ nargs="?",
317
+ default=default_obj.log_to_txt,
318
+ const="",
319
+ type=str,
320
+ help="Path to which to save attack logs as a text file. Set this argument if you want to save text logs. "
321
+ "If the last part of the path ends with `.txt` extension, the path is assumed to path for output file.",
322
+ )
323
+ parser.add_argument(
324
+ "--log-to-csv",
325
+ nargs="?",
326
+ default=default_obj.log_to_csv,
327
+ const="",
328
+ type=str,
329
+ help="Path to which to save attack logs as a CSV file. Set this argument if you want to save CSV logs. "
330
+ "If the last part of the path ends with `.csv` extension, the path is assumed to path for output file.",
331
+ )
332
+ parser.add_argument(
333
+ "--log-summary-to-json",
334
+ nargs="?",
335
+ default=default_obj.log_summary_to_json,
336
+ const="",
337
+ type=str,
338
+ help="Path to which to save attack summary as a JSON file. Set this argument if you want to save attack results summary in a JSON. "
339
+ "If the last part of the path ends with `.json` extension, the path is assumed to path for output file.",
340
+ )
341
+ parser.add_argument(
342
+ "--csv-coloring-style",
343
+ default=default_obj.csv_coloring_style,
344
+ type=str,
345
+ help='Method for choosing how to mark perturbed parts of the text in CSV logs. Options are "file" and "plain". '
346
+ '"file" wraps text with double brackets `[[ <text> ]]` while "plain" does not mark any text. Default is "file".',
347
+ )
348
+ parser.add_argument(
349
+ "--log-to-visdom",
350
+ nargs="?",
351
+ default=None,
352
+ const='{"env": "main", "port": 8097, "hostname": "localhost"}',
353
+ type=json.loads,
354
+ help="Set this argument if you want to log attacks to Visdom. The dictionary should have the following "
355
+ 'three keys and their corresponding values: `"env", "port", "hostname"`. '
356
+ 'Example for command line use: `--log-to-visdom {"env": "main", "port": 8097, "hostname": "localhost"}`.',
357
+ )
358
+ parser.add_argument(
359
+ "--log-to-wandb",
360
+ nargs="?",
361
+ default=None,
362
+ const='{"project": "textattack"}',
363
+ type=json.loads,
364
+ help="Set this argument if you want to log attacks to WandB. The dictionary should have the following "
365
+ 'key and its corresponding value: `"project"`. '
366
+ 'Example for command line use: `--log-to-wandb {"project": "textattack"}`.',
367
+ )
368
+ parser.add_argument(
369
+ "--disable-stdout",
370
+ action="store_true",
371
+ default=default_obj.disable_stdout,
372
+ help="Disable logging attack results to stdout",
373
+ )
374
+ parser.add_argument(
375
+ "--silent",
376
+ action="store_true",
377
+ default=default_obj.silent,
378
+ help="Disable all logging",
379
+ )
380
+ parser.add_argument(
381
+ "--enable-advance-metrics",
382
+ action="store_true",
383
+ default=default_obj.enable_advance_metrics,
384
+ help="Enable calculation and display of optional advance post-hoc metrics like perplexity, USE distance, etc.",
385
+ )
386
+
387
+ return parser
388
+
389
+ @classmethod
390
+ def create_loggers_from_args(cls, args):
391
+ """Creates AttackLogManager from an AttackArgs object."""
392
+ assert isinstance(
393
+ args, cls
394
+ ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."
395
+
396
+ # Create logger
397
+ attack_log_manager = textattack.loggers.AttackLogManager(args.metrics)
398
+
399
+ # Get current time for file naming
400
+ timestamp = time.strftime("%Y-%m-%d-%H-%M")
401
+
402
+ # if '--log-to-txt' specified with arguments
403
+ if args.log_to_txt is not None:
404
+ if args.log_to_txt.lower().endswith(".txt"):
405
+ txt_file_path = args.log_to_txt
406
+ else:
407
+ txt_file_path = os.path.join(args.log_to_txt, f"{timestamp}-log.txt")
408
+
409
+ dir_path = os.path.dirname(txt_file_path)
410
+ dir_path = dir_path if dir_path else "."
411
+ if not os.path.exists(dir_path):
412
+ os.makedirs(os.path.dirname(txt_file_path))
413
+
414
+ color_method = "file"
415
+ attack_log_manager.add_output_file(txt_file_path, color_method)
416
+
417
+ # if '--log-to-csv' specified with arguments
418
+ if args.log_to_csv is not None:
419
+ if args.log_to_csv.lower().endswith(".csv"):
420
+ csv_file_path = args.log_to_csv
421
+ else:
422
+ csv_file_path = os.path.join(args.log_to_csv, f"{timestamp}-log.csv")
423
+
424
+ dir_path = os.path.dirname(csv_file_path)
425
+ dir_path = dir_path if dir_path else "."
426
+ if not os.path.exists(dir_path):
427
+ os.makedirs(dir_path)
428
+
429
+ color_method = (
430
+ None if args.csv_coloring_style == "plain" else args.csv_coloring_style
431
+ )
432
+ attack_log_manager.add_output_csv(csv_file_path, color_method)
433
+
434
+ # if '--log-summary-to-json' specified with arguments
435
+ if args.log_summary_to_json is not None:
436
+ if args.log_summary_to_json.lower().endswith(".json"):
437
+ summary_json_file_path = args.log_summary_to_json
438
+ else:
439
+ summary_json_file_path = os.path.join(
440
+ args.log_summary_to_json, f"{timestamp}-attack_summary_log.json"
441
+ )
442
+
443
+ dir_path = os.path.dirname(summary_json_file_path)
444
+ dir_path = dir_path if dir_path else "."
445
+ if not os.path.exists(dir_path):
446
+ os.makedirs(os.path.dirname(summary_json_file_path))
447
+
448
+ attack_log_manager.add_output_summary_json(summary_json_file_path)
449
+
450
+ # Visdom
451
+ if args.log_to_visdom is not None:
452
+ attack_log_manager.enable_visdom(**args.log_to_visdom)
453
+
454
+ # Weights & Biases
455
+ if args.log_to_wandb is not None:
456
+ attack_log_manager.enable_wandb(**args.log_to_wandb)
457
+
458
+ # Stdout
459
+ if not args.disable_stdout and not sys.stdout.isatty():
460
+ attack_log_manager.disable_color()
461
+ elif not args.disable_stdout:
462
+ attack_log_manager.enable_stdout()
463
+
464
+ return attack_log_manager
465
+
466
+
467
+ @dataclass
468
+ class _CommandLineAttackArgs:
469
+ """Attack args for command line execution. This requires more arguments to
470
+ create ``Attack`` object as specified.
471
+
472
+ Args:
473
+ transformation (:obj:`str`, `optional`, defaults to :obj:`"word-swap-embedding"`):
474
+ Name of transformation to use.
475
+ constraints (:obj:`list[str]`, `optional`, defaults to :obj:`["repeat", "stopword"]`):
476
+ List of names of constraints to use.
477
+ goal_function (:obj:`str`, `optional`, defaults to :obj:`"untargeted-classification"`):
478
+ Name of goal function to use.
479
+ search_method (:obj:`str`, `optional`, defualts to :obj:`"greedy-word-wir"`):
480
+ Name of search method to use.
481
+ attack_recipe (:obj:`str`, `optional`, defaults to :obj:`None`):
482
+ Name of attack recipe to use.
483
+ .. note::
484
+ Setting this overrides any previous selection of transformation, constraints, goal function, and search method.
485
+ attack_from_file (:obj:`str`, `optional`, defaults to :obj:`None`):
486
+ Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.
487
+ .. note::
488
+ If this is set, it overrides any previous selection of transformation, constraints, goal function, and search method
489
+ interactive (:obj:`bool`, `optional`, defaults to :obj:`False`):
490
+ If `True`, carry attack in interactive mode.
491
+ parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
492
+ If `True`, attack in parallel.
493
+ model_batch_size (:obj:`int`, `optional`, defaults to :obj:`32`):
494
+ The batch size for making queries to the victim model.
495
+ model_cache_size (:obj:`int`, `optional`, defaults to :obj:`2**18`):
496
+ The maximum number of items to keep in the model results cache at once.
497
+ constraint-cache-size (:obj:`int`, `optional`, defaults to :obj:`2**18`):
498
+ The maximum number of items to keep in the constraints cache at once.
499
+ """
500
+
501
+ transformation: str = "word-swap-embedding"
502
+ constraints: list = field(default_factory=lambda: ["repeat", "stopword"])
503
+ goal_function: str = "untargeted-classification"
504
+ search_method: str = "greedy-word-wir"
505
+ attack_recipe: str = None
506
+ attack_from_file: str = None
507
+ interactive: bool = False
508
+ parallel: bool = False
509
+ model_batch_size: int = 32
510
+ model_cache_size: int = 2**18
511
+ constraint_cache_size: int = 2**18
512
+
513
+ @classmethod
514
+ def _add_parser_args(cls, parser):
515
+ """Add listed args to command line parser."""
516
+ default_obj = cls()
517
+ transformation_names = set(BLACK_BOX_TRANSFORMATION_CLASS_NAMES.keys()) | set(
518
+ WHITE_BOX_TRANSFORMATION_CLASS_NAMES.keys()
519
+ )
520
+ parser.add_argument(
521
+ "--transformation",
522
+ type=str,
523
+ required=False,
524
+ default=default_obj.transformation,
525
+ help='The transformation to apply. Usage: "--transformation {transformation}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
526
+ + str(transformation_names),
527
+ )
528
+ parser.add_argument(
529
+ "--constraints",
530
+ type=str,
531
+ required=False,
532
+ nargs="*",
533
+ default=default_obj.constraints,
534
+ help='Constraints to add to the attack. Usage: "--constraints {constraint}:{arg_1}={value_1},{arg_3}={value_3}". Choices: '
535
+ + str(CONSTRAINT_CLASS_NAMES.keys()),
536
+ )
537
+ goal_function_choices = ", ".join(GOAL_FUNCTION_CLASS_NAMES.keys())
538
+ parser.add_argument(
539
+ "--goal-function",
540
+ "-g",
541
+ default=default_obj.goal_function,
542
+ help=f"The goal function to use. choices: {goal_function_choices}",
543
+ )
544
+ attack_group = parser.add_mutually_exclusive_group(required=False)
545
+ search_choices = ", ".join(SEARCH_METHOD_CLASS_NAMES.keys())
546
+ attack_group.add_argument(
547
+ "--search-method",
548
+ "--search",
549
+ "-s",
550
+ type=str,
551
+ required=False,
552
+ default=default_obj.search_method,
553
+ help=f"The search method to use. choices: {search_choices}",
554
+ )
555
+ attack_group.add_argument(
556
+ "--attack-recipe",
557
+ "--recipe",
558
+ "-r",
559
+ type=str,
560
+ required=False,
561
+ default=default_obj.attack_recipe,
562
+ help="full attack recipe (overrides provided goal function, transformation & constraints)",
563
+ choices=ATTACK_RECIPE_NAMES.keys(),
564
+ )
565
+ attack_group.add_argument(
566
+ "--attack-from-file",
567
+ type=str,
568
+ required=False,
569
+ default=default_obj.attack_from_file,
570
+ help="Path of `.py` file from which to load attack from. Use `<path>^<variable_name>` to specifiy which variable to import from the file.",
571
+ )
572
+ parser.add_argument(
573
+ "--interactive",
574
+ action="store_true",
575
+ default=default_obj.interactive,
576
+ help="Whether to run attacks interactively.",
577
+ )
578
+ parser.add_argument(
579
+ "--model-batch-size",
580
+ type=int,
581
+ default=default_obj.model_batch_size,
582
+ help="The batch size for making calls to the model.",
583
+ )
584
+ parser.add_argument(
585
+ "--model-cache-size",
586
+ type=int,
587
+ default=default_obj.model_cache_size,
588
+ help="The maximum number of items to keep in the model results cache at once.",
589
+ )
590
+ parser.add_argument(
591
+ "--constraint-cache-size",
592
+ type=int,
593
+ default=default_obj.constraint_cache_size,
594
+ help="The maximum number of items to keep in the constraints cache at once.",
595
+ )
596
+
597
+ return parser
598
+
599
+ @classmethod
600
+ def _create_transformation_from_args(cls, args, model_wrapper):
601
+ """Create `Transformation` based on provided `args` and
602
+ `model_wrapper`."""
603
+
604
+ transformation_name = args.transformation
605
+ if ARGS_SPLIT_TOKEN in transformation_name:
606
+ transformation_name, params = transformation_name.split(ARGS_SPLIT_TOKEN)
607
+
608
+ if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
609
+ transformation = eval(
610
+ f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model_wrapper.model, {params})"
611
+ )
612
+ elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
613
+ transformation = eval(
614
+ f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}({params})"
615
+ )
616
+ else:
617
+ raise ValueError(
618
+ f"Error: unsupported transformation {transformation_name}"
619
+ )
620
+ else:
621
+ if transformation_name in WHITE_BOX_TRANSFORMATION_CLASS_NAMES:
622
+ transformation = eval(
623
+ f"{WHITE_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}(model_wrapper.model)"
624
+ )
625
+ elif transformation_name in BLACK_BOX_TRANSFORMATION_CLASS_NAMES:
626
+ transformation = eval(
627
+ f"{BLACK_BOX_TRANSFORMATION_CLASS_NAMES[transformation_name]}()"
628
+ )
629
+ else:
630
+ raise ValueError(
631
+ f"Error: unsupported transformation {transformation_name}"
632
+ )
633
+ return transformation
634
+
635
+ @classmethod
636
+ def _create_goal_function_from_args(cls, args, model_wrapper):
637
+ """Create `GoalFunction` based on provided `args` and
638
+ `model_wrapper`."""
639
+
640
+ goal_function = args.goal_function
641
+ if ARGS_SPLIT_TOKEN in goal_function:
642
+ goal_function_name, params = goal_function.split(ARGS_SPLIT_TOKEN)
643
+ if goal_function_name not in GOAL_FUNCTION_CLASS_NAMES:
644
+ raise ValueError(
645
+ f"Error: unsupported goal_function {goal_function_name}"
646
+ )
647
+ goal_function = eval(
648
+ f"{GOAL_FUNCTION_CLASS_NAMES[goal_function_name]}(model_wrapper, {params})"
649
+ )
650
+ elif goal_function in GOAL_FUNCTION_CLASS_NAMES:
651
+ goal_function = eval(
652
+ f"{GOAL_FUNCTION_CLASS_NAMES[goal_function]}(model_wrapper)"
653
+ )
654
+ else:
655
+ raise ValueError(f"Error: unsupported goal_function {goal_function}")
656
+ if args.query_budget:
657
+ goal_function.query_budget = args.query_budget
658
+ goal_function.model_cache_size = args.model_cache_size
659
+ goal_function.batch_size = args.model_batch_size
660
+ return goal_function
661
+
662
+ @classmethod
663
+ def _create_constraints_from_args(cls, args):
664
+ """Create list of `Constraints` based on provided `args`."""
665
+
666
+ if not args.constraints:
667
+ return []
668
+
669
+ _constraints = []
670
+ for constraint in args.constraints:
671
+ if ARGS_SPLIT_TOKEN in constraint:
672
+ constraint_name, params = constraint.split(ARGS_SPLIT_TOKEN)
673
+ if constraint_name not in CONSTRAINT_CLASS_NAMES:
674
+ raise ValueError(f"Error: unsupported constraint {constraint_name}")
675
+ _constraints.append(
676
+ eval(f"{CONSTRAINT_CLASS_NAMES[constraint_name]}({params})")
677
+ )
678
+ elif constraint in CONSTRAINT_CLASS_NAMES:
679
+ _constraints.append(eval(f"{CONSTRAINT_CLASS_NAMES[constraint]}()"))
680
+ else:
681
+ raise ValueError(f"Error: unsupported constraint {constraint}")
682
+
683
+ return _constraints
684
+
685
+ @classmethod
686
+ def _create_attack_from_args(cls, args, model_wrapper):
687
+ """Given ``CommandLineArgs`` and ``ModelWrapper``, return specified
688
+ ``Attack`` object."""
689
+
690
+ assert isinstance(
691
+ args, cls
692
+ ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."
693
+
694
+ if args.attack_recipe:
695
+ if ARGS_SPLIT_TOKEN in args.attack_recipe:
696
+ recipe_name, params = args.attack_recipe.split(ARGS_SPLIT_TOKEN)
697
+ if recipe_name not in ATTACK_RECIPE_NAMES:
698
+ raise ValueError(f"Error: unsupported recipe {recipe_name}")
699
+ recipe = eval(
700
+ f"{ATTACK_RECIPE_NAMES[recipe_name]}.build(model_wrapper, {params})"
701
+ )
702
+ elif args.attack_recipe in ATTACK_RECIPE_NAMES:
703
+ recipe = eval(
704
+ f"{ATTACK_RECIPE_NAMES[args.attack_recipe]}.build(model_wrapper)"
705
+ )
706
+ else:
707
+ raise ValueError(f"Invalid recipe {args.attack_recipe}")
708
+ if args.query_budget:
709
+ recipe.goal_function.query_budget = args.query_budget
710
+ recipe.goal_function.model_cache_size = args.model_cache_size
711
+ recipe.constraint_cache_size = args.constraint_cache_size
712
+ return recipe
713
+ elif args.attack_from_file:
714
+ if ARGS_SPLIT_TOKEN in args.attack_from_file:
715
+ attack_file, attack_name = args.attack_from_file.split(ARGS_SPLIT_TOKEN)
716
+ else:
717
+ attack_file, attack_name = args.attack_from_file, "attack"
718
+ attack_module = load_module_from_file(attack_file)
719
+ if not hasattr(attack_module, attack_name):
720
+ raise ValueError(
721
+ f"Loaded `{attack_file}` but could not find `{attack_name}`."
722
+ )
723
+ attack_func = getattr(attack_module, attack_name)
724
+ return attack_func(model_wrapper)
725
+ else:
726
+ goal_function = cls._create_goal_function_from_args(args, model_wrapper)
727
+ transformation = cls._create_transformation_from_args(args, model_wrapper)
728
+ constraints = cls._create_constraints_from_args(args)
729
+ if ARGS_SPLIT_TOKEN in args.search_method:
730
+ search_name, params = args.search_method.split(ARGS_SPLIT_TOKEN)
731
+ if search_name not in SEARCH_METHOD_CLASS_NAMES:
732
+ raise ValueError(f"Error: unsupported search {search_name}")
733
+ search_method = eval(
734
+ f"{SEARCH_METHOD_CLASS_NAMES[search_name]}({params})"
735
+ )
736
+ elif args.search_method in SEARCH_METHOD_CLASS_NAMES:
737
+ search_method = eval(
738
+ f"{SEARCH_METHOD_CLASS_NAMES[args.search_method]}()"
739
+ )
740
+ else:
741
+ raise ValueError(f"Error: unsupported attack {args.search_method}")
742
+
743
+ return Attack(
744
+ goal_function,
745
+ constraints,
746
+ transformation,
747
+ search_method,
748
+ constraint_cache_size=args.constraint_cache_size,
749
+ )
750
+
751
+
752
+ # This neat trick allows use to reorder the arguments to avoid TypeErrors commonly found when inheriting dataclass.
753
+ # https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
754
+ @dataclass
755
+ class CommandLineAttackArgs(AttackArgs, _CommandLineAttackArgs, DatasetArgs, ModelArgs):
756
+ @classmethod
757
+ def _add_parser_args(cls, parser):
758
+ """Add listed args to command line parser."""
759
+ parser = ModelArgs._add_parser_args(parser)
760
+ parser = DatasetArgs._add_parser_args(parser)
761
+ parser = _CommandLineAttackArgs._add_parser_args(parser)
762
+ parser = AttackArgs._add_parser_args(parser)
763
+ return parser
textattack/attack_recipes/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """.. _attack_recipes:
2
+
3
+ Attack Recipes Package:
4
+ ========================
5
+
6
+ We provide a number of pre-built attack recipes, which correspond to attacks from the literature. To run an attack recipe from the command line, run::
7
+
8
+ textattack attack --recipe [recipe_name]
9
+
10
+ To initialize an attack in Python script, use::
11
+
12
+ <recipe name>.build(model_wrapper)
13
+
14
+ For example, ``attack = InputReductionFeng2018.build(model)`` creates `attack`, an object of type ``Attack`` with the goal function, transformation, constraints, and search method specified in that paper. This object can then be used just like any other attack; for example, by calling ``attack.attack_dataset``.
15
+
16
+ TextAttack supports the following attack recipes (each recipe's documentation contains a link to the corresponding paper):
17
+
18
+ .. contents:: :local:
19
+ """
20
+
21
+ from .attack_recipe import AttackRecipe
22
+
23
+ from .a2t_yoo_2021 import A2TYoo2021
24
+ from .bae_garg_2019 import BAEGarg2019
25
+ from .bert_attack_li_2020 import BERTAttackLi2020
26
+ from .genetic_algorithm_alzantot_2018 import GeneticAlgorithmAlzantot2018
27
+ from .faster_genetic_algorithm_jia_2019 import FasterGeneticAlgorithmJia2019
28
+ from .deepwordbug_gao_2018 import DeepWordBugGao2018
29
+ from .hotflip_ebrahimi_2017 import HotFlipEbrahimi2017
30
+ from .input_reduction_feng_2018 import InputReductionFeng2018
31
+ from .kuleshov_2017 import Kuleshov2017
32
+ from .morpheus_tan_2020 import MorpheusTan2020
33
+ from .seq2sick_cheng_2018_blackbox import Seq2SickCheng2018BlackBox
34
+ from .textbugger_li_2018 import TextBuggerLi2018
35
+ from .textfooler_jin_2019 import TextFoolerJin2019
36
+ from .pwws_ren_2019 import PWWSRen2019
37
+ from .iga_wang_2019 import IGAWang2019
38
+ from .pruthi_2019 import Pruthi2019
39
+ from .pso_zang_2020 import PSOZang2020
40
+ from .checklist_ribeiro_2020 import CheckList2020
41
+ from .clare_li_2020 import CLARE2020
42
+ from .french_recipe import FrenchRecipe
43
+ from .spanish_recipe import SpanishRecipe
textattack/attack_recipes/a2t_yoo_2021.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A2T (A2T: Attack for Adversarial Training Recipe)
3
+ ==================================================
4
+
5
+ """
6
+
7
+ from textattack import Attack
8
+ from textattack.constraints.grammaticality import PartOfSpeech
9
+ from textattack.constraints.pre_transformation import (
10
+ InputColumnModification,
11
+ MaxModificationRate,
12
+ RepeatModification,
13
+ StopwordModification,
14
+ )
15
+ from textattack.constraints.semantics import WordEmbeddingDistance
16
+ from textattack.constraints.semantics.sentence_encoders import BERT
17
+ from textattack.goal_functions import UntargetedClassification
18
+ from textattack.search_methods import GreedyWordSwapWIR
19
+ from textattack.transformations import WordSwapEmbedding, WordSwapMaskedLM
20
+
21
+ from .attack_recipe import AttackRecipe
22
+
23
+
24
+ class A2TYoo2021(AttackRecipe):
25
+ """Towards Improving Adversarial Training of NLP Models.
26
+
27
+ (Yoo et al., 2021)
28
+
29
+ https://arxiv.org/abs/2109.00544
30
+ """
31
+
32
+ @staticmethod
33
+ def build(model_wrapper, mlm=False):
34
+ """Build attack recipe.
35
+
36
+ Args:
37
+ model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`):
38
+ Model wrapper containing both the model and the tokenizer.
39
+ mlm (:obj:`bool`, `optional`, defaults to :obj:`False`):
40
+ If :obj:`True`, load `A2T-MLM` attack. Otherwise, load regular `A2T` attack.
41
+
42
+ Returns:
43
+ :class:`~textattack.Attack`: A2T attack.
44
+ """
45
+ constraints = [RepeatModification(), StopwordModification()]
46
+ input_column_modification = InputColumnModification(
47
+ ["premise", "hypothesis"], {"premise"}
48
+ )
49
+ constraints.append(input_column_modification)
50
+ constraints.append(PartOfSpeech(allow_verb_noun_swap=False))
51
+ constraints.append(MaxModificationRate(max_rate=0.1, min_threshold=4))
52
+ sent_encoder = BERT(
53
+ model_name="stsb-distilbert-base", threshold=0.9, metric="cosine"
54
+ )
55
+ constraints.append(sent_encoder)
56
+
57
+ if mlm:
58
+ transformation = transformation = WordSwapMaskedLM(
59
+ method="bae", max_candidates=20, min_confidence=0.0, batch_size=16
60
+ )
61
+ else:
62
+ transformation = WordSwapEmbedding(max_candidates=20)
63
+ constraints.append(WordEmbeddingDistance(min_cos_sim=0.8))
64
+
65
+ #
66
+ # Goal is untargeted classification
67
+ #
68
+ goal_function = UntargetedClassification(model_wrapper, model_batch_size=32)
69
+ #
70
+ # Greedily swap words with "Word Importance Ranking".
71
+ #
72
+ search_method = GreedyWordSwapWIR(wir_method="gradient")
73
+
74
+ return Attack(goal_function, constraints, transformation, search_method)
textattack/attack_recipes/attack_recipe.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Attack Recipe Class
3
+ ========================
4
+
5
+ """
6
+
7
+ from abc import ABC, abstractmethod
8
+
9
+ from textattack import Attack
10
+
11
+
12
+ class AttackRecipe(Attack, ABC):
13
+ """A recipe for building an NLP adversarial attack from the literature."""
14
+
15
+ @staticmethod
16
+ @abstractmethod
17
+ def build(model_wrapper, **kwargs):
18
+ """Creates pre-built :class:`~textattack.Attack` that correspond to
19
+ attacks from the literature.
20
+
21
+ Args:
22
+ model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`):
23
+ :class:`~textattack.models.wrappers.ModelWrapper` that contains the victim model and tokenizer.
24
+ This is passed to :class:`~textattack.goal_functions.GoalFunction` when constructing the attack.
25
+ kwargs:
26
+ Additional keyword arguments.
27
+ Returns:
28
+ :class:`~textattack.Attack`
29
+ """
30
+ raise NotImplementedError()
textattack/attack_recipes/bae_garg_2019.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BAE (BAE: BERT-Based Adversarial Examples)
3
+ ============================================
4
+
5
+ """
6
+ from textattack.constraints.grammaticality import PartOfSpeech
7
+ from textattack.constraints.pre_transformation import (
8
+ RepeatModification,
9
+ StopwordModification,
10
+ )
11
+ from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
12
+ from textattack.goal_functions import UntargetedClassification
13
+ from textattack.search_methods import GreedyWordSwapWIR
14
+ from textattack.transformations import WordSwapMaskedLM
15
+
16
+ from .attack_recipe import AttackRecipe
17
+
18
+
19
+ class BAEGarg2019(AttackRecipe):
20
+ """Siddhant Garg and Goutham Ramakrishnan, 2019.
21
+
22
+ BAE: BERT-based Adversarial Examples for Text Classification.
23
+
24
+ https://arxiv.org/pdf/2004.01970
25
+
26
+ This is "attack mode" 1 from the paper, BAE-R, word replacement.
27
+
28
+ We present 4 attack modes for BAE based on the
29
+ R and I operations, where for each token t in S:
30
+ • BAE-R: Replace token t (See Algorithm 1)
31
+ • BAE-I: Insert a token to the left or right of t
32
+ • BAE-R/I: Either replace token t or insert a
33
+ token to the left or right of t
34
+ • BAE-R+I: First replace token t, then insert a
35
+ token to the left or right of t
36
+ """
37
+
38
+ @staticmethod
39
+ def build(model_wrapper):
40
+ # "In this paper, we present a simple yet novel technique: BAE (BERT-based
41
+ # Adversarial Examples), which uses a language model (LM) for token
42
+ # replacement to best fit the overall context. We perturb an input sentence
43
+ # by either replacing a token or inserting a new token in the sentence, by
44
+ # means of masking a part of the input and using a LM to fill in the mask."
45
+ #
46
+ # We only consider the top K=50 synonyms from the MLM predictions.
47
+ #
48
+ # [from email correspondance with the author]
49
+ # "When choosing the top-K candidates from the BERT masked LM, we filter out
50
+ # the sub-words and only retain the whole words (by checking if they are
51
+ # present in the GloVE vocabulary)"
52
+ #
53
+ transformation = WordSwapMaskedLM(
54
+ method="bae", max_candidates=50, min_confidence=0.0
55
+ )
56
+ #
57
+ # Don't modify the same word twice or stopwords.
58
+ #
59
+ constraints = [RepeatModification(), StopwordModification()]
60
+
61
+ # For the R operations we add an additional check for
62
+ # grammatical correctness of the generated adversarial example by filtering
63
+ # out predicted tokens that do not form the same part of speech (POS) as the
64
+ # original token t_i in the sentence.
65
+ constraints.append(PartOfSpeech(allow_verb_noun_swap=True))
66
+
67
+ # "To ensure semantic similarity on introducing perturbations in the input
68
+ # text, we filter the set of top-K masked tokens (K is a pre-defined
69
+ # constant) predicted by BERT-MLM using a Universal Sentence Encoder (USE)
70
+ # (Cer et al., 2018)-based sentence similarity scorer."
71
+ #
72
+ # "[We] set a threshold of 0.8 for the cosine similarity between USE-based
73
+ # embeddings of the adversarial and input text."
74
+ #
75
+ # [from email correspondence with the author]
76
+ # "For a fair comparison of the benefits of using a BERT-MLM in our paper,
77
+ # we retained the majority of TextFooler's specifications. Thus we:
78
+ # 1. Use the USE for comparison within a window of size 15 around the word
79
+ # being replaced/inserted.
80
+ # 2. Set the similarity score threshold to 0.1 for inputs shorter than the
81
+ # window size (this translates roughly to almost always accepting the new text).
82
+ # 3. Perform the USE similarity thresholding of 0.8 with respect to the text
83
+ # just before the replacement/insertion and not the original text (For
84
+ # example: at the 3rd R/I operation, we compute the USE score on a window
85
+ # of size 15 of the text obtained after the first 2 R/I operations and not
86
+ # the original text).
87
+ # ...
88
+ # To address point (3) from above, compare the USE with the original text
89
+ # at each iteration instead of the current one (While doing this change
90
+ # for the R-operation is trivial, doing it for the I-operation with the
91
+ # window based USE comparison might be more involved)."
92
+ #
93
+ # Finally, since the BAE code is based on the TextFooler code, we need to
94
+ # adjust the threshold to account for the missing / pi in the cosine
95
+ # similarity comparison. So the final threshold is 1 - (1 - 0.8) / pi
96
+ # = 1 - (0.2 / pi) = 0.936338023.
97
+ use_constraint = UniversalSentenceEncoder(
98
+ threshold=0.936338023,
99
+ metric="cosine",
100
+ compare_against_original=True,
101
+ window_size=15,
102
+ skip_text_shorter_than_window=True,
103
+ )
104
+ constraints.append(use_constraint)
105
+ #
106
+ # Goal is untargeted classification.
107
+ #
108
+ goal_function = UntargetedClassification(model_wrapper)
109
+ #
110
+ # "We estimate the token importance Ii of each token
111
+ # t_i ∈ S = [t1, . . . , tn], by deleting ti from S and computing the
112
+ # decrease in probability of predicting the correct label y, similar
113
+ # to (Jin et al., 2019).
114
+ #
115
+ # • "If there are multiple tokens can cause C to misclassify S when they
116
+ # replace the mask, we choose the token which makes Sadv most similar to
117
+ # the original S based on the USE score."
118
+ # • "If no token causes misclassification, we choose the perturbation that
119
+ # decreases the prediction probability P(C(Sadv)=y) the most."
120
+ #
121
+ search_method = GreedyWordSwapWIR(wir_method="delete")
122
+
123
+ return BAEGarg2019(goal_function, constraints, transformation, search_method)
textattack/attack_recipes/bert_attack_li_2020.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BERT-Attack:
3
+ ============================================================
4
+
5
+ (BERT-Attack: Adversarial Attack Against BERT Using BERT)
6
+
7
+ .. warning::
8
+ This attack is super slow
9
+ (see https://github.com/QData/TextAttack/issues/586)
10
+ Consider using smaller values for "max_candidates".
11
+
12
+ """
13
+ from textattack import Attack
14
+ from textattack.constraints.overlap import MaxWordsPerturbed
15
+ from textattack.constraints.pre_transformation import (
16
+ RepeatModification,
17
+ StopwordModification,
18
+ )
19
+ from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
20
+ from textattack.goal_functions import UntargetedClassification
21
+ from textattack.search_methods import GreedyWordSwapWIR
22
+ from textattack.transformations import WordSwapMaskedLM
23
+
24
+ from .attack_recipe import AttackRecipe
25
+
26
+
27
+ class BERTAttackLi2020(AttackRecipe):
28
+ """Li, L.., Ma, R., Guo, Q., Xiangyang, X., Xipeng, Q. (2020).
29
+
30
+ BERT-ATTACK: Adversarial Attack Against BERT Using BERT
31
+
32
+ https://arxiv.org/abs/2004.09984
33
+
34
+ This is "attack mode" 1 from the paper, BAE-R, word replacement.
35
+ """
36
+
37
+ @staticmethod
38
+ def build(model_wrapper):
39
+ # [from correspondence with the author]
40
+ # Candidate size K is set to 48 for all data-sets.
41
+ transformation = WordSwapMaskedLM(method="bert-attack", max_candidates=48)
42
+ #
43
+ # Don't modify the same word twice or stopwords.
44
+ #
45
+ constraints = [RepeatModification(), StopwordModification()]
46
+
47
+ # "We only take ε percent of the most important words since we tend to keep
48
+ # perturbations minimum."
49
+ #
50
+ # [from correspondence with the author]
51
+ # "Word percentage allowed to change is set to 0.4 for most data-sets, this
52
+ # parameter is trivial since most attacks only need a few changes. This
53
+ # epsilon is only used to avoid too much queries on those very hard samples."
54
+ constraints.append(MaxWordsPerturbed(max_percent=0.4))
55
+
56
+ # "As used in TextFooler (Jin et al., 2019), we also use Universal Sentence
57
+ # Encoder (Cer et al., 2018) to measure the semantic consistency between the
58
+ # adversarial sample and the original sequence. To balance between semantic
59
+ # preservation and attack success rate, we set up a threshold of semantic
60
+ # similarity score to filter the less similar examples."
61
+ #
62
+ # [from correspondence with author]
63
+ # "Over the full texts, after generating all the adversarial samples, we filter
64
+ # out low USE score samples. Thus the success rate is lower but the USE score
65
+ # can be higher. (actually USE score is not a golden metric, so we simply
66
+ # measure the USE score over the final texts for a comparison with TextFooler).
67
+ # For datasets like IMDB, we set a higher threshold between 0.4-0.7; for
68
+ # datasets like MNLI, we set threshold between 0-0.2."
69
+ #
70
+ # Since the threshold in the real world can't be determined from the training
71
+ # data, the TextAttack implementation uses a fixed threshold - determined to
72
+ # be 0.2 to be most fair.
73
+ use_constraint = UniversalSentenceEncoder(
74
+ threshold=0.2,
75
+ metric="cosine",
76
+ compare_against_original=True,
77
+ window_size=None,
78
+ )
79
+ constraints.append(use_constraint)
80
+ #
81
+ # Goal is untargeted classification.
82
+ #
83
+ goal_function = UntargetedClassification(model_wrapper)
84
+ #
85
+ # "We first select the words in the sequence which have a high significance
86
+ # influence on the final output logit. Let S = [w0, ··· , wi ··· ] denote
87
+ # the input sentence, and oy(S) denote the logit output by the target model
88
+ # for correct label y, the importance score Iwi is defined as
89
+ # Iwi = oy(S) − oy(S\wi), where S\wi = [w0, ··· , wi−1, [MASK], wi+1, ···]
90
+ # is the sentence after replacing wi with [MASK]. Then we rank all the words
91
+ # according to the ranking score Iwi in descending order to create word list
92
+ # L."
93
+ search_method = GreedyWordSwapWIR(wir_method="unk")
94
+
95
+ return Attack(goal_function, constraints, transformation, search_method)
textattack/attack_recipes/checklist_ribeiro_2020.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CheckList:
3
+ =========================
4
+
5
+ (Beyond Accuracy: Behavioral Testing of NLP models with CheckList)
6
+
7
+ """
8
+ from textattack import Attack
9
+ from textattack.constraints.pre_transformation import RepeatModification
10
+ from textattack.goal_functions import UntargetedClassification
11
+ from textattack.search_methods import GreedySearch
12
+ from textattack.transformations import (
13
+ CompositeTransformation,
14
+ WordSwapChangeLocation,
15
+ WordSwapChangeName,
16
+ WordSwapChangeNumber,
17
+ WordSwapContract,
18
+ WordSwapExtend,
19
+ )
20
+
21
+ from .attack_recipe import AttackRecipe
22
+
23
+
24
+ class CheckList2020(AttackRecipe):
25
+ """An implementation of the attack used in "Beyond Accuracy: Behavioral
26
+ Testing of NLP models with CheckList", Ribeiro et al., 2020.
27
+
28
+ This attack focuses on a number of attacks used in the Invariance Testing
29
+ Method: Contraction, Extension, Changing Names, Number, Location
30
+
31
+ https://arxiv.org/abs/2005.04118
32
+ """
33
+
34
+ @staticmethod
35
+ def build(model_wrapper):
36
+ transformation = CompositeTransformation(
37
+ [
38
+ WordSwapExtend(),
39
+ WordSwapContract(),
40
+ WordSwapChangeName(),
41
+ WordSwapChangeNumber(),
42
+ WordSwapChangeLocation(),
43
+ ]
44
+ )
45
+
46
+ # Need this constraint to prevent extend and contract modifying each others' changes and forming infinite loop
47
+ constraints = [RepeatModification()]
48
+
49
+ # Untargeted attack & GreedySearch
50
+ goal_function = UntargetedClassification(model_wrapper)
51
+ search_method = GreedySearch()
52
+
53
+ return Attack(goal_function, constraints, transformation, search_method)