Spaces:
Runtime error
Runtime error
Commit
·
229e14c
1
Parent(s):
11174d4
feat/fix: fixing code issues, adding plotting functions
Browse files- .gitignore +1 -0
- backend/controller.py +3 -3
- explanation/interpret_captum.py +0 -40
- explanation/interpret_shap.py +0 -72
- explanation/markup.py +12 -12
- explanation/plotting.py +0 -0
- explanation/visualize.py +0 -52
- explanation/visualize_att.py +0 -0
- model/mistral.py +10 -10
- requirements.txt +1 -3
.gitignore
CHANGED
@@ -2,3 +2,4 @@
|
|
2 |
__pycache__/
|
3 |
/start-venv.sh
|
4 |
/components/iframe/dist/
|
|
|
|
2 |
__pycache__/
|
3 |
/start-venv.sh
|
4 |
/components/iframe/dist/
|
5 |
+
.venv
|
backend/controller.py
CHANGED
@@ -10,7 +10,7 @@ from model import mistral
|
|
10 |
from explanation import (
|
11 |
interpret_shap as shap_int,
|
12 |
interpret_captum as cpt_int,
|
13 |
-
|
14 |
)
|
15 |
|
16 |
|
@@ -33,10 +33,10 @@ def interference(
|
|
33 |
|
34 |
if model_selection.lower() == "mistral":
|
35 |
model = mistral
|
36 |
-
print("
|
37 |
else:
|
38 |
model = godel
|
39 |
-
print("
|
40 |
|
41 |
# if a XAI approach is selected, grab the XAI module instance
|
42 |
if xai_selection in ("SHAP", "Attention"):
|
|
|
10 |
from explanation import (
|
11 |
interpret_shap as shap_int,
|
12 |
interpret_captum as cpt_int,
|
13 |
+
visualize_att as viz,
|
14 |
)
|
15 |
|
16 |
|
|
|
33 |
|
34 |
if model_selection.lower() == "mistral":
|
35 |
model = mistral
|
36 |
+
print("Indentified model as Mistral")
|
37 |
else:
|
38 |
model = godel
|
39 |
+
print("Indentified model as GODEL")
|
40 |
|
41 |
# if a XAI approach is selected, grab the XAI module instance
|
42 |
if xai_selection in ("SHAP", "Attention"):
|
explanation/interpret_captum.py
CHANGED
@@ -1,40 +0,0 @@
|
|
1 |
-
# external imports
|
2 |
-
from captum.attr import LLMAttribution, TextTokenInput, KernelShap
|
3 |
-
import torch
|
4 |
-
|
5 |
-
# internal imports
|
6 |
-
from utils import formatting as fmt
|
7 |
-
from .markup import markup_text
|
8 |
-
|
9 |
-
|
10 |
-
# main explain function that returns a chat with explanations
|
11 |
-
def chat_explained(model, prompt):
|
12 |
-
model.set_config({})
|
13 |
-
|
14 |
-
# creating llm attribution class with KernelSHAP and Mistal Model, Tokenizer
|
15 |
-
llm_attribution = LLMAttribution(KernelShap(model.MODEL), model.TOKENIZER)
|
16 |
-
|
17 |
-
# generation attribution
|
18 |
-
attribution_input = TextTokenInput(prompt, model.TOKENIZER)
|
19 |
-
attribution_result = llm_attribution.attribute(
|
20 |
-
attribution_input, gen_args=model.CONFIG.to_dict()
|
21 |
-
)
|
22 |
-
|
23 |
-
# extracting values and input tokens
|
24 |
-
values = attribution_result.seq_attr.to(torch.device("cpu")).numpy()
|
25 |
-
input_tokens = fmt.format_tokens(attribution_result.input_tokens)
|
26 |
-
|
27 |
-
# raising error if mismatch occurs
|
28 |
-
if len(attribution_result.input_tokens) != len(values):
|
29 |
-
raise RuntimeError("values and input len mismatch")
|
30 |
-
|
31 |
-
# getting response text, graphic placeholder and marked text object
|
32 |
-
response_text = fmt.format_output_text(attribution_result.output_tokens)
|
33 |
-
graphic = (
|
34 |
-
"<div style='text-align: center; font-family:arial;'><h4>Attention"
|
35 |
-
"Intepretation with Captum doesn't support an interactive graphic.</h4></div>"
|
36 |
-
)
|
37 |
-
marked_text = markup_text(input_tokens, values, variant="captum")
|
38 |
-
|
39 |
-
# return response, graphic and marked_text array
|
40 |
-
return response_text, graphic, marked_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
explanation/interpret_shap.py
CHANGED
@@ -1,72 +0,0 @@
|
|
1 |
-
# interpret module that implements the interpretability method
|
2 |
-
|
3 |
-
# external imports
|
4 |
-
from shap import models, maskers, plots, PartitionExplainer
|
5 |
-
import torch
|
6 |
-
|
7 |
-
# internal imports
|
8 |
-
from utils import formatting as fmt
|
9 |
-
from .markup import markup_text
|
10 |
-
|
11 |
-
# global variables
|
12 |
-
TEACHER_FORCING = None
|
13 |
-
TEXT_MASKER = None
|
14 |
-
|
15 |
-
|
16 |
-
# main explain function that returns a chat with explanations
|
17 |
-
def chat_explained(model, prompt):
|
18 |
-
model.set_config({})
|
19 |
-
|
20 |
-
# create the shap explainer
|
21 |
-
shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
|
22 |
-
|
23 |
-
# get the shap values for the prompt
|
24 |
-
shap_values = shap_explainer([prompt])
|
25 |
-
|
26 |
-
# create the explanation graphic and marked text array
|
27 |
-
graphic = create_graphic(shap_values)
|
28 |
-
marked_text = markup_text(
|
29 |
-
shap_values.data[0], shap_values.values[0], variant="shap"
|
30 |
-
)
|
31 |
-
|
32 |
-
# create the response text
|
33 |
-
response_text = fmt.format_output_text(shap_values.output_names)
|
34 |
-
|
35 |
-
# return response, graphic and marked_text array
|
36 |
-
return response_text, graphic, marked_text
|
37 |
-
|
38 |
-
|
39 |
-
# function used to wrap the model with a shap model
|
40 |
-
def wrap_shap(model):
|
41 |
-
# calling global variants
|
42 |
-
global TEXT_MASKER, TEACHER_FORCING
|
43 |
-
|
44 |
-
# set the device to cuda if gpu is available
|
45 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
46 |
-
|
47 |
-
# updating the model settings
|
48 |
-
model.set_config()
|
49 |
-
|
50 |
-
# (re)initialize the shap models and masker
|
51 |
-
# creating a shap text_generation model
|
52 |
-
text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER)
|
53 |
-
# wrapping the text generation model in a teacher forcing model
|
54 |
-
TEACHER_FORCING = models.TeacherForcing(
|
55 |
-
text_generation,
|
56 |
-
model.TOKENIZER,
|
57 |
-
device=str(device),
|
58 |
-
similarity_model=model.MODEL,
|
59 |
-
similarity_tokenizer=model.TOKENIZER,
|
60 |
-
)
|
61 |
-
# setting the text masker as an empty string
|
62 |
-
TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True)
|
63 |
-
|
64 |
-
|
65 |
-
# graphic plotting function that creates a html graphic (as string) for the explanation
|
66 |
-
def create_graphic(shap_values):
|
67 |
-
|
68 |
-
# create the html graphic using shap text plot function
|
69 |
-
graphic_html = plots.text(shap_values, display=False)
|
70 |
-
|
71 |
-
# return the html graphic as string to display in iFrame
|
72 |
-
return str(graphic_html)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
explanation/markup.py
CHANGED
@@ -66,16 +66,16 @@ def color_codes():
|
|
66 |
return {
|
67 |
# -5 to -1: Strong Light Sky Blue to Lighter Sky Blue
|
68 |
# 0: white (assuming default light mode)
|
69 |
-
# +1 to +5 light pink to
|
70 |
-
"-5": "#
|
71 |
-
"-4": "#
|
72 |
-
"-3": "#
|
73 |
-
"-2": "#
|
74 |
-
"-1": "#
|
75 |
-
"0": "#
|
76 |
-
"
|
77 |
-
"
|
78 |
-
"
|
79 |
-
"
|
80 |
-
"
|
81 |
}
|
|
|
66 |
return {
|
67 |
# -5 to -1: Strong Light Sky Blue to Lighter Sky Blue
|
68 |
# 0: white (assuming default light mode)
|
69 |
+
# +1 to +5 light pink to strng magenta
|
70 |
+
"-5": "#008bfb",
|
71 |
+
"-4": "#68a1fd",
|
72 |
+
"-3": "#96b7fe",
|
73 |
+
"-2": "#bcceff",
|
74 |
+
"-1:": "#dee6ff",
|
75 |
+
"0": "#ffffff",
|
76 |
+
"1": "#ffd9d9",
|
77 |
+
"2": "#ffb3b5",
|
78 |
+
"3": "#ff8b92",
|
79 |
+
"4": "#ff5c71",
|
80 |
+
"5": "#ff0051",
|
81 |
}
|
explanation/plotting.py
ADDED
File without changes
|
explanation/visualize.py
DELETED
@@ -1,52 +0,0 @@
|
|
1 |
-
# visualization module that creates an attention visualization
|
2 |
-
|
3 |
-
|
4 |
-
# internal imports
|
5 |
-
from utils import formatting as fmt
|
6 |
-
from .markup import markup_text
|
7 |
-
|
8 |
-
|
9 |
-
# chat function that returns an answer
|
10 |
-
# and marked text based on attention
|
11 |
-
def chat_explained(model, prompt):
|
12 |
-
|
13 |
-
# get encoded input
|
14 |
-
encoder_input_ids = model.TOKENIZER(
|
15 |
-
prompt, return_tensors="pt", add_special_tokens=True
|
16 |
-
).input_ids
|
17 |
-
# generate output together with attentions of the model
|
18 |
-
decoder_input_ids = model.MODEL.generate(
|
19 |
-
encoder_input_ids, output_attentions=True, **model.CONFIG
|
20 |
-
)
|
21 |
-
|
22 |
-
# get input and output text as list of strings
|
23 |
-
encoder_text = fmt.format_tokens(
|
24 |
-
model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
|
25 |
-
)
|
26 |
-
decoder_text = fmt.format_tokens(
|
27 |
-
model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0])
|
28 |
-
)
|
29 |
-
|
30 |
-
# get attention values for the input and output vectors
|
31 |
-
# using already generated input and output
|
32 |
-
attention_output = model.MODEL(
|
33 |
-
input_ids=encoder_input_ids,
|
34 |
-
decoder_input_ids=decoder_input_ids,
|
35 |
-
output_attentions=True,
|
36 |
-
)
|
37 |
-
|
38 |
-
# averaging attention across layers
|
39 |
-
averaged_attention = fmt.avg_attention(attention_output)
|
40 |
-
|
41 |
-
# format response text for clean output
|
42 |
-
response_text = fmt.format_output_text(decoder_text)
|
43 |
-
# setting placeholder for iFrame graphic
|
44 |
-
graphic = (
|
45 |
-
"<div style='text-align: center; font-family:arial;'><h4>Attention"
|
46 |
-
" Visualization doesn't support an interactive graphic.</h4></div>"
|
47 |
-
)
|
48 |
-
# creating marked text using markup_text function and attention
|
49 |
-
marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
|
50 |
-
|
51 |
-
# returning response, graphic and marked text array
|
52 |
-
return response_text, graphic, marked_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
explanation/visualize_att.py
ADDED
File without changes
|
model/mistral.py
CHANGED
@@ -41,13 +41,11 @@ CONFIG.update(**{
|
|
41 |
|
42 |
|
43 |
# function to (re) set config
|
44 |
-
def set_config(
|
45 |
|
46 |
-
# if config dict is given,
|
47 |
-
if
|
48 |
-
|
49 |
-
else:
|
50 |
-
CONFIG.update(**{
|
51 |
"temperature": 0.7,
|
52 |
"max_new_tokens": 50,
|
53 |
"max_length": 50,
|
@@ -55,7 +53,9 @@ def set_config(config: dict):
|
|
55 |
"repetition_penalty": 1.2,
|
56 |
"do_sample": True,
|
57 |
"seed": 42,
|
58 |
-
}
|
|
|
|
|
59 |
|
60 |
|
61 |
# advanced formatting function that takes into a account a conversation history
|
@@ -77,9 +77,9 @@ def format_prompt(message: str, history: list, system_prompt: str, knowledge: st
|
|
77 |
"""
|
78 |
else:
|
79 |
# takes the very first exchange and the system prompt as base
|
80 |
-
prompt =
|
81 |
-
|
82 |
-
|
83 |
|
84 |
# adds conversation history to the prompt
|
85 |
for conversation in history[1:]:
|
|
|
41 |
|
42 |
|
43 |
# function to (re) set config
|
44 |
+
def set_config(config_dict: dict):
|
45 |
|
46 |
+
# if config dict is not given, set to default
|
47 |
+
if config_dict == {}:
|
48 |
+
config_dict = {
|
|
|
|
|
49 |
"temperature": 0.7,
|
50 |
"max_new_tokens": 50,
|
51 |
"max_length": 50,
|
|
|
53 |
"repetition_penalty": 1.2,
|
54 |
"do_sample": True,
|
55 |
"seed": 42,
|
56 |
+
}
|
57 |
+
|
58 |
+
CONFIG.update(**dict)
|
59 |
|
60 |
|
61 |
# advanced formatting function that takes into a account a conversation history
|
|
|
77 |
"""
|
78 |
else:
|
79 |
# takes the very first exchange and the system prompt as base
|
80 |
+
prompt = f"""
|
81 |
+
<s>[INST] {system_prompt} {history[0][0]} [/INST] {history[0][1]}</s>
|
82 |
+
"""
|
83 |
|
84 |
# adds conversation history to the prompt
|
85 |
for conversation in history[1:]:
|
requirements.txt
CHANGED
@@ -2,7 +2,7 @@ gradio~=4.7.1
|
|
2 |
transformers~=4.35.2
|
3 |
torch~=2.1.1
|
4 |
shap
|
5 |
-
captum
|
6 |
bertviz~=1.4.0
|
7 |
accelerate~=0.24.1
|
8 |
bitsandbytes
|
@@ -13,9 +13,7 @@ uvicorn~=0.24.0
|
|
13 |
tinydb~=4.8.0
|
14 |
black~=23.12.0
|
15 |
pylint~=3.0.0
|
16 |
-
seaborn~=0.13.0
|
17 |
numpy
|
18 |
matplotlib
|
19 |
pre-commit
|
20 |
-
ipython
|
21 |
gradio-iframe~=0.0.10
|
|
|
2 |
transformers~=4.35.2
|
3 |
torch~=2.1.1
|
4 |
shap
|
5 |
+
captum @ git+https://github.com/LennardZuendorf/thesis-captum.git
|
6 |
bertviz~=1.4.0
|
7 |
accelerate~=0.24.1
|
8 |
bitsandbytes
|
|
|
13 |
tinydb~=4.8.0
|
14 |
black~=23.12.0
|
15 |
pylint~=3.0.0
|
|
|
16 |
numpy
|
17 |
matplotlib
|
18 |
pre-commit
|
|
|
19 |
gradio-iframe~=0.0.10
|