graph_spectrum / app.py
Narsil's picture
Narsil HF staff
Spectrum visalizer.
06e7970 verified
raw
history blame
2.1 kB
import gradio as gr
from transformers import pipeline
import numpy as np
import pandas as pd
import re
import torch
number_re = re.compile(r"\.[0-9]*\.")
STATE_DICT = {}
DATA = pd.DataFrame()
def scatter_plot_fn(group_name):
global DATA
df = DATA[DATA.group_name == group_name]
return gr.LinePlot.update(
value=df,
x="rank",
y="val",
color="layer",
tooltip=["val", "rank", "layer"],
caption="",
)
def find_choices(state_dict):
global DATA
layered_tensors = [
k for k, v in state_dict.items() if number_re.findall(k) and len(v.shape) == 2
]
choices = set()
data = []
for name in layered_tensors:
group_name = number_re.sub(".{N}.", name)
choices.add(group_name)
layer = int(number_re.search(name).group()[1:-1])
svdvals = torch.linalg.svdvals(state_dict[name])
svdvals /= svdvals.sum()
for rank, val in enumerate(svdvals.tolist()[:20]):
data.append((name, layer, group_name, rank, val))
data = np.array(data)
DATA = pd.DataFrame(data, columns=["name", "layer", "group_name", "rank", "val"])
DATA["val"] = DATA["val"].astype("float")
DATA["layer"] = DATA["layer"].astype("category")
DATA["rank"] = DATA["rank"].astype("int32")
return choices
def weights_fn(model_id):
global STATE_DICT
try:
pipe = pipeline(model=model_id)
STATE_DICT = pipe.model.state_dict()
except Exception as e:
print(e)
STATE_DICT = {}
choices = find_choices(STATE_DICT)
return gr.Dropdown.update(choices=choices)
with gr.Blocks() as scatter_plot:
with gr.Row():
with gr.Column():
model_id = gr.Textbox(value="gpt")
weights = gr.Dropdown(choices=["qkv", "c_fc"])
with gr.Column():
plot = gr.LinePlot(show_label=False).style(container=True)
model_id.change(weights_fn, inputs=model_id, outputs=weights)
weights.change(fn=scatter_plot_fn, inputs=weights, outputs=plot)
if __name__ == "__main__":
scatter_plot.launch()