Spark-on-HF-JupyterLab / data /hf_spark_utils.py
lhoestq's picture
lhoestq HF staff
Update data/hf_spark_utils.py
88ad24a verified
raw
history blame
7.62 kB
import math
import pickle
import tempfile
from functools import partial
from typing import Iterator, Optional, Union
import pyarrow as pa
import pyarrow.parquet as pq
from huggingface_hub import CommitOperationAdd, HfFileSystem
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
spark = None
def set_session(session):
global spark
spark = session
def _read(iterator: Iterator[pa.RecordBatch], columns: Optional[list[str]], filters: Optional[Union[list[tuple], list[list[tuple]]]], **kwargs) -> Iterator[pa.RecordBatch]:
for batch in iterator:
paths = batch[0].to_pylist()
ds = pq.ParquetDataset(paths, **kwargs)
yield from ds._dataset.to_batches(columns=columns, filter=pq.filters_to_expression(filters) if filters else None)
def read_parquet(
path: str,
columns: Optional[list[str]] = None,
filters: Optional[Union[list[tuple], list[list[tuple]]]] = None,
**kwargs,
) -> DataFrame:
"""
Loads Parquet files from Hugging Face using PyArrow, returning a PySPark `DataFrame`.
It reads Parquet files in a distributed manner.
Access private or gated repositories using `huggingface-cli login` or passing a token
using the `storage_options` argument: `storage_options={"token": "hf_xxx"}`
Parameters
----------
path : str
Path to the file. Prefix with a protocol like `hf://` to read from Hugging Face.
You can read from multiple files if you pass a globstring.
columns : list, default None
If not None, only these columns will be read from the file.
filters : List[Tuple] or List[List[Tuple]], default None
To filter out data.
Filter syntax: [[(column, op, val), ...],...]
where op is [==, =, >, >=, <, <=, !=, in, not in]
The innermost tuples are transposed into a set of filters applied
through an `AND` operation.
The outer list combines these sets of filters through an `OR`
operation.
A single list of tuples can also be used, meaning that no `OR`
operation between set of filters is to be conducted.
**kwargs
Any additional kwargs are passed to pyarrow.parquet.ParquetDataset.
Returns
-------
DataFrame
DataFrame based on parquet file.
Examples
--------
>>> path = "hf://datasets/username/dataset/data.parquet"
>>> pd.DataFrame({"foo": range(5), "bar": range(5, 10)}).to_parquet(path)
>>> read_parquet(path).show()
+---+---+
|foo|bar|
+---+---+
| 0| 5|
| 1| 6|
| 2| 7|
| 3| 8|
| 4| 9|
+---+---+
>>> read_parquet(path, columns=["bar"]).show()
+---+
|bar|
+---+
| 5|
| 6|
| 7|
| 8|
| 9|
+---+
>>> sel = [("foo", ">", 2)]
>>> read_parquet(path, filters=sel).show()
+---+---+
|foo|bar|
+---+---+
| 3| 8|
| 4| 9|
+---+---+
"""
filesystem: HfFileSystem = kwargs.pop("filesystem") if "filesystem" in kwargs else HfFileSystem(**kwargs.pop("storage_options", {}))
paths = filesystem.glob(path)
if not paths:
raise FileNotFoundError(f"Counldn't find any file at {path}")
rdd = spark.sparkContext.parallelize([{"path": path} for path in paths], len(paths))
df = spark.createDataFrame(rdd)
arrow_schema = pq.read_schema(filesystem.open(paths[0]))
schema = pa.schema([field for field in arrow_schema if (columns is None or field.name in columns)], metadata=arrow_schema.metadata)
return df.mapInArrow(
partial(_read, columns=columns, filters=filters, filesystem=filesystem, schema=arrow_schema, **kwargs),
from_arrow_schema(schema),
)
def _preupload(iterator: Iterator[pa.RecordBatch], path: str, schema: pa.Schema, filesystem: HfFileSystem, row_group_size: Optional[int] = None, **kwargs) -> Iterator[pa.RecordBatch]:
resolved_path = filesystem.resolve_path(path)
with tempfile.NamedTemporaryFile(suffix=".parquet") as temp_file:
with pq.ParquetWriter(temp_file.name, schema=schema, **kwargs) as writer:
for batch in iterator:
writer.write_batch(batch, row_group_size=row_group_size)
addition = CommitOperationAdd(path_in_repo=temp_file.name, path_or_fileobj=temp_file.name)
filesystem._api.preupload_lfs_files(repo_id=resolved_path.repo_id, additions=[addition], repo_type=resolved_path.repo_type, revision=resolved_path.revision)
yield pa.record_batch({"addition": [pickle.dumps(addition)]}, schema=pa.schema({"addition": pa.binary()}))
def _commit(iterator: Iterator[pa.RecordBatch], path: str, filesystem: HfFileSystem, max_operations_per_commit=50) -> Iterator[pa.RecordBatch]:
resolved_path = filesystem.resolve_path(path)
additions: list[CommitOperationAdd] = [pickle.loads(addition) for addition in pa.Table.from_batches(iterator, schema=pa.schema({"addition": pa.binary()}))[0].to_pylist()]
num_commits = math.ceil(len(additions) / max_operations_per_commit)
for shard_idx, addition in enumerate(additions):
addition.path_in_repo = resolved_path.path_in_repo.replace("{shard_idx:05d}", f"{shard_idx:05d}")
for i in range(0, num_commits):
operations = additions[i * max_operations_per_commit : (i + 1) * max_operations_per_commit]
commit_message = "Upload using PySpark" + (f" (part {i:05d}-of-{num_commits:05d})" if num_commits > 1 else "")
filesystem._api.create_commit(repo_id=resolved_path.repo_id, repo_type=resolved_path.repo_type, revision=resolved_path.revision, operations=operations, commit_message=commit_message)
yield pa.record_batch({"path": [addition.path_in_repo for addition in operations]}, schema=pa.schema({"path": pa.string()}))
def write_parquet(df: DataFrame, path: str, **kwargs) -> None:
"""
Write Parquet files to Hugging Face using PyArrow.
It uploads Parquet files in a distributed manner in two steps:
1. Preupload the Parquet files in parallel in a distributed banner
2. Commit the preuploaded files
Authenticate using `huggingface-cli login` or passing a token
using the `storage_options` argument: `storage_options={"token": "hf_xxx"}`
Parameters
----------
path : str
Path of the file or directory. Prefix with a protocol like `hf://` to read from Hugging Face.
It writes Parquet files in the form "part-xxxxx.parquet", or to a single file if `path ends with ".parquet".
**kwargs
Any additional kwargs are passed to pyarrow.parquet.ParquetWriter.
Returns
-------
DataFrame
DataFrame based on parquet file.
Examples
--------
>>> spark.createDataFrame(pd.DataFrame({"foo": range(5), "bar": range(5, 10)}))
>>> # Save to one file
>>> write_parquet(df, "hf://datasets/username/dataset/data.parquet")
>>> # OR save to a directory (possibly in many files)
>>> write_parquet(df, "hf://datasets/username/dataset")
"""
filesystem: HfFileSystem = kwargs.pop("filesystem", HfFileSystem(**kwargs.pop("storage_options", {})))
if path.endswith(".parquet") or path.endswith(".pq"):
df = df.coalesce(1)
else:
path += "/part-{shard_idx:05d}.parquet"
df.mapInArrow(
partial(_preupload, path=path, schema=to_arrow_schema(df.schema), filesystem=filesystem, **kwargs),
from_arrow_schema(pa.schema({"addition": pa.binary()})),
).repartition(1).mapInArrow(
partial(_commit, path=path, filesystem=filesystem),
from_arrow_schema(pa.schema({"path": pa.string()})),
).collect()