teven's picture
switching to c4
01d5d64
raw
history blame contribute delete
No virus
1.98 kB
import dash
import plotly.express as px
from dash import dcc, html
from dash.dependencies import Input, Output
from dash.exceptions import PreventUpdate
from datasets import load_dataset
# Create dash app
app = dash.Dash(__name__)
def get_dataset(name, n_items=1000):
ola_path = f"ola13/small-{name}-dedup"
dataset = load_dataset(ola_path, split="train").shuffle().select(range(n_items)).to_pandas()
dataset["text_length"] = dataset.apply(lambda doc: len(doc["text"]), axis=1)
for column in dataset.columns:
if column not in ["text", "perplexity", "text_length"]:
dataset = dataset.drop(column, axis=1)
dataset = dataset.sort_values("perplexity")
max_perp = dataset["perplexity"].max()
return dataset, max_perp
# names = ["oscar", "the_pile", "c4", "roots_en"]
name = "c4"
df, max_perplexity = get_dataset(name)
# Create scatter plot with x and y coordinates
fig = px.scatter(df, x="perplexity", y="text_length", custom_data=["text"])
# Update layout and update traces
fig.update_layout(clickmode='event+select')
fig.update_traces(marker_size=3)
fig.update_xaxes(title_text="Perplexity (log scale)", type="log")
fig.update_yaxes(title_text="Text Length (log scale)", type="log")
styles = {
'textbox': {
'border': 'thin lightgrey solid',
'overflowX': 'scroll',
"whiteSpace": "pre-wrap;"
}
}
# Create app layout to show dash graph
app.layout = html.Div(
[
dcc.Graph(
id="graph_interaction",
figure=fig,
),
html.Div(id='text', style=styles['textbox'])
]
)
# html callback function to hover the data on specific coordinates
@app.callback(
Output('text', 'children'),
Input('graph_interaction', 'hoverData'))
def open_url(hoverData):
if hoverData:
return hoverData["points"][0]["customdata"][0]
else:
raise PreventUpdate
if __name__ == '__main__':
app.run_server(port=7860, host="0.0.0.0", debug=True)