import dash import plotly.express as px from dash import dcc, html from dash.dependencies import Input, Output from dash.exceptions import PreventUpdate from datasets import load_dataset # Create dash app app = dash.Dash(__name__) def get_dataset(name, n_items=1000): ola_path = f"ola13/small-{name}-dedup" dataset = load_dataset(ola_path, split="train").shuffle().select(range(n_items)).to_pandas() dataset["text_length"] = dataset.apply(lambda doc: len(doc["text"]), axis=1) for column in dataset.columns: if column not in ["text", "perplexity", "text_length"]: dataset = dataset.drop(column, axis=1) dataset = dataset.sort_values("perplexity") max_perp = dataset["perplexity"].max() return dataset, max_perp # names = ["oscar", "the_pile", "c4", "roots_en"] name = "c4" df, max_perplexity = get_dataset(name) # Create scatter plot with x and y coordinates fig = px.scatter(df, x="perplexity", y="text_length", custom_data=["text"]) # Update layout and update traces fig.update_layout(clickmode='event+select') fig.update_traces(marker_size=3) fig.update_xaxes(title_text="Perplexity (log scale)", type="log") fig.update_yaxes(title_text="Text Length (log scale)", type="log") styles = { 'textbox': { 'border': 'thin lightgrey solid', 'overflowX': 'scroll', "whiteSpace": "pre-wrap;" } } # Create app layout to show dash graph app.layout = html.Div( [ dcc.Graph( id="graph_interaction", figure=fig, ), html.Div(id='text', style=styles['textbox']) ] ) # html callback function to hover the data on specific coordinates @app.callback( Output('text', 'children'), Input('graph_interaction', 'hoverData')) def open_url(hoverData): if hoverData: return hoverData["points"][0]["customdata"][0] else: raise PreventUpdate if __name__ == '__main__': app.run_server(port=7860, host="0.0.0.0", debug=True)