wimbd / app.py
yanaiela's picture
more docs, and default string value
38d6db9
raw
history blame
4.86 kB
import os
from functools import lru_cache
import gradio as gr
import plotly.graph_objects as go
from wimbd.es import es_init, count_documents_containing_phrases
es = es_init(None, os.getenv("lm_datasets_cloud_id"), os.getenv("lm_datasets_api_key"))
es_dolma = es_init(None, os.getenv("dolma_cloud_id"), os.getenv("dolma_api_key"))
datasets = ["OpenWebText", "C4", "OSCAR", "The Pile", "LAION-2B-en", "Dolma"]
dataset_es_map = {
"OSCAR": "re_oscar",
"LAION-2B-en": "re_laion2b-en-*",
"LAION-5B": "*laion2b*",
"OpenWebText": "openwebtext",
"The Pile": "re_pile",
"C4": "c4",
"Dolma v1.5": "docs_v1.5_2023-11-02",
"Dolma v1.7": "docs_v1.7_2024-06-04",
"Tulu v2": "tulu-v2-sft-mixture",
}
default_checked = ["C4", "The Pile", "Dolma v1.7"] # Datasets to be checked by default
@lru_cache()
def get_counts(index_name, phrase, es):
return count_documents_containing_phrases(index_name, phrase, es=es)
def process_input(phrases, *dataset_choices):
results = []
for dataset_name, index_name, is_selected in zip(
dataset_es_map.keys(), dataset_es_map.values(), dataset_choices
):
if is_selected:
for phrase in phrases.split("\n"):
phrase = phrase.strip()
if phrase:
if "dolma" in dataset_name.lower():
count = get_counts(index_name, phrase, es=es_dolma)
else:
count = get_counts(index_name, phrase, es=es)
results.append((dataset_name, phrase, count))
# Format results for different output components
table_data = [[dataset, phrase, str(count)] for dataset, phrase, count in results]
# Create bar chart using plotly
fig = go.Figure()
for phrase in set([r[1] for r in results]):
dataset_names = [r[0] for r in results if r[1] == phrase]
counts = [r[2] for r in results if r[1] == phrase]
fig.add_trace(go.Bar(x=dataset_names, y=counts, name=phrase))
fig.update_layout(
title="Document Counts by Dataset and Phrase",
xaxis_title="Dataset",
yaxis_title="Count",
barmode="group",
)
# return table_data, markdown_text, fig
return table_data, fig
citation_text = """If you find this tool useful, please kindly cite our paper:
```bibtex
@inproceedings{elazar2023s,
title={What's In My Big Data?},
author={Elazar, Yanai and Bhagia, Akshita and Magnusson, Ian Helgi and Ravichander, Abhilasha and Schwenk, Dustin and Suhr, Alane and Walsh, Evan Pete and Groeneveld, Dirk and Soldaini, Luca and Singh, Sameer and Hajishirzi, Hanna and Smith, Noah A. and Dodge, Jesse},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024}
}```"""
def custom_layout(input_components, output_components, citation):
return [
input_components[0], # Textbox
*input_components[1:], # Checkboxes
output_components[0], # Dataframe
# output_components[1], # Markdown
output_components[1], # Plot
citation, # Citation Markdown
]
iface = gr.Interface(
fn=process_input,
inputs=[
gr.Textbox(
label="Enter phrases (one per line)",
lines=5,
value="let's think step by step\nhello world",
),
*[
gr.Checkbox(label=dataset, value=(dataset in default_checked))
for dataset in dataset_es_map.keys()
],
],
outputs=[
gr.Dataframe(headers=["Dataset", "Phrase", "Count"], label="Counts Table"),
# gr.Markdown(label="Results as Text"),
gr.Plot(label="Results Chart"),
# gr.Markdown(value=citation_text)
],
title="What's In My Big Data? String Counts Demo",
description="""This app connects to the WIMBD Elasticsearch instance and counts the number of documents containing a given string in the various indexed datasets.\\
The app uses the wimbd pypi package, which can be installed by simply running `pip install wimbd`.\\
Access to the indices require an API key, due to the sensitive nature of the data, but can be accessed by filling up the following [form](https://forms.gle/Mk9uwJibR9H4hh9Y9).\\
This app was created by [Yanai Elazar](https://yanaiela.github.io/), and for bugs, improvements, or feature requests, please open an issue on the [GitHub repository](https://github.com/allenai/wimbd), or send me an email.
The indices were set up as part of the WIMBD project, which you can read about in our [ICLR paper](https://arxiv.org/abs/2310.20707).
The returned counts are the number of documents that contain each string per dataset.""",
article=citation_text, # This adds the citation at the bottom
theme=custom_layout, # This uses our custom layout function
)
iface.launch()