test_skim / blackbox_utils /Attack_base.py
adamtayzzz's picture
Update blackbox_utils/Attack_base.py
185186c verified
raw
history blame
No virus
8.77 kB
import torch
import torch.nn as nn
# import jieba
import string
import numpy as np
from copy import deepcopy
from tqdm import tqdm
import time
from datetime import datetime
import os
from sklearn.linear_model import LinearRegression
from torch.multiprocessing import Process,Pool
from transformers import BertTokenizer
os.environ['TOKENIZERS_PARALLELISM']='True'
# torch.autograd.set_detect_anomaly(True)
class BaseAttack:
def __init__(self, name, model, tokenizer, device, max_per, padding,max_length,label_to_id,sentence1_key,sentence2_key):
self.name = name
self.model = model
self.tokenizer = tokenizer
self.device = device
self.model = self.model.to(self.device)
self.model.eval()
self.padding = padding
self.max_length = max_length
self.label_to_id = label_to_id
self.sentence1_key = sentence1_key
self.sentence2_key = sentence2_key
# 修改token个数的最大值
self.max_per = max_per
# linear regression model initialization
# self.linear_regression()
self.random_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def run_attack(self, x):
pass
def compute_loss(self, x):
pass
def preprocess_function(self,examples,to_device=True):
# Tokenize the texts
texts = ((examples[0],) if self.sentence2_key is None else (examples[0], examples[1]))
result = self.tokenizer(*texts, padding=self.padding, max_length=self.max_length, truncation=True)
new_result = {}
for key,item in result.items():
if to_device:
new_result[key] = torch.tensor(item).unsqueeze(0).to(self.device)
else:
new_result[key] = torch.tensor(item).unsqueeze(0)
return new_result
def get_pred(self,input_):
return self.get_prob(input_).logits.argmax(dim=-1)
def get_prob(self,input_):
toc = datetime.now()
batch = self.preprocess_function(input_)
# batch['gumbel_softmax']=gradient
# print(batch)
outputs = self.model(**batch) # get all logits
tic = datetime.now()
running_time = (tic-toc).total_seconds()
return outputs,running_time
def output_analysis(self,outputs):
# print(outputs)
all_skim_loss, all_tokens_remained = list(), list()
all_layer_tokens_remained = [[] for _ in range(len(outputs.layer_tokens_remained))]
all_skim_loss.append(outputs.skim_loss)
all_tokens_remained.append(outputs.tokens_remained)
for layer_idx,mac in enumerate(outputs.layer_tokens_remained):
all_layer_tokens_remained[layer_idx].append(mac)
skim_loss = torch.mean(torch.stack(all_skim_loss))
tokens_remained = torch.mean(torch.stack(all_tokens_remained))
layers_result = [torch.mean(torch.stack(macs)) for i,macs in enumerate(all_layer_tokens_remained)]
return skim_loss,tokens_remained,layers_result
def load_data(self,model_path_key,mode='train'):
path = f'flops_count/{model_path_key}/{mode}'
if os.path.exists(f'{path}/process_data.pth'):
print(f'loading data from {path}')
data = torch.load(f'{path}/process_data.pth')
else:
time_list = torch.load(f'{path}/time_list.pth')
ratio_list = torch.load(f'{path}/ratio_list.pth')
token_num_list = torch.load(f'{path}/text_len_list_tokenizer.pth')
ratio_list_ = []
for ratio in ratio_list:
ratio_list_.append(ratio.item())
y = np.expand_dims(np.array(ratio_list_),axis=1)
# print(x.shape)
time_list_ = []
for time,token_num in zip(time_list,token_num_list):
time_list_.append((time/(token_num*(10**8))))
x = np.expand_dims(np.array(time_list_),axis=1)
# print(y.shape)
data = dict()
data['x']=x
data['y']=y
torch.save(data,f'{path}/process_data.pth')
return data
def predict(self,x):
return self.w*x+self.b
def linear_regression(self):
print("="*20)
print('Linear Regression Generation')
data_train = self.load_data(self.name,mode='train')
data_test = self.load_data(self.name,mode='test')
# print(data_train,data_test)
reg = LinearRegression().fit(data_train['x'],data_train['y'])
train_score = reg.score(data_train['x'],data_train['y'])
test_score = reg.score(data_test['x'],data_test['y'])
print(f'train set score: {train_score}')
print(f'test set score: {test_score}')
self.w = reg.coef_[0][0]
self.b = reg.intercept_[0]
print("w:",self.w)
print("b:",self.b)
print(self.predict(0.8))
class MyAttack(BaseAttack):
def __init__(self, name, model, tokenizer, device, max_per, padding, max_length, label_to_id, sentence1_key, sentence2_key):
super(MyAttack, self).__init__(name, model, tokenizer, device, max_per, padding, max_length, label_to_id, sentence1_key, sentence2_key)
# self.insert_character = string.punctuation
self.insert_character = string.digits
self.insert_character += string.ascii_letters
# self.insert_character -= """"'/\\"""
# print(self.insert_character)
self.origin_ratio = []
self.attack_ratio = []
self.layer_result = []
self.origin_layer_result = []
# @torch.no_grad()
# def select_best(self, new_strings):
# best_string = None
# best_loss = 0
# for new_string in new_strings:
# new_predicted_loss = self.compute_loss(new_string)
# if new_predicted_loss>best_loss:
# best_loss = new_predicted_loss
# best_string = new_string
# assert best_string is not None
# return best_string,best_loss
@torch.no_grad()
def select_best(self, new_strings):
# self.model.to('cpu')
best_string = None
best_loss = 0
with Pool(processes=4) as pool:
loss_list = pool.map(self.compute_loss,new_strings)
idx = np.argmax(np.array(loss_list))
best_loss = loss_list[idx]
best_string = new_strings[idx]
# self.model.to(self.device)
# for new_string in new_strings:
# new_predicted_loss = self.compute_loss(new_string)
# if new_predicted_loss>best_loss:
# best_loss = new_predicted_loss
# best_string = new_string
assert best_string is not None
# self.model.to(self.device)
return best_string,best_loss
def compute_loss(self, xxx):
raise NotImplementedError
def mutation(self, current_adv_text, grad, modify_pos):
raise NotImplementedError
def run_attack(self, text):
# assert len(text) == 1
# print(text)
text[0] = text[0].strip(" .")
text[1] = text[1].strip(" .")
print(f'Origin Text: {text}')
current_adv_text = deepcopy(text)
# max_per 最多扰动单词的个数
# pbar = tqdm(range(self.max_per))
best_loss = 0
best_tokens_remained = 0
best_layer_result = None
output,_ = self.get_prob(current_adv_text)
origin_skim_loss,origin_ratio_,origin_layer_result_ = self.output_analysis(output)
print(origin_skim_loss,origin_ratio_)
self.origin_ratio.append(origin_ratio_.item())
self.origin_layer_result.append(origin_layer_result_)
# for it in pbar:
for _ in range(self.max_per):
# 得到每个修改的位置
new_strings = self.mutation(current_adv_text)
#print(new_strings)
current_adv_text,current_loss = self.select_best(new_strings)
# print(new_strings)
# print(current_adv_text,current_loss,current_tokens_remained)
if current_loss > best_loss:
best_adv_text = deepcopy(current_adv_text)
best_loss = current_loss
print(best_adv_text)
output,_ = self.get_prob(best_adv_text)
_,best_tokens_remained,best_layer_result = self.output_analysis(output)
self.attack_ratio.append(best_tokens_remained.item())
self.layer_result.append(best_layer_result)
print(f'Malicious Text: {best_adv_text}')
print(f'Origin Ratio: {self.origin_ratio[-1]} Attack Ratio: {self.attack_ratio[-1]}')
print(f'Layer Result: {self.layer_result[-1]}')
return best_adv_text,best_loss,best_tokens_remained,best_layer_result