Spaces:
Sleeping
Sleeping
import copy | |
import time | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from tqdm import tqdm | |
import nltk | |
import string | |
from copy import deepcopy | |
from torchprofile import profile_macs | |
from datetime import datetime | |
from transformers import BertTokenizer, BertModel, BertForMaskedLM | |
from nltk.tokenize.treebank import TreebankWordTokenizer, TreebankWordDetokenizer | |
from blackbox_utils.Attack_base import MyAttack | |
class CharacterAttack(MyAttack): | |
# TODO: 存储一个list每次只修改不同的token位置 | |
def __init__(self, name, model, tokenizer, device, max_per, padding, max_length, label_to_id, sentence1_key, sentence2_key): | |
super(CharacterAttack, self).__init__(name, model, tokenizer, device, max_per, padding, max_length, label_to_id, sentence1_key, sentence2_key) | |
def compute_importance(self, text): | |
current_tensor = self.preprocess_function(text)["input_ids"][0] | |
# print(current_tensor) | |
word_losses = {} | |
for idx in range(1,len(current_tensor)-1): | |
# print(current_tensor[:idx]) | |
# print(current_tensor[idx+1:]) | |
sentence_tokens_without = torch.cat([current_tensor[:idx],current_tensor[idx + 1:]]) | |
sentence_without = self.tokenizer.decode(sentence_tokens_without) | |
sentence_without = [sentence_without,text[1]] | |
word_losses[int(current_tensor[idx])] = self.compute_loss(sentence_without) | |
word_losses = [k for k, _ in sorted(word_losses.items(), key=lambda item: item[1], reverse=True)] | |
return word_losses | |
def compute_loss(self, text): | |
inputs = self.preprocess_function(text) | |
shift_inputs = (inputs['input_ids'],inputs['attention_mask'],inputs['token_type_ids']) | |
# toc = datetime.now() | |
macs = profile_macs(self.model, shift_inputs) | |
# tic = datetime.now() | |
# print((tic-toc).total_seconds()) | |
result = self.random_tokenizer(*inputs, padding=self.padding, max_length=self.max_length, truncation=True) | |
token_length = len(result["input_ids"]) | |
macs_per_token = macs/(token_length*10**8) | |
return self.predict(macs_per_token) | |
def mutation(self, current_adv_text): | |
current_tensor = self.preprocess_function(current_adv_text) | |
# print(current_tensor) | |
current_tensor = current_tensor["input_ids"][0] | |
# print(current_tensor) | |
new_strings = self.character_replace_mutation(current_adv_text, current_tensor) | |
return new_strings | |
def transfer(c: str): | |
if c in string.ascii_lowercase: | |
return c.upper() | |
elif c in string.ascii_uppercase: | |
return c.lower() | |
return c | |
def character_replace_mutation(self, current_text, current_tensor): | |
important_tensor = self.compute_importance(current_text) | |
# current_string = [self.tokenizer.decoder[int(t)] for t in current_tensor] | |
new_strings = [current_text] | |
# 遍历每个vocabulary,查找文本有的第一个token | |
# print(current_tensor) | |
for t in important_tensor: | |
if int(t) not in current_tensor: | |
continue | |
ori_decode_token = self.tokenizer.decode([int(t)]) | |
# print(ori_decode_token) | |
# if self.space_token in ori_decode_token: | |
# ori_token = ori_decode_token.replace(self.space_token, '') | |
# else: | |
ori_token = ori_decode_token | |
# 如果只有一个长度 | |
if len(ori_token) == 1 or ori_token not in current_text[0]: #todo | |
continue | |
# 随意插入一个字符 | |
candidate = [ori_token[:i] + insert + ori_token[i:] for i in range(len(ori_token)) for insert in self.insert_character] | |
# 随意更换一个大小写 | |
candidate += [ori_token[:i - 1] + self.transfer(ori_token[i - 1]) + ori_token[i:] for i in range(1, len(ori_token))] | |
# print(candidate) | |
# 最多只替换一次 | |
new_strings += [[current_text[0].replace(ori_token, c, 1),current_text[1]] for c in candidate] | |
# ori_tensor_pos = current_tensor.eq(int(t)).nonzero() | |
# | |
# for p in ori_tensor_pos: | |
# new_strings += [current_string[:p] + c + current_string[p + 1:] for c in candidate] | |
# 存在一个有效的改动就返回 | |
if len(new_strings) > 1: | |
return new_strings | |
return new_strings |