test_skim / blackbox_utils /my_attack.py
adamtayzzz's picture
Upload 21 files
0e73e91 verified
raw
history blame
4.49 kB
import copy
import time
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import nltk
import string
from copy import deepcopy
from torchprofile import profile_macs
from datetime import datetime
from transformers import BertTokenizer, BertModel, BertForMaskedLM
from nltk.tokenize.treebank import TreebankWordTokenizer, TreebankWordDetokenizer
from blackbox_utils.Attack_base import MyAttack
class CharacterAttack(MyAttack):
# TODO: 存储一个list每次只修改不同的token位置
def __init__(self, name, model, tokenizer, device, max_per, padding, max_length, label_to_id, sentence1_key, sentence2_key):
super(CharacterAttack, self).__init__(name, model, tokenizer, device, max_per, padding, max_length, label_to_id, sentence1_key, sentence2_key)
def compute_importance(self, text):
current_tensor = self.preprocess_function(text)["input_ids"][0]
# print(current_tensor)
word_losses = {}
for idx in range(1,len(current_tensor)-1):
# print(current_tensor[:idx])
# print(current_tensor[idx+1:])
sentence_tokens_without = torch.cat([current_tensor[:idx],current_tensor[idx + 1:]])
sentence_without = self.tokenizer.decode(sentence_tokens_without)
sentence_without = [sentence_without,text[1]]
word_losses[int(current_tensor[idx])] = self.compute_loss(sentence_without)
word_losses = [k for k, _ in sorted(word_losses.items(), key=lambda item: item[1], reverse=True)]
return word_losses
def compute_loss(self, text):
inputs = self.preprocess_function(text)
shift_inputs = (inputs['input_ids'],inputs['attention_mask'],inputs['token_type_ids'])
# toc = datetime.now()
macs = profile_macs(self.model, shift_inputs)
# tic = datetime.now()
# print((tic-toc).total_seconds())
result = self.random_tokenizer(*inputs, padding=self.padding, max_length=self.max_length, truncation=True)
token_length = len(result["input_ids"])
macs_per_token = macs/(token_length*10**8)
return self.predict(macs_per_token)
def mutation(self, current_adv_text):
current_tensor = self.preprocess_function(current_adv_text)
# print(current_tensor)
current_tensor = current_tensor["input_ids"][0]
# print(current_tensor)
new_strings = self.character_replace_mutation(current_adv_text, current_tensor)
return new_strings
@staticmethod
def transfer(c: str):
if c in string.ascii_lowercase:
return c.upper()
elif c in string.ascii_uppercase:
return c.lower()
return c
def character_replace_mutation(self, current_text, current_tensor):
important_tensor = self.compute_importance(current_text)
# current_string = [self.tokenizer.decoder[int(t)] for t in current_tensor]
new_strings = [current_text]
# 遍历每个vocabulary,查找文本有的第一个token
# print(current_tensor)
for t in important_tensor:
if int(t) not in current_tensor:
continue
ori_decode_token = self.tokenizer.decode([int(t)])
# print(ori_decode_token)
# if self.space_token in ori_decode_token:
# ori_token = ori_decode_token.replace(self.space_token, '')
# else:
ori_token = ori_decode_token
# 如果只有一个长度
if len(ori_token) == 1 or ori_token not in current_text[0]: #todo
continue
# 随意插入一个字符
candidate = [ori_token[:i] + insert + ori_token[i:] for i in range(len(ori_token)) for insert in self.insert_character]
# 随意更换一个大小写
candidate += [ori_token[:i - 1] + self.transfer(ori_token[i - 1]) + ori_token[i:] for i in range(1, len(ori_token))]
# print(candidate)
# 最多只替换一次
new_strings += [[current_text[0].replace(ori_token, c, 1),current_text[1]] for c in candidate]
# ori_tensor_pos = current_tensor.eq(int(t)).nonzero()
#
# for p in ori_tensor_pos:
# new_strings += [current_string[:p] + c + current_string[p + 1:] for c in candidate]
# 存在一个有效的改动就返回
if len(new_strings) > 1:
return new_strings
return new_strings