Spaces:
Sleeping
Sleeping
File size: 5,870 Bytes
4a1df2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
"""
Managing Attack Logs.
========================
"""
from typing import Dict, Optional
from textattack.metrics.attack_metrics import (
AttackQueries,
AttackSuccessRate,
WordsPerturbed,
)
from textattack.metrics.quality_metrics import Perplexity, USEMetric
from . import (
CSVLogger,
FileLogger,
JsonSummaryLogger,
VisdomLogger,
WeightsAndBiasesLogger,
)
class AttackLogManager:
"""Logs the results of an attack to all attached loggers."""
# metrics maps strings (metric names) to textattack.metric.Metric objects
metrics: Dict
def __init__(self, metrics: Optional[Dict]):
self.loggers = []
self.results = []
self.enable_advance_metrics = False
if metrics is None:
self.metrics = {}
else:
self.metrics = metrics
def enable_stdout(self):
self.loggers.append(FileLogger(stdout=True))
def enable_visdom(self):
self.loggers.append(VisdomLogger())
def enable_wandb(self, **kwargs):
self.loggers.append(WeightsAndBiasesLogger(**kwargs))
def disable_color(self):
self.loggers.append(FileLogger(stdout=True, color_method="file"))
def add_output_file(self, filename, color_method):
self.loggers.append(FileLogger(filename=filename, color_method=color_method))
def add_output_csv(self, filename, color_method):
self.loggers.append(CSVLogger(filename=filename, color_method=color_method))
def add_output_summary_json(self, filename):
self.loggers.append(JsonSummaryLogger(filename=filename))
def log_result(self, result):
"""Logs an ``AttackResult`` on each of `self.loggers`."""
self.results.append(result)
for logger in self.loggers:
logger.log_attack_result(result)
def log_results(self, results):
"""Logs an iterable of ``AttackResult`` objects on each of
`self.loggers`."""
for result in results:
self.log_result(result)
self.log_summary()
def log_summary_rows(self, rows, title, window_id):
for logger in self.loggers:
logger.log_summary_rows(rows, title, window_id)
def log_sep(self):
for logger in self.loggers:
logger.log_sep()
def flush(self):
for logger in self.loggers:
logger.flush()
def log_attack_details(self, attack_name, model_name):
# @TODO log a more complete set of attack details
attack_detail_rows = [
["Attack algorithm:", attack_name],
["Model:", model_name],
]
self.log_summary_rows(attack_detail_rows, "Attack Details", "attack_details")
def log_summary(self):
total_attacks = len(self.results)
if total_attacks == 0:
return
# Default metrics - calculated on every attack
attack_success_stats = AttackSuccessRate().calculate(self.results)
words_perturbed_stats = WordsPerturbed().calculate(self.results)
attack_query_stats = AttackQueries().calculate(self.results)
# @TODO generate this table based on user input - each column in specific class
# Example to demonstrate:
# summary_table_rows = attack_success_stats.display_row() + words_perturbed_stats.display_row() + ...
summary_table_rows = [
[
"Number of successful attacks:",
attack_success_stats["successful_attacks"],
],
["Number of failed attacks:", attack_success_stats["failed_attacks"]],
["Number of skipped attacks:", attack_success_stats["skipped_attacks"]],
[
"Original accuracy:",
str(attack_success_stats["original_accuracy"]) + "%",
],
[
"Accuracy under attack:",
str(attack_success_stats["attack_accuracy_perc"]) + "%",
],
[
"Attack success rate:",
str(attack_success_stats["attack_success_rate"]) + "%",
],
[
"Average perturbed word %:",
str(words_perturbed_stats["avg_word_perturbed_perc"]) + "%",
],
[
"Average num. words per input:",
words_perturbed_stats["avg_word_perturbed"],
],
]
summary_table_rows.append(
["Avg num queries:", attack_query_stats["avg_num_queries"]]
)
for metric_name, metric in self.metrics.items():
summary_table_rows.append([metric_name, metric.calculate(self.results)])
if self.enable_advance_metrics:
perplexity_stats = Perplexity().calculate(self.results)
use_stats = USEMetric().calculate(self.results)
summary_table_rows.append(
[
"Average Original Perplexity:",
perplexity_stats["avg_original_perplexity"],
]
)
summary_table_rows.append(
[
"Average Attack Perplexity:",
perplexity_stats["avg_attack_perplexity"],
]
)
summary_table_rows.append(
["Average Attack USE Score:", use_stats["avg_attack_use_score"]]
)
self.log_summary_rows(
summary_table_rows, "Attack Results", "attack_results_summary"
)
# Show histogram of words changed.
numbins = max(words_perturbed_stats["max_words_changed"], 10)
for logger in self.loggers:
logger.log_hist(
words_perturbed_stats["num_words_changed_until_success"][:numbins],
numbins=numbins,
title="Num Words Perturbed",
window_id="num_words_perturbed",
)
|