Spaces:
Runtime error
Runtime error
import json | |
import os | |
import random | |
import hashlib | |
import warnings | |
import pandas as pd | |
from toolz import curried | |
from typing import Callable | |
from .core import sanitize_dataframe | |
from .core import sanitize_geo_interface | |
from .deprecation import AltairDeprecationWarning | |
from .plugin_registry import PluginRegistry | |
# ============================================================================== | |
# Data transformer registry | |
# ============================================================================== | |
DataTransformerType = Callable | |
class DataTransformerRegistry(PluginRegistry[DataTransformerType]): | |
_global_settings = {"consolidate_datasets": True} | |
def consolidate_datasets(self): | |
return self._global_settings["consolidate_datasets"] | |
def consolidate_datasets(self, value): | |
self._global_settings["consolidate_datasets"] = value | |
# ============================================================================== | |
# Data model transformers | |
# | |
# A data model transformer is a pure function that takes a dict or DataFrame | |
# and returns a transformed version of a dict or DataFrame. The dict objects | |
# will be the Data portion of the VegaLite schema. The idea is that user can | |
# pipe a sequence of these data transformers together to prepare the data before | |
# it hits the renderer. | |
# | |
# In this version of Altair, renderers only deal with the dict form of a | |
# VegaLite spec, after the Data model has been put into a schema compliant | |
# form. | |
# | |
# A data model transformer has the following type signature: | |
# DataModelType = Union[dict, pd.DataFrame] | |
# DataModelTransformerType = Callable[[DataModelType, KwArgs], DataModelType] | |
# ============================================================================== | |
class MaxRowsError(Exception): | |
"""Raised when a data model has too many rows.""" | |
pass | |
def limit_rows(data, max_rows=5000): | |
"""Raise MaxRowsError if the data model has more than max_rows. | |
If max_rows is None, then do not perform any check. | |
""" | |
check_data_type(data) | |
if hasattr(data, "__geo_interface__"): | |
if data.__geo_interface__["type"] == "FeatureCollection": | |
values = data.__geo_interface__["features"] | |
else: | |
values = data.__geo_interface__ | |
elif isinstance(data, pd.DataFrame): | |
values = data | |
elif isinstance(data, dict): | |
if "values" in data: | |
values = data["values"] | |
else: | |
return data | |
elif hasattr(data, "__dataframe__"): | |
values = data | |
if max_rows is not None and len(values) > max_rows: | |
raise MaxRowsError( | |
"The number of rows in your dataset is greater " | |
f"than the maximum allowed ({max_rows}).\n\n" | |
"See https://altair-viz.github.io/user_guide/large_datasets.html " | |
"for information on how to plot large datasets, " | |
"including how to install third-party data management tools and, " | |
"in the right circumstance, disable the restriction" | |
) | |
return data | |
def sample(data, n=None, frac=None): | |
"""Reduce the size of the data model by sampling without replacement.""" | |
check_data_type(data) | |
if isinstance(data, pd.DataFrame): | |
return data.sample(n=n, frac=frac) | |
elif isinstance(data, dict): | |
if "values" in data: | |
values = data["values"] | |
n = n if n else int(frac * len(values)) | |
values = random.sample(values, n) | |
return {"values": values} | |
elif hasattr(data, "__dataframe__"): | |
# experimental interchange dataframe support | |
pi = import_pyarrow_interchange() | |
pa_table = pi.from_dataframe(data) | |
n = n if n else int(frac * len(pa_table)) | |
indices = random.sample(range(len(pa_table)), n) | |
return pa_table.take(indices) | |
def to_json( | |
data, | |
prefix="altair-data", | |
extension="json", | |
filename="{prefix}-{hash}.{extension}", | |
urlpath="", | |
): | |
""" | |
Write the data model to a .json file and return a url based data model. | |
""" | |
data_json = _data_to_json_string(data) | |
data_hash = _compute_data_hash(data_json) | |
filename = filename.format(prefix=prefix, hash=data_hash, extension=extension) | |
with open(filename, "w") as f: | |
f.write(data_json) | |
return {"url": os.path.join(urlpath, filename), "format": {"type": "json"}} | |
def to_csv( | |
data, | |
prefix="altair-data", | |
extension="csv", | |
filename="{prefix}-{hash}.{extension}", | |
urlpath="", | |
): | |
"""Write the data model to a .csv file and return a url based data model.""" | |
data_csv = _data_to_csv_string(data) | |
data_hash = _compute_data_hash(data_csv) | |
filename = filename.format(prefix=prefix, hash=data_hash, extension=extension) | |
with open(filename, "w") as f: | |
f.write(data_csv) | |
return {"url": os.path.join(urlpath, filename), "format": {"type": "csv"}} | |
def to_values(data): | |
"""Replace a DataFrame by a data model with values.""" | |
check_data_type(data) | |
if hasattr(data, "__geo_interface__"): | |
if isinstance(data, pd.DataFrame): | |
data = sanitize_dataframe(data) | |
data = sanitize_geo_interface(data.__geo_interface__) | |
return {"values": data} | |
elif isinstance(data, pd.DataFrame): | |
data = sanitize_dataframe(data) | |
return {"values": data.to_dict(orient="records")} | |
elif isinstance(data, dict): | |
if "values" not in data: | |
raise KeyError("values expected in data dict, but not present.") | |
return data | |
elif hasattr(data, "__dataframe__"): | |
# experimental interchange dataframe support | |
pi = import_pyarrow_interchange() | |
pa_table = pi.from_dataframe(data) | |
return {"values": pa_table.to_pylist()} | |
def check_data_type(data): | |
"""Raise if the data is not a dict or DataFrame.""" | |
if not isinstance(data, (dict, pd.DataFrame)) and not any( | |
hasattr(data, attr) for attr in ["__geo_interface__", "__dataframe__"] | |
): | |
raise TypeError( | |
"Expected dict, DataFrame or a __geo_interface__ attribute, got: {}".format( | |
type(data) | |
) | |
) | |
# ============================================================================== | |
# Private utilities | |
# ============================================================================== | |
def _compute_data_hash(data_str): | |
return hashlib.md5(data_str.encode()).hexdigest() | |
def _data_to_json_string(data): | |
"""Return a JSON string representation of the input data""" | |
check_data_type(data) | |
if hasattr(data, "__geo_interface__"): | |
if isinstance(data, pd.DataFrame): | |
data = sanitize_dataframe(data) | |
data = sanitize_geo_interface(data.__geo_interface__) | |
return json.dumps(data) | |
elif isinstance(data, pd.DataFrame): | |
data = sanitize_dataframe(data) | |
return data.to_json(orient="records", double_precision=15) | |
elif isinstance(data, dict): | |
if "values" not in data: | |
raise KeyError("values expected in data dict, but not present.") | |
return json.dumps(data["values"], sort_keys=True) | |
elif hasattr(data, "__dataframe__"): | |
# experimental interchange dataframe support | |
pi = import_pyarrow_interchange() | |
pa_table = pi.from_dataframe(data) | |
return json.dumps(pa_table.to_pylist()) | |
else: | |
raise NotImplementedError( | |
"to_json only works with data expressed as " "a DataFrame or as a dict" | |
) | |
def _data_to_csv_string(data): | |
"""return a CSV string representation of the input data""" | |
check_data_type(data) | |
if hasattr(data, "__geo_interface__"): | |
raise NotImplementedError( | |
"to_csv does not work with data that " | |
"contains the __geo_interface__ attribute" | |
) | |
elif isinstance(data, pd.DataFrame): | |
data = sanitize_dataframe(data) | |
return data.to_csv(index=False) | |
elif isinstance(data, dict): | |
if "values" not in data: | |
raise KeyError("values expected in data dict, but not present") | |
return pd.DataFrame.from_dict(data["values"]).to_csv(index=False) | |
elif hasattr(data, "__dataframe__"): | |
# experimental interchange dataframe support | |
pi = import_pyarrow_interchange() | |
import pyarrow as pa | |
import pyarrow.csv as pa_csv | |
pa_table = pi.from_dataframe(data) | |
csv_buffer = pa.BufferOutputStream() | |
pa_csv.write_csv(pa_table, csv_buffer) | |
return csv_buffer.getvalue().to_pybytes().decode() | |
else: | |
raise NotImplementedError( | |
"to_csv only works with data expressed as " "a DataFrame or as a dict" | |
) | |
def pipe(data, *funcs): | |
""" | |
Pipe a value through a sequence of functions | |
Deprecated: use toolz.curried.pipe() instead. | |
""" | |
warnings.warn( | |
"alt.pipe() is deprecated, and will be removed in a future release. " | |
"Use toolz.curried.pipe() instead.", | |
AltairDeprecationWarning, | |
stacklevel=1, | |
) | |
return curried.pipe(data, *funcs) | |
def curry(*args, **kwargs): | |
"""Curry a callable function | |
Deprecated: use toolz.curried.curry() instead. | |
""" | |
warnings.warn( | |
"alt.curry() is deprecated, and will be removed in a future release. " | |
"Use toolz.curried.curry() instead.", | |
AltairDeprecationWarning, | |
stacklevel=1, | |
) | |
return curried.curry(*args, **kwargs) | |
def import_pyarrow_interchange(): | |
import pkg_resources | |
try: | |
pkg_resources.require("pyarrow>=11.0.0") | |
# The package is installed and meets the minimum version requirement | |
import pyarrow.interchange as pi | |
return pi | |
except pkg_resources.DistributionNotFound as err: | |
# The package is not installed | |
raise ImportError( | |
"Usage of the DataFrame Interchange Protocol requires the package 'pyarrow', but it is not installed." | |
) from err | |
except pkg_resources.VersionConflict as err: | |
# The package is installed but does not meet the minimum version requirement | |
raise ImportError( | |
"The installed version of 'pyarrow' does not meet the minimum requirement of version 11.0.0. " | |
"Please update 'pyarrow' to use the DataFrame Interchange Protocol." | |
) from err | |