File size: 1,957 Bytes
d4dd3c5
 
 
 
 
 
 
67a34bd
d4dd3c5
a597c76
d4dd3c5
 
a597c76
d4dd3c5
 
a597c76
d4dd3c5
 
a597c76
d4dd3c5
 
 
a597c76
f301e04
d4dd3c5
 
a597c76
d4dd3c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6d622b
d4dd3c5
67a34bd
d4dd3c5
 
 
 
f6d622b
d4dd3c5
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# plotting functions

# external imports
import numpy as np
import matplotlib.pyplot as plt


def plot_seq(seq_values: list, method: str = ""):

    # separate the tokens and their corresponding importance values
    tokens, importance = zip(*seq_values)

    # convert importance values to numpy array for conditional coloring
    importance = np.array(importance)

    # determine the colors based on the sign of the importance values
    colors = ["#ff0051" if val > 0 else "#008bfb" for val in importance]

    # create a bar plot
    plt.figure(figsize=(len(tokens) * 0.9, np.max(importance)))
    x_positions = range(len(tokens))  # Positions for the bars

    # creating vertical bar plot
    bar_width = 0.8
    plt.bar(x_positions, importance, color=colors, align="center", width=bar_width)

    # annotating each bar with its value
    padding = 0.1  # Padding for text annotation
    for x, (y, color) in enumerate(zip(importance, colors)):
        sign = "+" if y > 0 else ""
        plt.annotate(
            f"{sign}{y:.2f}",  # Format the value with sign
            xy=(x, y + padding if y > 0 else y - padding),
            ha="center",
            color=color,
            va="bottom" if y > 0 else "top",  # Vertical alignment
            fontweight="bold",  # Bold text
            bbox={
                "facecolor": "white",
                "edgecolor": "none",
                "boxstyle": "round,pad=0.1",
            },  # White background
        )

    # setting plot properties, labels, and title
    plt.axhline(0, color="black", linewidth=1)
    plt.title(f"Input Token Attribution with {method}")
    plt.xlabel("Input Tokens", labelpad=0.5)
    plt.ylabel("Attribution")
    plt.xticks(x_positions, tokens, rotation=45)

    # adjusting y-axis limits to ensure there's enough space for labels
    y_min, y_max = plt.ylim()
    y_range = y_max - y_min
    plt.ylim(y_min - 0.1 * y_range, y_max + 0.1 * y_range)

    return plt