Spaces:
Runtime error
Runtime error
feat: switched model to Mistral AI 7B
Browse files- app.py +45 -6
- chatmodel.py +37 -50
app.py
CHANGED
@@ -10,15 +10,54 @@ with gr.Blocks() as ui:
|
|
10 |
# Thesis Demo - AI Chat Application with XAI
|
11 |
### Select between tabs below for the different views.
|
12 |
""")
|
13 |
-
with gr.Tab("
|
14 |
with gr.Row():
|
15 |
gr.Markdown(
|
16 |
"""
|
17 |
### ChatBot Demo
|
18 |
-
|
19 |
""")
|
20 |
with gr.Row():
|
21 |
-
gr.ChatInterface(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
with gr.Tab("SHAP Dashboard"):
|
24 |
with gr.Row():
|
@@ -36,12 +75,12 @@ with gr.Blocks() as ui:
|
|
36 |
Visualization Dashboard adopted from [BERTViz](https://github.com/jessevig/bertviz)
|
37 |
""")
|
38 |
|
39 |
-
with gr.Tab("
|
40 |
with gr.Row():
|
41 |
gr.Markdown(
|
42 |
"""
|
43 |
-
###
|
44 |
-
Adopted from official [model paper](https://arxiv.org/abs/
|
45 |
""")
|
46 |
|
47 |
|
|
|
10 |
# Thesis Demo - AI Chat Application with XAI
|
11 |
### Select between tabs below for the different views.
|
12 |
""")
|
13 |
+
with gr.Tab("Mistral AI ChatBot"):
|
14 |
with gr.Row():
|
15 |
gr.Markdown(
|
16 |
"""
|
17 |
### ChatBot Demo
|
18 |
+
Mitral AI 7B Model fine-tuned for instruction and fully open source (see at [HGF](https://huggingface.co/mistralai/Mistral-7B-v0.1))
|
19 |
""")
|
20 |
with gr.Row():
|
21 |
+
gr.ChatInterface(
|
22 |
+
chat.interference
|
23 |
+
)
|
24 |
+
with gr.Row():
|
25 |
+
gr.Slider(
|
26 |
+
label="Temperature",
|
27 |
+
value=0.7,
|
28 |
+
minimum=0.0,
|
29 |
+
maximum=1.0,
|
30 |
+
step=0.05,
|
31 |
+
interactive=True,
|
32 |
+
info="Higher values produce more diverse outputs",
|
33 |
+
),
|
34 |
+
gr.Slider(
|
35 |
+
label="Max new tokens",
|
36 |
+
value=256,
|
37 |
+
minimum=0,
|
38 |
+
maximum=1024,
|
39 |
+
step=64,
|
40 |
+
interactive=True,
|
41 |
+
info="The maximum numbers of new tokens",
|
42 |
+
),
|
43 |
+
gr.Slider(
|
44 |
+
label="Top-p (nucleus sampling)",
|
45 |
+
value=0.95,
|
46 |
+
minimum=0.0,
|
47 |
+
maximum=1,
|
48 |
+
step=0.05,
|
49 |
+
interactive=True,
|
50 |
+
info="Higher values sample more low-probability tokens",
|
51 |
+
),
|
52 |
+
gr.Slider(
|
53 |
+
label="Repetition penalty",
|
54 |
+
value=1.1,
|
55 |
+
minimum=1.0,
|
56 |
+
maximum=2.0,
|
57 |
+
step=0.05,
|
58 |
+
interactive=True,
|
59 |
+
info="Penalize repeated tokens",
|
60 |
+
)
|
61 |
|
62 |
with gr.Tab("SHAP Dashboard"):
|
63 |
with gr.Row():
|
|
|
75 |
Visualization Dashboard adopted from [BERTViz](https://github.com/jessevig/bertviz)
|
76 |
""")
|
77 |
|
78 |
+
with gr.Tab("Mitral Model Overview"):
|
79 |
with gr.Row():
|
80 |
gr.Markdown(
|
81 |
"""
|
82 |
+
### Mistral 7B Model & Data Overview for Transparency
|
83 |
+
Adopted from official [model paper](https://arxiv.org/abs/2310.06825) by Mistral AI
|
84 |
""")
|
85 |
|
86 |
|
chatmodel.py
CHANGED
@@ -1,61 +1,48 @@
|
|
1 |
-
from
|
2 |
-
import torch
|
3 |
-
from transformers import AutoTokenizer
|
4 |
import os
|
|
|
5 |
|
6 |
token = os.environ.get("HGFTOKEN")
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
llama_pipeline = pipeline(
|
12 |
-
"text-generation",
|
13 |
-
model=model,
|
14 |
-
torch_dtype=torch.float32,
|
15 |
-
device_map="auto",
|
16 |
-
token = token
|
17 |
)
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
40 |
|
41 |
-
|
42 |
-
def interference(message: str, history: list, ) -> str:
|
43 |
-
system_prompt="You are a helpful assistant providing reasonable answers."
|
44 |
|
45 |
-
|
46 |
-
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
num_return_sequences=1,
|
53 |
-
eos_token_id=tokenizer.eos_token_id,
|
54 |
-
max_length=1024,
|
55 |
-
)
|
56 |
|
57 |
-
|
58 |
-
response = generated_text[len(query):] # Remove the prompt from the output
|
59 |
|
60 |
-
|
61 |
-
return response.strip()
|
|
|
1 |
+
from huggingface_hub import InferenceClient
|
|
|
|
|
2 |
import os
|
3 |
+
import gradio as gr
|
4 |
|
5 |
token = os.environ.get("HGFTOKEN")
|
6 |
|
7 |
+
client = InferenceClient(
|
8 |
+
"mistralai/Mistral-7B-Instruct-v0.1"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
)
|
10 |
|
11 |
+
def format_prompt(message, history):
|
12 |
+
prompt = "<s>"
|
13 |
+
for user_prompt, bot_response in history:
|
14 |
+
prompt += f"[INST] {user_prompt} [/INST]"
|
15 |
+
prompt += f" {bot_response}</s> "
|
16 |
+
prompt += f"[INST] {message} [/INST]"
|
17 |
+
return prompt
|
18 |
+
|
19 |
+
def interference(
|
20 |
+
prompt, history, temperature=0.7, max_new_tokens=256, top_p=0.95, repetition_penalty=1.1,
|
21 |
+
):
|
22 |
+
temperature = float(temperature)
|
23 |
+
if temperature < 1e-2:
|
24 |
+
temperature = 1e-2
|
25 |
+
top_p = float(top_p)
|
26 |
+
|
27 |
+
generate_kwargs = dict(
|
28 |
+
temperature=temperature,
|
29 |
+
max_new_tokens=max_new_tokens,
|
30 |
+
top_p=top_p,
|
31 |
+
repetition_penalty=repetition_penalty,
|
32 |
+
do_sample=True,
|
33 |
+
seed=42,
|
34 |
+
)
|
35 |
|
36 |
+
formatted_prompt = format_prompt(prompt, history)
|
|
|
|
|
37 |
|
38 |
+
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
39 |
+
output = ""
|
40 |
|
41 |
+
for response in stream:
|
42 |
+
output += response.token.text
|
43 |
+
yield output
|
44 |
+
return output
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
custom=[
|
|
|
47 |
|
48 |
+
]
|
|