Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import io
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import struct
|
8 |
+
from contextlib import contextmanager
|
9 |
+
from html import escape
|
10 |
+
|
11 |
+
import msgpack
|
12 |
+
import streamlit as st
|
13 |
+
import torch
|
14 |
+
import tqdm
|
15 |
+
from huggingface_hub import HfFileSystem
|
16 |
+
from transformers import AutoTokenizer
|
17 |
+
|
18 |
+
st.set_page_config(layout="wide")
|
19 |
+
|
20 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "MonetLLM/monet-vd-1.4B-100BT-hf")
|
21 |
+
CONTEXT_WINDOW = int(os.environ.get("CONTEXT_WINDOW", "12"))
|
22 |
+
CANDIDATE_THRESHOLD = int(os.environ.get("CANDIDATE_THRESHOLD", "50"))
|
23 |
+
|
24 |
+
HORIZONTAL_STYLE = """<style class="hide-element">
|
25 |
+
/* Hides the style container and removes the extra spacing */
|
26 |
+
.element-container:has(.hide-element) {
|
27 |
+
display: none;
|
28 |
+
}
|
29 |
+
/*
|
30 |
+
The selector for >.element-container is necessary to avoid selecting the whole
|
31 |
+
body of the streamlit app, which is also a stVerticalBlock.
|
32 |
+
*/
|
33 |
+
div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) {
|
34 |
+
display: flex;
|
35 |
+
flex-direction: row !important;
|
36 |
+
flex-wrap: wrap;
|
37 |
+
gap: 0.5rem;
|
38 |
+
align-items: baseline;
|
39 |
+
}
|
40 |
+
/* Buttons and their parent container all have a width of 704px, which we need to override */
|
41 |
+
div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) div {
|
42 |
+
width: max-content !important;
|
43 |
+
}
|
44 |
+
/* Just an example of how you would style buttons, if desired */
|
45 |
+
/*
|
46 |
+
div[data-testid="stVerticalBlock"]:has(> .element-container .horizontal-marker) button {
|
47 |
+
border-color: red;
|
48 |
+
}
|
49 |
+
*/
|
50 |
+
</style>"""
|
51 |
+
|
52 |
+
|
53 |
+
@st.cache_resource
|
54 |
+
def prepare_routing_resources():
|
55 |
+
fs = HfFileSystem()
|
56 |
+
for filename in fs.glob(f"datasets/{MODEL_NAME}-viewer-data/*"):
|
57 |
+
if not os.path.exists(os.path.basename(filename)):
|
58 |
+
print(f"[*] Download {filename}...")
|
59 |
+
fs.download(filename, ".")
|
60 |
+
|
61 |
+
input_tokens = torch.load("inputs.pt")
|
62 |
+
|
63 |
+
examples_tables = []
|
64 |
+
for i in tqdm.trange(len(glob.glob("examples-*.msgpack"))):
|
65 |
+
with open(f"examples-{i}.msgpack", "rb") as fp:
|
66 |
+
fp.seek(-4, io.SEEK_END)
|
67 |
+
table_size = struct.unpack(">I", fp.read(4))[0]
|
68 |
+
|
69 |
+
fp.seek(-(table_size + 4), io.SEEK_END)
|
70 |
+
examples_tables.append(msgpack.Unpacker(fp).unpack())
|
71 |
+
|
72 |
+
candidates = []
|
73 |
+
for i, table in enumerate(tqdm.tqdm(examples_tables)):
|
74 |
+
candidates.append([])
|
75 |
+
with open(f"examples-{i}.msgpack", "rb") as fp:
|
76 |
+
unpacker = msgpack.Unpacker(fp)
|
77 |
+
for j in range(len(table)):
|
78 |
+
if len(unpacker.unpack()) > CANDIDATE_THRESHOLD:
|
79 |
+
candidates[-1].append(j)
|
80 |
+
|
81 |
+
routing_tables = []
|
82 |
+
for i in tqdm.trange(len(examples_tables)):
|
83 |
+
with open(f"routings-{i}.msgpack", "rb") as fp:
|
84 |
+
fp.seek(-4, io.SEEK_END)
|
85 |
+
table_size = struct.unpack(">I", fp.read(4))[0]
|
86 |
+
|
87 |
+
fp.seek(-(table_size + 4), io.SEEK_END)
|
88 |
+
routing_tables.append(msgpack.Unpacker(fp).unpack())
|
89 |
+
|
90 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
91 |
+
return input_tokens, examples_tables, routing_tables, candidates, tokenizer
|
92 |
+
|
93 |
+
|
94 |
+
input_tokens, examples_tables, routing_tables, candidates, tokenizer = (
|
95 |
+
prepare_routing_resources()
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
def render_routing_examples_in_html(router_index: int, expert_id: int) -> str:
|
100 |
+
with open(f"examples-{router_index}.msgpack", "rb") as fp:
|
101 |
+
fp.seek(examples_tables[router_index][expert_id])
|
102 |
+
examples = msgpack.Unpacker(fp).unpack()
|
103 |
+
with open(f"routings-{router_index}.msgpack", "rb") as fp:
|
104 |
+
table = []
|
105 |
+
for i, j, _ in examples:
|
106 |
+
start = max(j - CONTEXT_WINDOW, 0)
|
107 |
+
end = min(j + CONTEXT_WINDOW, len(routing_tables[router_index][i]))
|
108 |
+
|
109 |
+
fp.seek(routing_tables[router_index][i][start])
|
110 |
+
unpacker = msgpack.Unpacker(fp, strict_map_key=False)
|
111 |
+
activated = [unpacker.unpack().get(expert_id, 0) for _ in range(start, end)]
|
112 |
+
|
113 |
+
full_text = tokenizer.decode(input_tokens[i])
|
114 |
+
encodings = tokenizer(full_text, add_special_tokens=False)
|
115 |
+
offset = len(encodings.input_ids) - input_tokens.size(1)
|
116 |
+
|
117 |
+
spans, lslice = [], None
|
118 |
+
for k in range(start, end):
|
119 |
+
if offset + k >= 0 and (sslice := encodings.token_to_chars(offset + k)):
|
120 |
+
span, score = full_text[slice(*sslice)], activated[k - start]
|
121 |
+
if lslice == sslice:
|
122 |
+
score = max(spans.pop(-1)[1], score)
|
123 |
+
spans.append((escape(span), score))
|
124 |
+
lslice = sslice
|
125 |
+
|
126 |
+
spans = [
|
127 |
+
f"<span style='background-color: rgba(144, 238, 144, {score}' title='Routing: {score*100:.2f}%'>{span}</span>"
|
128 |
+
for span, score in spans
|
129 |
+
]
|
130 |
+
table.append(
|
131 |
+
f"""
|
132 |
+
<tr>
|
133 |
+
<td align='right'>
|
134 |
+
<span style='font-weight: bold'>
|
135 |
+
{escape(tokenizer.decode(input_tokens[i, j]))} ({activated[j - start] * 100:.2f}%)
|
136 |
+
</span>
|
137 |
+
</td>
|
138 |
+
<td align='left'>
|
139 |
+
(...) {"".join(spans)} (...)
|
140 |
+
</td>
|
141 |
+
<td align='right'>
|
142 |
+
({i}, {j})
|
143 |
+
</td>
|
144 |
+
</tr>
|
145 |
+
"""
|
146 |
+
)
|
147 |
+
|
148 |
+
return f"""
|
149 |
+
<div style='background-color: white; color: black; padding: 1em 3em; font-size: 12pt'>
|
150 |
+
<h2 style='font-size: 18pt'> Activated Examples of Group {router_index} / Expert {expert_id} </h2>
|
151 |
+
<table>
|
152 |
+
{"".join(table)}
|
153 |
+
</table>
|
154 |
+
</div>
|
155 |
+
"""
|
156 |
+
|
157 |
+
|
158 |
+
@contextmanager
|
159 |
+
def st_horizontal():
|
160 |
+
st.markdown(HORIZONTAL_STYLE, unsafe_allow_html=True)
|
161 |
+
with st.container():
|
162 |
+
st.markdown(
|
163 |
+
'<span class="hide-element horizontal-marker"></span>',
|
164 |
+
unsafe_allow_html=True,
|
165 |
+
)
|
166 |
+
yield
|
167 |
+
|
168 |
+
|
169 |
+
col1, col2 = st.columns(2)
|
170 |
+
with col1:
|
171 |
+
router_groups = [f"Routing Group {i}" for i in range(len(examples_tables))]
|
172 |
+
router_index = st.selectbox("Expert Routing Group", router_groups, index=4)
|
173 |
+
with col2:
|
174 |
+
expert_id = st.number_input("Expert Index", 0, len(examples_tables[0]), 54136)
|
175 |
+
|
176 |
+
with st_horizontal():
|
177 |
+
show_btn = st.button("Show")
|
178 |
+
random_btn = st.button("Random")
|
179 |
+
|
180 |
+
if show_btn or random_btn:
|
181 |
+
router_index = router_groups.index(router_index)
|
182 |
+
if random_btn:
|
183 |
+
expert_id = random.choice(candidates[router_index])
|
184 |
+
st.html(render_routing_examples_in_html(router_index, expert_id))
|