File size: 1,984 Bytes
d134af5
a9cc2b2
d134af5
 
a9cc2b2
 
d134af5
 
 
 
 
a9cc2b2
 
 
 
 
 
 
 
d134af5
a9cc2b2
 
 
 
 
 
 
01d5d64
a9cc2b2
 
 
 
d134af5
 
a9cc2b2
 
 
 
 
 
 
 
 
 
 
d134af5
 
 
 
 
 
 
 
a9cc2b2
d134af5
 
 
 
 
 
a9cc2b2
d134af5
 
 
 
 
 
 
 
 
a9cc2b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import dash
import plotly.express as px
from dash import dcc, html
from dash.dependencies import Input, Output
from dash.exceptions import PreventUpdate
from datasets import load_dataset

# Create dash app
app = dash.Dash(__name__)


def get_dataset(name, n_items=1000):
    ola_path = f"ola13/small-{name}-dedup"
    dataset = load_dataset(ola_path, split="train").shuffle().select(range(n_items)).to_pandas()
    dataset["text_length"] = dataset.apply(lambda doc: len(doc["text"]), axis=1)

    for column in dataset.columns:
        if column not in ["text", "perplexity", "text_length"]:
            dataset = dataset.drop(column, axis=1)

    dataset = dataset.sort_values("perplexity")

    max_perp = dataset["perplexity"].max()
    return dataset, max_perp


# names = ["oscar", "the_pile", "c4", "roots_en"]
name = "c4"
df, max_perplexity = get_dataset(name)

# Create scatter plot with x and y coordinates
fig = px.scatter(df, x="perplexity", y="text_length", custom_data=["text"])
# Update layout and update traces
fig.update_layout(clickmode='event+select')
fig.update_traces(marker_size=3)
fig.update_xaxes(title_text="Perplexity (log scale)", type="log")
fig.update_yaxes(title_text="Text Length (log scale)", type="log")

styles = {
    'textbox': {
        'border': 'thin lightgrey solid',
        'overflowX': 'scroll',
        "whiteSpace": "pre-wrap;"
    }
}

# Create app layout to show dash graph
app.layout = html.Div(
    [
        dcc.Graph(
            id="graph_interaction",
            figure=fig,
        ),
        html.Div(id='text', style=styles['textbox'])
    ]
)


# html callback function to hover the data on specific coordinates
@app.callback(
    Output('text', 'children'),
    Input('graph_interaction', 'hoverData'))
def open_url(hoverData):
    if hoverData:
        return hoverData["points"][0]["customdata"][0]
    else:
        raise PreventUpdate


if __name__ == '__main__':
    app.run_server(port=7860, host="0.0.0.0", debug=True)