Spaces:
Sleeping
Sleeping
""" | |
Managing Attack Logs. | |
======================== | |
""" | |
from typing import Dict, Optional | |
from textattack.metrics.attack_metrics import ( | |
AttackQueries, | |
AttackSuccessRate, | |
WordsPerturbed, | |
) | |
from textattack.metrics.quality_metrics import Perplexity, USEMetric | |
from . import ( | |
CSVLogger, | |
FileLogger, | |
JsonSummaryLogger, | |
VisdomLogger, | |
WeightsAndBiasesLogger, | |
) | |
class AttackLogManager: | |
"""Logs the results of an attack to all attached loggers.""" | |
# metrics maps strings (metric names) to textattack.metric.Metric objects | |
metrics: Dict | |
def __init__(self, metrics: Optional[Dict]): | |
self.loggers = [] | |
self.results = [] | |
self.enable_advance_metrics = False | |
if metrics is None: | |
self.metrics = {} | |
else: | |
self.metrics = metrics | |
def enable_stdout(self): | |
self.loggers.append(FileLogger(stdout=True)) | |
def enable_visdom(self): | |
self.loggers.append(VisdomLogger()) | |
def enable_wandb(self, **kwargs): | |
self.loggers.append(WeightsAndBiasesLogger(**kwargs)) | |
def disable_color(self): | |
self.loggers.append(FileLogger(stdout=True, color_method="file")) | |
def add_output_file(self, filename, color_method): | |
self.loggers.append(FileLogger(filename=filename, color_method=color_method)) | |
def add_output_csv(self, filename, color_method): | |
self.loggers.append(CSVLogger(filename=filename, color_method=color_method)) | |
def add_output_summary_json(self, filename): | |
self.loggers.append(JsonSummaryLogger(filename=filename)) | |
def log_result(self, result): | |
"""Logs an ``AttackResult`` on each of `self.loggers`.""" | |
self.results.append(result) | |
for logger in self.loggers: | |
logger.log_attack_result(result) | |
def log_results(self, results): | |
"""Logs an iterable of ``AttackResult`` objects on each of | |
`self.loggers`.""" | |
for result in results: | |
self.log_result(result) | |
self.log_summary() | |
def log_summary_rows(self, rows, title, window_id): | |
for logger in self.loggers: | |
logger.log_summary_rows(rows, title, window_id) | |
def log_sep(self): | |
for logger in self.loggers: | |
logger.log_sep() | |
def flush(self): | |
for logger in self.loggers: | |
logger.flush() | |
def log_attack_details(self, attack_name, model_name): | |
# @TODO log a more complete set of attack details | |
attack_detail_rows = [ | |
["Attack algorithm:", attack_name], | |
["Model:", model_name], | |
] | |
self.log_summary_rows(attack_detail_rows, "Attack Details", "attack_details") | |
def log_summary(self): | |
total_attacks = len(self.results) | |
if total_attacks == 0: | |
return | |
# Default metrics - calculated on every attack | |
attack_success_stats = AttackSuccessRate().calculate(self.results) | |
words_perturbed_stats = WordsPerturbed().calculate(self.results) | |
attack_query_stats = AttackQueries().calculate(self.results) | |
# @TODO generate this table based on user input - each column in specific class | |
# Example to demonstrate: | |
# summary_table_rows = attack_success_stats.display_row() + words_perturbed_stats.display_row() + ... | |
summary_table_rows = [ | |
[ | |
"Number of successful attacks:", | |
attack_success_stats["successful_attacks"], | |
], | |
["Number of failed attacks:", attack_success_stats["failed_attacks"]], | |
["Number of skipped attacks:", attack_success_stats["skipped_attacks"]], | |
[ | |
"Original accuracy:", | |
str(attack_success_stats["original_accuracy"]) + "%", | |
], | |
[ | |
"Accuracy under attack:", | |
str(attack_success_stats["attack_accuracy_perc"]) + "%", | |
], | |
[ | |
"Attack success rate:", | |
str(attack_success_stats["attack_success_rate"]) + "%", | |
], | |
[ | |
"Average perturbed word %:", | |
str(words_perturbed_stats["avg_word_perturbed_perc"]) + "%", | |
], | |
[ | |
"Average num. words per input:", | |
words_perturbed_stats["avg_word_perturbed"], | |
], | |
] | |
summary_table_rows.append( | |
["Avg num queries:", attack_query_stats["avg_num_queries"]] | |
) | |
for metric_name, metric in self.metrics.items(): | |
summary_table_rows.append([metric_name, metric.calculate(self.results)]) | |
if self.enable_advance_metrics: | |
perplexity_stats = Perplexity().calculate(self.results) | |
use_stats = USEMetric().calculate(self.results) | |
summary_table_rows.append( | |
[ | |
"Average Original Perplexity:", | |
perplexity_stats["avg_original_perplexity"], | |
] | |
) | |
summary_table_rows.append( | |
[ | |
"Average Attack Perplexity:", | |
perplexity_stats["avg_attack_perplexity"], | |
] | |
) | |
summary_table_rows.append( | |
["Average Attack USE Score:", use_stats["avg_attack_use_score"]] | |
) | |
self.log_summary_rows( | |
summary_table_rows, "Attack Results", "attack_results_summary" | |
) | |
# Show histogram of words changed. | |
numbins = max(words_perturbed_stats["max_words_changed"], 10) | |
for logger in self.loggers: | |
logger.log_hist( | |
words_perturbed_stats["num_words_changed_until_success"][:numbins], | |
numbins=numbins, | |
title="Num Words Perturbed", | |
window_id="num_words_perturbed", | |
) | |