Spaces:
Runtime error
Runtime error
Commit
·
f5ebee7
1
Parent(s):
940d70a
feat: removing plots, updating iframe height, minor changes
Browse files- backend/controller.py +5 -6
- explanation/interpret_shap.py +2 -58
- explanation/markup.py +5 -2
- explanation/visualize.py +3 -66
- main.py +8 -12
- public/about.md +1 -1
- utils/formatting.py +8 -1
backend/controller.py
CHANGED
@@ -40,7 +40,7 @@ def interference(
|
|
40 |
raise RuntimeError("There was an error in the selected XAI approach.")
|
41 |
|
42 |
# call the explained chat function
|
43 |
-
prompt_output, history_output, xai_graphic,
|
44 |
explained_chat(
|
45 |
model=godel,
|
46 |
xai=xai,
|
@@ -61,17 +61,16 @@ def interference(
|
|
61 |
knowledge=knowledge,
|
62 |
)
|
63 |
# set XAI outputs to disclaimer html/none
|
64 |
-
xai_graphic,
|
65 |
"""
|
66 |
<div style="text-align: center"><h4>Without Selected XAI Approach,
|
67 |
no graphic will be displayed</h4></div>
|
68 |
""",
|
69 |
-
None,
|
70 |
[("", "")],
|
71 |
)
|
72 |
|
73 |
# return the outputs
|
74 |
-
return prompt_output, history_output, xai_graphic,
|
75 |
|
76 |
|
77 |
# simple chat function that calls the model
|
@@ -98,10 +97,10 @@ def explained_chat(
|
|
98 |
prompt = model.format_prompt(message, history, system_prompt, knowledge)
|
99 |
|
100 |
# generating an answer using the xai methods explain and respond function
|
101 |
-
answer, xai_graphic,
|
102 |
|
103 |
# updating the chat history with the new answer
|
104 |
history.append((message, answer))
|
105 |
|
106 |
# returning the updated history, xai graphic and xai plot elements
|
107 |
-
return "", history, xai_graphic,
|
|
|
40 |
raise RuntimeError("There was an error in the selected XAI approach.")
|
41 |
|
42 |
# call the explained chat function
|
43 |
+
prompt_output, history_output, xai_graphic, xai_markup = (
|
44 |
explained_chat(
|
45 |
model=godel,
|
46 |
xai=xai,
|
|
|
61 |
knowledge=knowledge,
|
62 |
)
|
63 |
# set XAI outputs to disclaimer html/none
|
64 |
+
xai_graphic, xai_markup = (
|
65 |
"""
|
66 |
<div style="text-align: center"><h4>Without Selected XAI Approach,
|
67 |
no graphic will be displayed</h4></div>
|
68 |
""",
|
|
|
69 |
[("", "")],
|
70 |
)
|
71 |
|
72 |
# return the outputs
|
73 |
+
return prompt_output, history_output, xai_graphic, xai_markup
|
74 |
|
75 |
|
76 |
# simple chat function that calls the model
|
|
|
97 |
prompt = model.format_prompt(message, history, system_prompt, knowledge)
|
98 |
|
99 |
# generating an answer using the xai methods explain and respond function
|
100 |
+
answer, xai_graphic, xai_markup = xai.chat_explained(model, prompt)
|
101 |
|
102 |
# updating the chat history with the new answer
|
103 |
history.append((message, answer))
|
104 |
|
105 |
# returning the updated history, xai graphic and xai plot elements
|
106 |
+
return "", history, xai_graphic, xai_markup
|
explanation/interpret_shap.py
CHANGED
@@ -26,18 +26,13 @@ def chat_explained(model, prompt):
|
|
26 |
|
27 |
# create the explanation graphic and plot
|
28 |
graphic = create_graphic(shap_values)
|
29 |
-
plot = create_plot(
|
30 |
-
values=shap_values.values[0],
|
31 |
-
output_names=shap_values.output_names,
|
32 |
-
input_names=shap_values.data[0],
|
33 |
-
)
|
34 |
marked_text = markup_text(
|
35 |
shap_values.data[0], shap_values.values[0], variant="shap"
|
36 |
)
|
37 |
|
38 |
# create the response text
|
39 |
response_text = fmt.format_output_text(shap_values.output_names)
|
40 |
-
return response_text, graphic,
|
41 |
|
42 |
|
43 |
def wrap_shap(model):
|
@@ -67,55 +62,4 @@ def create_graphic(shap_values):
|
|
67 |
graphic_html = plots.text(shap_values, display=False)
|
68 |
|
69 |
# return the html graphic as string
|
70 |
-
return str(graphic_html)
|
71 |
-
|
72 |
-
|
73 |
-
# creating an attention heatmap plot using matplotlib/seaborn
|
74 |
-
# CREDIT: adopted from official Matplotlib documentation
|
75 |
-
## see https://matplotlib.org/stable/
|
76 |
-
def create_plot(values, output_names, input_names):
|
77 |
-
|
78 |
-
# Set seaborn style to dark
|
79 |
-
sns.set(style="white")
|
80 |
-
fig, ax = plt.subplots()
|
81 |
-
|
82 |
-
# Setting figure size
|
83 |
-
fig.set_size_inches(
|
84 |
-
max(values.shape[1] * 2, 10),
|
85 |
-
max(values.shape[0] * 1, 5),
|
86 |
-
)
|
87 |
-
|
88 |
-
# Plotting the heatmap with Seaborn's color palette
|
89 |
-
im = ax.imshow(
|
90 |
-
values,
|
91 |
-
vmax=values.max(),
|
92 |
-
vmin=values.min(),
|
93 |
-
cmap=sns.color_palette("vlag_r", as_cmap=True),
|
94 |
-
aspect="auto",
|
95 |
-
)
|
96 |
-
|
97 |
-
# Creating colorbar
|
98 |
-
cbar = ax.figure.colorbar(im, ax=ax)
|
99 |
-
cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
|
100 |
-
cbar.ax.yaxis.set_tick_params(color="black")
|
101 |
-
plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="black")
|
102 |
-
|
103 |
-
# Setting ticks and labels with white color for visibility
|
104 |
-
ax.set_yticks(np.arange(len(input_names)), labels=input_names)
|
105 |
-
ax.set_xticks(np.arange(len(output_names)), labels=output_names)
|
106 |
-
plt.setp(ax.get_xticklabels(), color="black", rotation=45, ha="right")
|
107 |
-
plt.setp(ax.get_yticklabels(), color="black")
|
108 |
-
|
109 |
-
# Adjusting tick labels
|
110 |
-
ax.tick_params(
|
111 |
-
top=True, bottom=False, labeltop=False, labelbottom=True, color="white"
|
112 |
-
)
|
113 |
-
|
114 |
-
# Adding text annotations with appropriate contrast
|
115 |
-
for i in range(values.shape[0]):
|
116 |
-
for j in range(values.shape[1]):
|
117 |
-
val = values[i, j]
|
118 |
-
color = "white" if im.norm(values.max()) / 2 > im.norm(val) else "black"
|
119 |
-
ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
|
120 |
-
|
121 |
-
return plt
|
|
|
26 |
|
27 |
# create the explanation graphic and plot
|
28 |
graphic = create_graphic(shap_values)
|
|
|
|
|
|
|
|
|
|
|
29 |
marked_text = markup_text(
|
30 |
shap_values.data[0], shap_values.values[0], variant="shap"
|
31 |
)
|
32 |
|
33 |
# create the response text
|
34 |
response_text = fmt.format_output_text(shap_values.output_names)
|
35 |
+
return response_text, graphic, marked_text
|
36 |
|
37 |
|
38 |
def wrap_shap(model):
|
|
|
62 |
graphic_html = plots.text(shap_values, display=False)
|
63 |
|
64 |
# return the html graphic as string
|
65 |
+
return str(graphic_html)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
explanation/markup.py
CHANGED
@@ -11,10 +11,13 @@ from utils import formatting as fmt
|
|
11 |
def markup_text(input_text: list, text_values: ndarray, variant: str):
|
12 |
bucket_tags = ["-5", "-4", "-3", "-2", "-1", "0", "+1", "+2", "+3", "+4", "+5"]
|
13 |
|
14 |
-
# Flatten the
|
|
|
15 |
if variant == "shap":
|
16 |
text_values = np.transpose(text_values)
|
17 |
-
|
|
|
|
|
18 |
|
19 |
# Determine the minimum and maximum values
|
20 |
min_val, max_val = np.min(text_values), np.max(text_values)
|
|
|
11 |
def markup_text(input_text: list, text_values: ndarray, variant: str):
|
12 |
bucket_tags = ["-5", "-4", "-3", "-2", "-1", "0", "+1", "+2", "+3", "+4", "+5"]
|
13 |
|
14 |
+
# Flatten the values depending on the source
|
15 |
+
# attention is averaged, SHAP summed up
|
16 |
if variant == "shap":
|
17 |
text_values = np.transpose(text_values)
|
18 |
+
text_values = fmt.flatten_attribution(text_values)
|
19 |
+
else:
|
20 |
+
text_values = fmt.flatten_attention(text_values)
|
21 |
|
22 |
# Determine the minimum and maximum values
|
23 |
min_val, max_val = np.min(text_values), np.max(text_values)
|
explanation/visualize.py
CHANGED
@@ -34,74 +34,11 @@ def chat_explained(model, prompt):
|
|
34 |
output_attentions=True,
|
35 |
)
|
36 |
|
37 |
-
averaged_attention = avg_attention(attention_output)
|
38 |
|
39 |
-
# create the response text
|
40 |
response_text = fmt.format_output_text(decoder_text)
|
41 |
-
plot = create_plot(averaged_attention, (encoder_text, decoder_text))
|
42 |
marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
|
43 |
|
44 |
-
return response_text, "",
|
45 |
|
46 |
-
|
47 |
-
# creating an attention heatmap plot using matplotlib/seaborn
|
48 |
-
# CREDIT: adopted from official Matplotlib documentation
|
49 |
-
## see https://matplotlib.org/stable/
|
50 |
-
def create_plot(averaged_attention_weights, enc_dec_texts: tuple):
|
51 |
-
# transpose the attention weights
|
52 |
-
averaged_attention_weights = np.transpose(averaged_attention_weights)
|
53 |
-
|
54 |
-
# get the encoder and decoder tokens in text form
|
55 |
-
encoder_tokens = enc_dec_texts[0]
|
56 |
-
decoder_tokens = enc_dec_texts[1]
|
57 |
-
|
58 |
-
# set seaborn style to dark and initialize figure and axis
|
59 |
-
sns.set(style="white")
|
60 |
-
fig, ax = plt.subplots()
|
61 |
-
|
62 |
-
# Setting figure size
|
63 |
-
fig.set_size_inches(
|
64 |
-
max(averaged_attention_weights.shape[1] * 2, 10),
|
65 |
-
max(averaged_attention_weights.shape[0] * 1, 5),
|
66 |
-
)
|
67 |
-
|
68 |
-
# Plotting the heatmap with seaborn's color palette
|
69 |
-
im = ax.imshow(
|
70 |
-
averaged_attention_weights,
|
71 |
-
vmax=averaged_attention_weights.max(),
|
72 |
-
vmin=-averaged_attention_weights.min(),
|
73 |
-
cmap=sns.color_palette("rocket", as_cmap=True),
|
74 |
-
aspect="auto",
|
75 |
-
)
|
76 |
-
|
77 |
-
# Creating colorbar
|
78 |
-
cbar = ax.figure.colorbar(im, ax=ax)
|
79 |
-
cbar.ax.set_ylabel("Attention Weight Scale", rotation=-90, va="bottom")
|
80 |
-
cbar.ax.yaxis.set_tick_params(color="black")
|
81 |
-
plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="black")
|
82 |
-
|
83 |
-
# Setting ticks and labels with black color for visibility
|
84 |
-
ax.set_yticks(np.arange(len(encoder_tokens)), labels=encoder_tokens)
|
85 |
-
ax.set_xticks(np.arange(len(decoder_tokens)), labels=decoder_tokens)
|
86 |
-
ax.set_title("Attention Weights by Token")
|
87 |
-
plt.setp(ax.get_xticklabels(), color="black", rotation=45, ha="right")
|
88 |
-
plt.setp(ax.get_yticklabels(), color="black")
|
89 |
-
|
90 |
-
# Adding text annotations with appropriate contrast
|
91 |
-
for i in range(averaged_attention_weights.shape[0]):
|
92 |
-
for j in range(averaged_attention_weights.shape[1]):
|
93 |
-
val = averaged_attention_weights[i, j]
|
94 |
-
color = (
|
95 |
-
"white"
|
96 |
-
if im.norm(averaged_attention_weights.max()) / 2 > im.norm(val)
|
97 |
-
else "black"
|
98 |
-
)
|
99 |
-
ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
|
100 |
-
|
101 |
-
# return the plot
|
102 |
-
return plt
|
103 |
-
|
104 |
-
|
105 |
-
def avg_attention(attention_values):
|
106 |
-
attention = attention_values.cross_attentions[0][0].detach().numpy()
|
107 |
-
return np.mean(attention, axis=0)
|
|
|
34 |
output_attentions=True,
|
35 |
)
|
36 |
|
37 |
+
averaged_attention = fmt.avg_attention(attention_output)
|
38 |
|
39 |
+
# create the response text and marked text for ui
|
40 |
response_text = fmt.format_output_text(decoder_text)
|
|
|
41 |
marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
|
42 |
|
43 |
+
return response_text, "", marked_text
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
CHANGED
@@ -180,6 +180,10 @@ with gr.Blocks(
|
|
180 |
" scripts: hieroglyphs, Demotic, and Greek."
|
181 |
),
|
182 |
],
|
|
|
|
|
|
|
|
|
183 |
],
|
184 |
inputs=[user_prompt, knowledge_input],
|
185 |
)
|
@@ -197,22 +201,14 @@ with gr.Blocks(
|
|
197 |
with gr.Row(variant="panel"):
|
198 |
# wraps the explanation html to display it statically
|
199 |
xai_interactive = iFrame(
|
200 |
-
label="
|
201 |
value=(
|
202 |
'<div style="text-align: center"><h4>No Graphic to Display'
|
203 |
" (Yet)</h4></div>"
|
204 |
),
|
|
|
205 |
show_label=True,
|
206 |
)
|
207 |
-
# row and accordion to display an explanation plot (if applicable)
|
208 |
-
with gr.Row():
|
209 |
-
with gr.Accordion("Token Wise Explanation Plot", open=False):
|
210 |
-
gr.Markdown("""
|
211 |
-
#### Plotted Values
|
212 |
-
Values have been excluded for readability. See colorbar for value indication.
|
213 |
-
""")
|
214 |
-
# plot component that takes a matplotlib figure as input
|
215 |
-
xai_plot = gr.Plot(label="Token Level Explanation")
|
216 |
|
217 |
# functions to trigger the controller
|
218 |
## takes information for the chat and the xai selection
|
@@ -221,13 +217,13 @@ with gr.Blocks(
|
|
221 |
submit_btn.click(
|
222 |
interference,
|
223 |
[user_prompt, chatbot, knowledge_input, system_prompt, xai_selection],
|
224 |
-
[user_prompt, chatbot, xai_interactive,
|
225 |
)
|
226 |
# function triggered by the enter key
|
227 |
user_prompt.submit(
|
228 |
interference,
|
229 |
[user_prompt, chatbot, knowledge_input, system_prompt, xai_selection],
|
230 |
-
[user_prompt, chatbot, xai_interactive,
|
231 |
)
|
232 |
|
233 |
# final row to show legal information
|
|
|
180 |
" scripts: hieroglyphs, Demotic, and Greek."
|
181 |
),
|
182 |
],
|
183 |
+
[
|
184 |
+
"Does money buy happiness?",
|
185 |
+
""
|
186 |
+
],
|
187 |
],
|
188 |
inputs=[user_prompt, knowledge_input],
|
189 |
)
|
|
|
201 |
with gr.Row(variant="panel"):
|
202 |
# wraps the explanation html to display it statically
|
203 |
xai_interactive = iFrame(
|
204 |
+
label="Interactive Explanation",
|
205 |
value=(
|
206 |
'<div style="text-align: center"><h4>No Graphic to Display'
|
207 |
" (Yet)</h4></div>"
|
208 |
),
|
209 |
+
height="600px",
|
210 |
show_label=True,
|
211 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
# functions to trigger the controller
|
214 |
## takes information for the chat and the xai selection
|
|
|
217 |
submit_btn.click(
|
218 |
interference,
|
219 |
[user_prompt, chatbot, knowledge_input, system_prompt, xai_selection],
|
220 |
+
[user_prompt, chatbot, xai_interactive, xai_text],
|
221 |
)
|
222 |
# function triggered by the enter key
|
223 |
user_prompt.submit(
|
224 |
interference,
|
225 |
[user_prompt, chatbot, knowledge_input, system_prompt, xai_selection],
|
226 |
+
[user_prompt, chatbot, xai_interactive, xai_text],
|
227 |
)
|
228 |
|
229 |
# final row to show legal information
|
public/about.md
CHANGED
@@ -7,7 +7,7 @@ This research tackles the rise of LLM based applications such a chatbots and exp
|
|
7 |
## Links
|
8 |
|
9 |
- [GitHub Repository](https://github.com/LennardZuendorf/thesis-webapp) - The GitHub repository of this project.
|
10 |
-
- [HTW Berlin](https://www.htw-berlin.de/) - The University I have built this project for, as part of my thesis.
|
11 |
|
12 |
|
13 |
## Implementation
|
|
|
7 |
## Links
|
8 |
|
9 |
- [GitHub Repository](https://github.com/LennardZuendorf/thesis-webapp) - The GitHub repository of this project.
|
10 |
+
- [HTW Berlin](https://www.htw-berlin.de/en/) - The University I have built this project for, as part of my thesis.
|
11 |
|
12 |
|
13 |
## Implementation
|
utils/formatting.py
CHANGED
@@ -66,5 +66,12 @@ def format_tokens(tokens: list):
|
|
66 |
|
67 |
|
68 |
# function to flatten values into a 2d list by averaging the explanation values
|
69 |
-
def
|
|
|
|
|
|
|
70 |
return np.mean(values, axis=axis)
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
|
68 |
# function to flatten values into a 2d list by averaging the explanation values
|
69 |
+
def flatten_attribution(values: ndarray, axis: int = 0):
|
70 |
+
return np.sum(values, axis=axis)
|
71 |
+
|
72 |
+
def flatten_attention(values: ndarray, axis: int = 0):
|
73 |
return np.mean(values, axis=axis)
|
74 |
+
|
75 |
+
def avg_attention(attention_values):
|
76 |
+
attention = attention_values.cross_attentions[0][0].detach().numpy()
|
77 |
+
return np.mean(attention, axis=0)
|