Spaces:
Runtime error
Runtime error
""" | |
Utility routines | |
""" | |
from collections.abc import Mapping | |
from copy import deepcopy | |
import json | |
import itertools | |
import re | |
import sys | |
import traceback | |
import warnings | |
from typing import Callable, TypeVar, Any | |
import jsonschema | |
import pandas as pd | |
import numpy as np | |
from altair.utils.schemapi import SchemaBase | |
if sys.version_info >= (3, 10): | |
from typing import ParamSpec | |
else: | |
from typing_extensions import ParamSpec | |
try: | |
from pandas.api.types import infer_dtype as _infer_dtype | |
except ImportError: | |
# Import for pandas < 0.20.0 | |
from pandas.lib import infer_dtype as _infer_dtype # type: ignore[no-redef] | |
_V = TypeVar("_V") | |
_P = ParamSpec("_P") | |
def infer_dtype(value): | |
"""Infer the dtype of the value. | |
This is a compatibility function for pandas infer_dtype, | |
with skipna=False regardless of the pandas version. | |
""" | |
if not hasattr(infer_dtype, "_supports_skipna"): | |
try: | |
_infer_dtype([1], skipna=False) | |
except TypeError: | |
# pandas < 0.21.0 don't support skipna keyword | |
infer_dtype._supports_skipna = False | |
else: | |
infer_dtype._supports_skipna = True | |
if infer_dtype._supports_skipna: | |
return _infer_dtype(value, skipna=False) | |
else: | |
return _infer_dtype(value) | |
TYPECODE_MAP = { | |
"ordinal": "O", | |
"nominal": "N", | |
"quantitative": "Q", | |
"temporal": "T", | |
"geojson": "G", | |
} | |
INV_TYPECODE_MAP = {v: k for k, v in TYPECODE_MAP.items()} | |
# aggregates from vega-lite version 4.6.0 | |
AGGREGATES = [ | |
"argmax", | |
"argmin", | |
"average", | |
"count", | |
"distinct", | |
"max", | |
"mean", | |
"median", | |
"min", | |
"missing", | |
"product", | |
"q1", | |
"q3", | |
"ci0", | |
"ci1", | |
"stderr", | |
"stdev", | |
"stdevp", | |
"sum", | |
"valid", | |
"values", | |
"variance", | |
"variancep", | |
] | |
# window aggregates from vega-lite version 4.6.0 | |
WINDOW_AGGREGATES = [ | |
"row_number", | |
"rank", | |
"dense_rank", | |
"percent_rank", | |
"cume_dist", | |
"ntile", | |
"lag", | |
"lead", | |
"first_value", | |
"last_value", | |
"nth_value", | |
] | |
# timeUnits from vega-lite version 4.17.0 | |
TIMEUNITS = [ | |
"year", | |
"quarter", | |
"month", | |
"week", | |
"day", | |
"dayofyear", | |
"date", | |
"hours", | |
"minutes", | |
"seconds", | |
"milliseconds", | |
"yearquarter", | |
"yearquartermonth", | |
"yearmonth", | |
"yearmonthdate", | |
"yearmonthdatehours", | |
"yearmonthdatehoursminutes", | |
"yearmonthdatehoursminutesseconds", | |
"yearweek", | |
"yearweekday", | |
"yearweekdayhours", | |
"yearweekdayhoursminutes", | |
"yearweekdayhoursminutesseconds", | |
"yeardayofyear", | |
"quartermonth", | |
"monthdate", | |
"monthdatehours", | |
"monthdatehoursminutes", | |
"monthdatehoursminutesseconds", | |
"weekday", | |
"weeksdayhours", | |
"weekdayhoursminutes", | |
"weekdayhoursminutesseconds", | |
"dayhours", | |
"dayhoursminutes", | |
"dayhoursminutesseconds", | |
"hoursminutes", | |
"hoursminutesseconds", | |
"minutesseconds", | |
"secondsmilliseconds", | |
"utcyear", | |
"utcquarter", | |
"utcmonth", | |
"utcweek", | |
"utcday", | |
"utcdayofyear", | |
"utcdate", | |
"utchours", | |
"utcminutes", | |
"utcseconds", | |
"utcmilliseconds", | |
"utcyearquarter", | |
"utcyearquartermonth", | |
"utcyearmonth", | |
"utcyearmonthdate", | |
"utcyearmonthdatehours", | |
"utcyearmonthdatehoursminutes", | |
"utcyearmonthdatehoursminutesseconds", | |
"utcyearweek", | |
"utcyearweekday", | |
"utcyearweekdayhours", | |
"utcyearweekdayhoursminutes", | |
"utcyearweekdayhoursminutesseconds", | |
"utcyeardayofyear", | |
"utcquartermonth", | |
"utcmonthdate", | |
"utcmonthdatehours", | |
"utcmonthdatehoursminutes", | |
"utcmonthdatehoursminutesseconds", | |
"utcweekday", | |
"utcweeksdayhours", | |
"utcweekdayhoursminutes", | |
"utcweekdayhoursminutesseconds", | |
"utcdayhours", | |
"utcdayhoursminutes", | |
"utcdayhoursminutesseconds", | |
"utchoursminutes", | |
"utchoursminutesseconds", | |
"utcminutesseconds", | |
"utcsecondsmilliseconds", | |
] | |
def infer_vegalite_type(data): | |
""" | |
From an array-like input, infer the correct vega typecode | |
('ordinal', 'nominal', 'quantitative', or 'temporal') | |
Parameters | |
---------- | |
data: Numpy array or Pandas Series | |
""" | |
# Otherwise, infer based on the dtype of the input | |
typ = infer_dtype(data) | |
if typ in [ | |
"floating", | |
"mixed-integer-float", | |
"integer", | |
"mixed-integer", | |
"complex", | |
]: | |
return "quantitative" | |
elif typ == "categorical" and data.cat.ordered: | |
return ("ordinal", data.cat.categories.tolist()) | |
elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]: | |
return "nominal" | |
elif typ in [ | |
"datetime", | |
"datetime64", | |
"timedelta", | |
"timedelta64", | |
"date", | |
"time", | |
"period", | |
]: | |
return "temporal" | |
else: | |
warnings.warn( | |
"I don't know how to infer vegalite type from '{}'. " | |
"Defaulting to nominal.".format(typ), | |
stacklevel=1, | |
) | |
return "nominal" | |
def merge_props_geom(feat): | |
""" | |
Merge properties with geometry | |
* Overwrites 'type' and 'geometry' entries if existing | |
""" | |
geom = {k: feat[k] for k in ("type", "geometry")} | |
try: | |
feat["properties"].update(geom) | |
props_geom = feat["properties"] | |
except (AttributeError, KeyError): | |
# AttributeError when 'properties' equals None | |
# KeyError when 'properties' is non-existing | |
props_geom = geom | |
return props_geom | |
def sanitize_geo_interface(geo): | |
"""Santize a geo_interface to prepare it for serialization. | |
* Make a copy | |
* Convert type array or _Array to list | |
* Convert tuples to lists (using json.loads/dumps) | |
* Merge properties with geometry | |
""" | |
geo = deepcopy(geo) | |
# convert type _Array or array to list | |
for key in geo.keys(): | |
if str(type(geo[key]).__name__).startswith(("_Array", "array")): | |
geo[key] = geo[key].tolist() | |
# convert (nested) tuples to lists | |
geo = json.loads(json.dumps(geo)) | |
# sanitize features | |
if geo["type"] == "FeatureCollection": | |
geo = geo["features"] | |
if len(geo) > 0: | |
for idx, feat in enumerate(geo): | |
geo[idx] = merge_props_geom(feat) | |
elif geo["type"] == "Feature": | |
geo = merge_props_geom(geo) | |
else: | |
geo = {"type": "Feature", "geometry": geo} | |
return geo | |
def sanitize_dataframe(df): # noqa: C901 | |
"""Sanitize a DataFrame to prepare it for serialization. | |
* Make a copy | |
* Convert RangeIndex columns to strings | |
* Raise ValueError if column names are not strings | |
* Raise ValueError if it has a hierarchical index. | |
* Convert categoricals to strings. | |
* Convert np.bool_ dtypes to Python bool objects | |
* Convert np.int dtypes to Python int objects | |
* Convert floats to objects and replace NaNs/infs with None. | |
* Convert DateTime dtypes into appropriate string representations | |
* Convert Nullable integers to objects and replace NaN with None | |
* Convert Nullable boolean to objects and replace NaN with None | |
* convert dedicated string column to objects and replace NaN with None | |
* Raise a ValueError for TimeDelta dtypes | |
""" | |
df = df.copy() | |
if isinstance(df.columns, pd.RangeIndex): | |
df.columns = df.columns.astype(str) | |
for col in df.columns: | |
if not isinstance(col, str): | |
raise ValueError( | |
"Dataframe contains invalid column name: {0!r}. " | |
"Column names must be strings".format(col) | |
) | |
if isinstance(df.index, pd.MultiIndex): | |
raise ValueError("Hierarchical indices not supported") | |
if isinstance(df.columns, pd.MultiIndex): | |
raise ValueError("Hierarchical indices not supported") | |
def to_list_if_array(val): | |
if isinstance(val, np.ndarray): | |
return val.tolist() | |
else: | |
return val | |
for col_name, dtype in df.dtypes.items(): | |
if str(dtype) == "category": | |
# Work around bug in to_json for categorical types in older versions of pandas | |
# https://github.com/pydata/pandas/issues/10778 | |
# https://github.com/altair-viz/altair/pull/2170 | |
col = df[col_name].astype(object) | |
df[col_name] = col.where(col.notnull(), None) | |
elif str(dtype) == "string": | |
# dedicated string datatype (since 1.0) | |
# https://pandas.pydata.org/pandas-docs/version/1.0.0/whatsnew/v1.0.0.html#dedicated-string-data-type | |
col = df[col_name].astype(object) | |
df[col_name] = col.where(col.notnull(), None) | |
elif str(dtype) == "bool": | |
# convert numpy bools to objects; np.bool is not JSON serializable | |
df[col_name] = df[col_name].astype(object) | |
elif str(dtype) == "boolean": | |
# dedicated boolean datatype (since 1.0) | |
# https://pandas.io/docs/user_guide/boolean.html | |
col = df[col_name].astype(object) | |
df[col_name] = col.where(col.notnull(), None) | |
elif str(dtype).startswith("datetime"): | |
# Convert datetimes to strings. This needs to be a full ISO string | |
# with time, which is why we cannot use ``col.astype(str)``. | |
# This is because Javascript parses date-only times in UTC, but | |
# parses full ISO-8601 dates as local time, and dates in Vega and | |
# Vega-Lite are displayed in local time by default. | |
# (see https://github.com/altair-viz/altair/issues/1027) | |
df[col_name] = ( | |
df[col_name].apply(lambda x: x.isoformat()).replace("NaT", "") | |
) | |
elif str(dtype).startswith("timedelta"): | |
raise ValueError( | |
'Field "{col_name}" has type "{dtype}" which is ' | |
"not supported by Altair. Please convert to " | |
"either a timestamp or a numerical value." | |
"".format(col_name=col_name, dtype=dtype) | |
) | |
elif str(dtype).startswith("geometry"): | |
# geopandas >=0.6.1 uses the dtype geometry. Continue here | |
# otherwise it will give an error on np.issubdtype(dtype, np.integer) | |
continue | |
elif str(dtype) in { | |
"Int8", | |
"Int16", | |
"Int32", | |
"Int64", | |
"UInt8", | |
"UInt16", | |
"UInt32", | |
"UInt64", | |
"Float32", | |
"Float64", | |
}: # nullable integer datatypes (since 24.0) and nullable float datatypes (since 1.2.0) | |
# https://pandas.pydata.org/pandas-docs/version/0.25/whatsnew/v0.24.0.html#optional-integer-na-support | |
col = df[col_name].astype(object) | |
df[col_name] = col.where(col.notnull(), None) | |
elif np.issubdtype(dtype, np.integer): | |
# convert integers to objects; np.int is not JSON serializable | |
df[col_name] = df[col_name].astype(object) | |
elif np.issubdtype(dtype, np.floating): | |
# For floats, convert to Python float: np.float is not JSON serializable | |
# Also convert NaN/inf values to null, as they are not JSON serializable | |
col = df[col_name] | |
bad_values = col.isnull() | np.isinf(col) | |
df[col_name] = col.astype(object).where(~bad_values, None) | |
elif dtype == object: | |
# Convert numpy arrays saved as objects to lists | |
# Arrays are not JSON serializable | |
col = df[col_name].apply(to_list_if_array, convert_dtype=False) | |
df[col_name] = col.where(col.notnull(), None) | |
return df | |
def parse_shorthand( | |
shorthand, | |
data=None, | |
parse_aggregates=True, | |
parse_window_ops=False, | |
parse_timeunits=True, | |
parse_types=True, | |
): | |
"""General tool to parse shorthand values | |
These are of the form: | |
- "col_name" | |
- "col_name:O" | |
- "average(col_name)" | |
- "average(col_name):O" | |
Optionally, a dataframe may be supplied, from which the type | |
will be inferred if not specified in the shorthand. | |
Parameters | |
---------- | |
shorthand : dict or string | |
The shorthand representation to be parsed | |
data : DataFrame, optional | |
If specified and of type DataFrame, then use these values to infer the | |
column type if not provided by the shorthand. | |
parse_aggregates : boolean | |
If True (default), then parse aggregate functions within the shorthand. | |
parse_window_ops : boolean | |
If True then parse window operations within the shorthand (default:False) | |
parse_timeunits : boolean | |
If True (default), then parse timeUnits from within the shorthand | |
parse_types : boolean | |
If True (default), then parse typecodes within the shorthand | |
Returns | |
------- | |
attrs : dict | |
a dictionary of attributes extracted from the shorthand | |
Examples | |
-------- | |
>>> data = pd.DataFrame({'foo': ['A', 'B', 'A', 'B'], | |
... 'bar': [1, 2, 3, 4]}) | |
>>> parse_shorthand('name') == {'field': 'name'} | |
True | |
>>> parse_shorthand('name:Q') == {'field': 'name', 'type': 'quantitative'} | |
True | |
>>> parse_shorthand('average(col)') == {'aggregate': 'average', 'field': 'col'} | |
True | |
>>> parse_shorthand('foo:O') == {'field': 'foo', 'type': 'ordinal'} | |
True | |
>>> parse_shorthand('min(foo):Q') == {'aggregate': 'min', 'field': 'foo', 'type': 'quantitative'} | |
True | |
>>> parse_shorthand('month(col)') == {'field': 'col', 'timeUnit': 'month', 'type': 'temporal'} | |
True | |
>>> parse_shorthand('year(col):O') == {'field': 'col', 'timeUnit': 'year', 'type': 'ordinal'} | |
True | |
>>> parse_shorthand('foo', data) == {'field': 'foo', 'type': 'nominal'} | |
True | |
>>> parse_shorthand('bar', data) == {'field': 'bar', 'type': 'quantitative'} | |
True | |
>>> parse_shorthand('bar:O', data) == {'field': 'bar', 'type': 'ordinal'} | |
True | |
>>> parse_shorthand('sum(bar)', data) == {'aggregate': 'sum', 'field': 'bar', 'type': 'quantitative'} | |
True | |
>>> parse_shorthand('count()', data) == {'aggregate': 'count', 'type': 'quantitative'} | |
True | |
""" | |
if not shorthand: | |
return {} | |
valid_typecodes = list(TYPECODE_MAP) + list(INV_TYPECODE_MAP) | |
units = { | |
"field": "(?P<field>.*)", | |
"type": "(?P<type>{})".format("|".join(valid_typecodes)), | |
"agg_count": "(?P<aggregate>count)", | |
"op_count": "(?P<op>count)", | |
"aggregate": "(?P<aggregate>{})".format("|".join(AGGREGATES)), | |
"window_op": "(?P<op>{})".format("|".join(AGGREGATES + WINDOW_AGGREGATES)), | |
"timeUnit": "(?P<timeUnit>{})".format("|".join(TIMEUNITS)), | |
} | |
patterns = [] | |
if parse_aggregates: | |
patterns.extend([r"{agg_count}\(\)"]) | |
patterns.extend([r"{aggregate}\({field}\)"]) | |
if parse_window_ops: | |
patterns.extend([r"{op_count}\(\)"]) | |
patterns.extend([r"{window_op}\({field}\)"]) | |
if parse_timeunits: | |
patterns.extend([r"{timeUnit}\({field}\)"]) | |
patterns.extend([r"{field}"]) | |
if parse_types: | |
patterns = list(itertools.chain(*((p + ":{type}", p) for p in patterns))) | |
regexps = ( | |
re.compile(r"\A" + p.format(**units) + r"\Z", re.DOTALL) for p in patterns | |
) | |
# find matches depending on valid fields passed | |
if isinstance(shorthand, dict): | |
attrs = shorthand | |
else: | |
attrs = next( | |
exp.match(shorthand).groupdict() for exp in regexps if exp.match(shorthand) | |
) | |
# Handle short form of the type expression | |
if "type" in attrs: | |
attrs["type"] = INV_TYPECODE_MAP.get(attrs["type"], attrs["type"]) | |
# counts are quantitative by default | |
if attrs == {"aggregate": "count"}: | |
attrs["type"] = "quantitative" | |
# times are temporal by default | |
if "timeUnit" in attrs and "type" not in attrs: | |
attrs["type"] = "temporal" | |
# if data is specified and type is not, infer type from data | |
if isinstance(data, pd.DataFrame) and "type" not in attrs: | |
# Remove escape sequences so that types can be inferred for columns with special characters | |
if "field" in attrs and attrs["field"].replace("\\", "") in data.columns: | |
attrs["type"] = infer_vegalite_type(data[attrs["field"].replace("\\", "")]) | |
# ordered categorical dataframe columns return the type and sort order as a tuple | |
if isinstance(attrs["type"], tuple): | |
attrs["sort"] = attrs["type"][1] | |
attrs["type"] = attrs["type"][0] | |
# If an unescaped colon is still present, it's often due to an incorrect data type specification | |
# but could also be due to using a column name with ":" in it. | |
if ( | |
"field" in attrs | |
and ":" in attrs["field"] | |
and attrs["field"][attrs["field"].rfind(":") - 1] != "\\" | |
): | |
raise ValueError( | |
'"{}" '.format(attrs["field"].split(":")[-1]) | |
+ "is not one of the valid encoding data types: {}.".format( | |
", ".join(TYPECODE_MAP.values()) | |
) | |
+ "\nFor more details, see https://altair-viz.github.io/user_guide/encodings/index.html#encoding-data-types. " | |
+ "If you are trying to use a column name that contains a colon, " | |
+ 'prefix it with a backslash; for example "column\\:name" instead of "column:name".' | |
) | |
return attrs | |
def use_signature(Obj: Callable[_P, Any]): | |
"""Apply call signature and documentation of Obj to the decorated method""" | |
def decorate(f: Callable[..., _V]) -> Callable[_P, _V]: | |
# call-signature of f is exposed via __wrapped__. | |
# we want it to mimic Obj.__init__ | |
f.__wrapped__ = Obj.__init__ # type: ignore | |
f._uses_signature = Obj # type: ignore | |
# Supplement the docstring of f with information from Obj | |
if Obj.__doc__: | |
# Patch in a reference to the class this docstring is copied from, | |
# to generate a hyperlink. | |
doclines = Obj.__doc__.splitlines() | |
doclines[0] = f"Refer to :class:`{Obj.__name__}`" | |
if f.__doc__: | |
doc = f.__doc__ + "\n".join(doclines[1:]) | |
else: | |
doc = "\n".join(doclines) | |
try: | |
f.__doc__ = doc | |
except AttributeError: | |
# __doc__ is not modifiable for classes in Python < 3.3 | |
pass | |
return f | |
return decorate | |
def update_nested(original, update, copy=False): | |
"""Update nested dictionaries | |
Parameters | |
---------- | |
original : dict | |
the original (nested) dictionary, which will be updated in-place | |
update : dict | |
the nested dictionary of updates | |
copy : bool, default False | |
if True, then copy the original dictionary rather than modifying it | |
Returns | |
------- | |
original : dict | |
a reference to the (modified) original dict | |
Examples | |
-------- | |
>>> original = {'x': {'b': 2, 'c': 4}} | |
>>> update = {'x': {'b': 5, 'd': 6}, 'y': 40} | |
>>> update_nested(original, update) # doctest: +SKIP | |
{'x': {'b': 5, 'c': 4, 'd': 6}, 'y': 40} | |
>>> original # doctest: +SKIP | |
{'x': {'b': 5, 'c': 4, 'd': 6}, 'y': 40} | |
""" | |
if copy: | |
original = deepcopy(original) | |
for key, val in update.items(): | |
if isinstance(val, Mapping): | |
orig_val = original.get(key, {}) | |
if isinstance(orig_val, Mapping): | |
original[key] = update_nested(orig_val, val) | |
else: | |
original[key] = val | |
else: | |
original[key] = val | |
return original | |
def display_traceback(in_ipython=True): | |
exc_info = sys.exc_info() | |
if in_ipython: | |
from IPython.core.getipython import get_ipython | |
ip = get_ipython() | |
else: | |
ip = None | |
if ip is not None: | |
ip.showtraceback(exc_info) | |
else: | |
traceback.print_exception(*exc_info) | |
def infer_encoding_types(args, kwargs, channels): | |
"""Infer typed keyword arguments for args and kwargs | |
Parameters | |
---------- | |
args : tuple | |
List of function args | |
kwargs : dict | |
Dict of function kwargs | |
channels : module | |
The module containing all altair encoding channel classes. | |
Returns | |
------- | |
kwargs : dict | |
All args and kwargs in a single dict, with keys and types | |
based on the channels mapping. | |
""" | |
# Construct a dictionary of channel type to encoding name | |
# TODO: cache this somehow? | |
channel_objs = (getattr(channels, name) for name in dir(channels)) | |
channel_objs = ( | |
c for c in channel_objs if isinstance(c, type) and issubclass(c, SchemaBase) | |
) | |
channel_to_name = {c: c._encoding_name for c in channel_objs} | |
name_to_channel = {} | |
for chan, name in channel_to_name.items(): | |
chans = name_to_channel.setdefault(name, {}) | |
if chan.__name__.endswith("Datum"): | |
key = "datum" | |
elif chan.__name__.endswith("Value"): | |
key = "value" | |
else: | |
key = "field" | |
chans[key] = chan | |
# First use the mapping to convert args to kwargs based on their types. | |
for arg in args: | |
if isinstance(arg, (list, tuple)) and len(arg) > 0: | |
type_ = type(arg[0]) | |
else: | |
type_ = type(arg) | |
encoding = channel_to_name.get(type_, None) | |
if encoding is None: | |
raise NotImplementedError("positional of type {}" "".format(type_)) | |
if encoding in kwargs: | |
raise ValueError("encoding {} specified twice.".format(encoding)) | |
kwargs[encoding] = arg | |
def _wrap_in_channel_class(obj, encoding): | |
if isinstance(obj, SchemaBase): | |
return obj | |
if isinstance(obj, str): | |
obj = {"shorthand": obj} | |
if isinstance(obj, (list, tuple)): | |
return [_wrap_in_channel_class(subobj, encoding) for subobj in obj] | |
if encoding not in name_to_channel: | |
warnings.warn( | |
"Unrecognized encoding channel '{}'".format(encoding), stacklevel=1 | |
) | |
return obj | |
classes = name_to_channel[encoding] | |
cls = classes["value"] if "value" in obj else classes["field"] | |
try: | |
# Don't force validation here; some objects won't be valid until | |
# they're created in the context of a chart. | |
return cls.from_dict(obj, validate=False) | |
except jsonschema.ValidationError: | |
# our attempts at finding the correct class have failed | |
return obj | |
return { | |
encoding: _wrap_in_channel_class(obj, encoding) | |
for encoding, obj in kwargs.items() | |
} | |