import copy import os from functools import lru_cache, partial import gradio as gr import numpy as np import pandas as pd import pyarrow as pa import pyarrow.parquet as pq from tqdm.contrib.concurrent import thread_map from fastapi import FastAPI, Response import uvicorn from hffs.fs import HfFileSystem from datasets import Features, Image, Audio, Sequence from typing import List, Tuple, Callable from utils import ndarray_to_base64, clean_up_df, create_statistic, create_plot from comm_utils import save_to_file, send_msg_to_server class AppError(RuntimeError): pass APP_URL = "http://127.0.0.1:7860" if os.getenv("DEV") else "https://Kamarov-lotsa-explorer.hf.space" PAGE_SIZE = 1 MAX_CACHED_BLOBS = PAGE_SIZE * 10 TIME_PLOTS_NUM = 1 _blobs_cache = {} ##################################################### # Define routes for image and audio files ##################################################### app = FastAPI() @app.get( "/image", responses={200: {"content": {"image/png": {}}}}, response_class=Response, ) def image(id: str): blob = get_blob(id) return Response(content=blob, media_type="image/png") @app.get( "/audio", responses={200: {"content": {"audio/wav": {}}}}, response_class=Response, ) def audio(id: str): blob = get_blob(id) return Response(content=blob, media_type="audio/wav") def push_blob(blob: bytes, blob_id: str) -> str: global _blobs_cache if blob_id in _blobs_cache: del _blobs_cache[blob_id] _blobs_cache[blob_id] = blob if len(_blobs_cache) > MAX_CACHED_BLOBS: del _blobs_cache[next(iter(_blobs_cache))] return blob_id def get_blob(blob_id: str) -> bytes: global _blobs_cache return _blobs_cache[blob_id] def blobs_to_urls(blobs: List[bytes], type: str, prefix: str) -> List[str]: image_blob_ids = [push_blob(blob, f"{prefix}-{i}") for i, blob in enumerate(blobs)] return [APP_URL + f"/{type}?id={blob_id}" for blob_id in image_blob_ids] ##################################################### # List configs, splits and parquet files ##################################################### @lru_cache(maxsize=128) def get_parquet_fs(dataset: str) -> HfFileSystem: try: fs = HfFileSystem(dataset, repo_type="dataset", revision="refs/convert/parquet") if any(fs.isfile(path) for path in fs.ls("") if not path.startswith(".")): raise AppError(f"Parquet export doesn't exist for '{dataset}'.") return fs except: raise AppError(f"Parquet export doesn't exist for '{dataset}'.") @lru_cache(maxsize=128) def get_parquet_configs(dataset: str) -> List[str]: fs = get_parquet_fs(dataset) return [path for path in fs.ls("") if fs.isdir(path)] def _sorted_split_key(split: str) -> str: return split if not split.startswith("train") else chr(0) + split # always "train" first @lru_cache(maxsize=128) def get_parquet_splits(dataset: str, config: str) -> List[str]: fs = get_parquet_fs(dataset) return [path.split("/")[1] for path in fs.ls(config) if fs.isdir(path)] ##################################################### # Index and query Parquet data ##################################################### RowGroupReaders = List[Callable[[], pa.Table]] @lru_cache(maxsize=128) def index(dataset: str, config: str, split: str) -> Tuple[np.ndarray, RowGroupReaders, int, Features]: fs = get_parquet_fs(dataset) sources = fs.glob(f"{config}/{split}/*.parquet") if not sources: if config not in get_parquet_configs(dataset): raise AppError(f"Invalid config {config}. Available configs are: {', '.join(get_parquet_configs(dataset))}.") else: raise AppError(f"Invalid split {split}. Available splits are: {', '.join(get_parquet_splits(dataset, config))}.") desc = f"{dataset}/{config}/{split}" all_pf: List[pq.ParquetFile] = thread_map(partial(pq.ParquetFile, filesystem=fs), sources, desc=desc, unit="pq") features = Features.from_arrow_schema(all_pf[0].schema.to_arrow_schema()) rg_offsets = np.cumsum([pf.metadata.row_group(i).num_rows for pf in all_pf for i in range(pf.metadata.num_row_groups)]) rg_readers = [partial(pf.read_row_group, i) for pf in all_pf for i in range(pf.metadata.num_row_groups)] max_page = 1 + (rg_offsets[-1] - 1) // PAGE_SIZE return rg_offsets, rg_readers, max_page, features def query(page: int, page_size: int, rg_offsets: np.ndarray, rg_readers: RowGroupReaders) -> pd.DataFrame: start_row, end_row = (page - 1) * page_size, min(page * page_size, rg_offsets[-1] - 1) # both included # rg_offsets[start_rg - 1] <= start_row < rg_offsets[start_rg] # rg_offsets[end_rg - 1] <= end_row < rg_offsets[end_rg] start_rg, end_rg = np.searchsorted(rg_offsets, [start_row, end_row], side="right") # both included pa_table = pa.concat_tables([rg_readers[i]() for i in range(start_rg, end_rg + 1)]) offset = start_row - (rg_offsets[start_rg - 1] if start_rg > 0 else 0) pa_table = pa_table.slice(offset, page_size) return pa_table.to_pandas() def sanitize_inputs(dataset: str, config: str, split: str, page: str) -> Tuple[str, str, str, int]: try: page = int(page) assert page > 0 except: raise AppError(f"Bad page: {page}") if not dataset: raise AppError("Empty dataset name") if not config: raise AppError(f"Empty config. Available configs are: {', '.join(get_parquet_configs(dataset))}.") if not split: raise AppError(f"Empty split. Available splits are: {', '.join(get_parquet_splits(dataset, config))}.") return dataset, config, split, int(page) @lru_cache(maxsize=128) def get_page_df(dataset: str, config: str, split: str, page: str) -> Tuple[pd.DataFrame, int, Features]: dataset, config, split, page = sanitize_inputs(dataset, config, split, page) rg_offsets, rg_readers, max_page, features = index(dataset, config, split) if page > max_page: raise AppError(f"Page {page} does not exist") df = query(page, PAGE_SIZE, rg_offsets=rg_offsets, rg_readers=rg_readers) return df, max_page, features ##################################################### # Format results ##################################################### def get_page(dataset: str, config: str, split: str, page: str) -> Tuple[str, int, str]: df_, max_page, features = get_page_df(dataset, config, split, page) df = copy.deepcopy(df_) unsupported_columns = [] if dataset != 'Salesforce/lotsa_data': for column, feature in features.items(): if isinstance(feature, Image): blob_type = "image" # TODO: support audio - right now it seems that the markdown renderer in gradio doesn't support audio and shows nothing blob_urls = blobs_to_urls([item.get("bytes") if isinstance(item, dict) else None for item in df[column]], blob_type, prefix=f"{dataset}-{config}-{split}-{page}-{column}") df = df.drop([column], axis=1) df[column] = [f"![]({url})" for url in blob_urls] elif any(bad_type in str(feature) for bad_type in ["Image(", "Audio(", "'binary'"]): unsupported_columns.append(column) df = df.drop([column], axis=1) elif isinstance(feature, Sequence): if feature.feature.dtype == 'float32': # 直接将内容绘图,并嵌入为Base64编码 base64_srcs = [ndarray_to_base64(vec) for vec in df[column]] df = df.drop([column], axis=1) df[column] = [f"![]({src})" for src in base64_srcs] info = "" if not unsupported_columns else f"Some columns are not supported yet: {unsupported_columns}" return df.reset_index().to_markdown(index=False), max_page, info else: # 对Salesforce/lotsa_data数据集进行特殊处理 info = "" if not unsupported_columns else f"Some columns are not supported yet: {unsupported_columns}" return df, max_page, info ##################################################### # Gradio app ##################################################### # 存取状态 # 保留小数位 with gr.Blocks() as demo: # 初始化组件 gr.Markdown("A tool for interactive observation of lotsa dataset, extended from lhoestq/datasets-explorer") cp_dataset = gr.Textbox("Salesforce/lotsa_data", label="Pick a dataset", interactive=False) cp_go = gr.Button("Explore") cp_config = gr.Dropdown(["plain_text"], value="plain_text", label="Config", visible=False) cp_split = gr.Dropdown(["train", "validation"], value="train", label="Split", visible=False) # cp_goto_next_page = gr.Button("Next page", visible=False) cp_error = gr.Markdown("", visible=False) cp_info = gr.Markdown("", visible=False) cp_result = gr.Markdown("", visible=False) tot_samples = 0 # 初始化Salesforce/lotsa_data数据集展示使用的组件 # componets = [] # for _ in range(TIME_PLOTS_NUM): with gr.Row(): with gr.Column(scale=2): select_sample_box = gr.Dropdown(choices=["items"], label="Select some items", multiselect=True, interactive=True) with gr.Column(scale=2): select_subtarget_box = gr.Dropdown(choices=["subtargets"], label="Select some subtargets", multiselect=True, interactive=True) with gr.Column(scale=1): select_buttom = gr.Button("Show selected items") with gr.Row(): with gr.Column(scale=2): statistics_textbox = gr.DataFrame() with gr.Column(scale=3): plot = gr.Plot() with gr.Row(): user_input_box = gr.Textbox(placeholder="输入一些内容", label="输入", lines=5, interactive=True) user_output_box = gr.Textbox(label="回答", lines=5, interactive=False) user_io_buttom = gr.Button("发送", interactive=True) # componets.append({"select_sample_box": select_sample_box, # "statistics_textbox": statistics_textbox, # "user_input_box": user_input_box, # "plot": plot}) # with gr.Row(): # cp_page = gr.Textbox("1", label="Page", placeholder="1", visible=False) # cp_goto_page = gr.Button("Go to page", visible=False) def show_error(message: str) -> dict: return { cp_error: gr.update(visible=True, value=f"## ❌ Error:\n\n{message}"), cp_info: gr.update(visible=False, value=""), cp_result: gr.update(visible=False, value=""), } def show_dataset_at_config_and_split_and_page(dataset: str, config: str, split: str, page: str|List[str], sub_targets: List[int|str]) -> dict: try: ret = {} if dataset != 'Salesforce/lotsa_data': markdown_result, max_page, info = get_page(dataset, config, split, page) ret[cp_result] = gr.update(visible=True, value=markdown_result) else: # 对Salesforce/lotsa_data数据集进行特殊处理 if type(page) == str: page = [page] df_list, id_list = [], [] # TODO: 将以下内容封装为函数 for i, page in enumerate(page): df, max_page, info = get_page(dataset, config, split, page) global tot_samples, tot_targets tot_samples, tot_targets = max_page, len(df['target'][0]) if isinstance(df['target'][0], np.ndarray) and df['target'][0].dtype == 'O' else 1 if 'all' in sub_targets: sub_targets = [i for i in range(tot_targets)] df = clean_up_df(df, sub_targets) row = df.iloc[0] id_list.append(row['item_id']) # 将单行的DataFrame展开为新的DataFrame df_without_index = row.drop('item_id').to_frame().T df_expanded = df_without_index.apply(pd.Series.explode).reset_index(drop=True).fillna(0) df_list.append(df_expanded) return { statistics_textbox: gr.update(value=create_statistic(df_list, id_list)), plot: gr.update(value=create_plot(df_list, id_list)), cp_info: gr.update(visible=True, value=f"Page {page}/{max_page} {info}"), cp_error: gr.update(visible=False, value="") } except AppError as err: return show_error(str(err)) def show_dataset_at_config_and_split_and_next_page(dataset: str, config: str, split: str, page: str) -> dict: try: next_page = str(int(page) + 1) return { **show_dataset_at_config_and_split_and_page(dataset, config, split, next_page), # cp_page: gr.update(value=next_page, visible=True), } except AppError as err: return show_error(str(err)) def show_dataset_at_config_and_split(dataset: str, config: str, split: str) -> dict: try: return { **show_dataset_at_config_and_split_and_page(dataset, config, split, "1", [0]), select_sample_box: gr.update(choices=[f"{i+1}" for i in range(tot_samples)], value=["1"]), select_subtarget_box: gr.update(choices=[i for i in range(tot_targets)]+['all'], value=[0]), # cp_page: gr.update(value="1", visible=True), # cp_goto_page: gr.update(visible=True), # cp_goto_next_page: gr.update(visible=True), } except AppError as err: return show_error(str(err)) def show_dataset_at_config(dataset: str, config: str) -> dict: try: splits = get_parquet_splits(dataset, config) if not splits: raise AppError(f"Dataset {dataset} with config {config} has no splits.") else: split = splits[0] return { **show_dataset_at_config_and_split(dataset, config, split), cp_split: gr.update(value=split, choices=splits, visible=len(splits) > 1), } except AppError as err: return show_error(str(err)) def show_dataset(dataset: str) -> dict: try: configs = get_parquet_configs(dataset) if not configs: raise AppError(f"Dataset {dataset} has no configs.") else: config = configs[0] return { **show_dataset_at_config(dataset, config), cp_config: gr.update(value=config, choices=configs, visible=len(configs) > 1), } except AppError as err: return show_error(str(err)) all_outputs = [cp_config, cp_split, # cp_page, cp_goto_page, cp_goto_next_page, cp_result, cp_info, cp_error, select_sample_box, select_subtarget_box, select_buttom, statistics_textbox, plot] cp_go.click(show_dataset, inputs=[cp_dataset], outputs=all_outputs) cp_config.change(show_dataset_at_config, inputs=[cp_dataset, cp_config], outputs=all_outputs) cp_split.change(show_dataset_at_config_and_split, inputs=[cp_dataset, cp_config, cp_split], outputs=all_outputs) # cp_goto_page.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs) # cp_goto_next_page.click(show_dataset_at_config_and_split_and_next_page, inputs=[cp_dataset, cp_config, cp_split, cp_page], outputs=all_outputs) user_io_buttom.click(send_msg_to_server, inputs=[user_input_box], outputs=[user_output_box]) select_buttom.click(show_dataset_at_config_and_split_and_page, inputs=[cp_dataset, cp_config, cp_split, select_sample_box, select_subtarget_box], outputs=all_outputs) if __name__ == "__main__": app = gr.mount_gradio_app(app, demo, path="/") host = "127.0.0.1" if os.getenv("DEV") else "0.0.0.0" import subprocess subprocess.Popen(["python", "test_server.py"]) uvicorn.run(app, host=host, port=7860)