|
import nomic |
|
import pandas as pd |
|
from tqdm import tqdm |
|
from datasets import load_dataset, \ |
|
get_dataset_split_names, \ |
|
get_dataset_config_names, \ |
|
ClassLabel, utils |
|
|
|
utils.logging.set_verbosity_error() |
|
import pyarrow as pa |
|
from dateutil.parser import parse |
|
import time |
|
|
|
|
|
def get_datum_fields(dataset_dict, n_samples = 100, unique_cutoff=20): |
|
|
|
dataset = dataset_dict["first_split_dataset"] |
|
sample = pd.DataFrame(dataset.shuffle(seed=42).take(n_samples)) |
|
features = dataset.features |
|
|
|
numeric_fields = [] |
|
string_fields = [] |
|
bool_fields = [] |
|
list_fields = [] |
|
label_fields = [] |
|
categorical_fields = [] |
|
datetime_fields = [] |
|
uncategorized_fields = [] |
|
|
|
if unique_cutoff < 1: |
|
unique_cutoff = unique_cutoff*len(sample) |
|
|
|
for field, dtype in dataset_dict["schema"].items(): |
|
try: |
|
num_unique = sample[field].nunique() |
|
except: |
|
num_unique = len(sample) |
|
|
|
if dtype == "string": |
|
if num_unique < unique_cutoff: |
|
categorical_fields.append(field) |
|
else: |
|
is_datetime = True |
|
for row in sample: |
|
try: |
|
parse(row[field], fuzzy=False) |
|
except: |
|
is_datetime = False |
|
break |
|
if is_datetime: |
|
datetime_fields.append(field) |
|
else: |
|
string_fields.append(field) |
|
|
|
elif dtype in ("float"): |
|
numeric_fields.append(field) |
|
|
|
elif dtype in ("int64", "int32", "int16", "int8"): |
|
if features is not None and field in features and isinstance(features[field], ClassLabel): |
|
label_fields.append(field) |
|
elif num_unique < unique_cutoff: |
|
categorical_fields.append(field) |
|
else: |
|
numeric_fields.append(field) |
|
|
|
elif dtype == "bool": |
|
bool_fields.append(field) |
|
|
|
elif "list" == dtype[0:4]: |
|
list_fields.append(field) |
|
|
|
else: |
|
uncategorized_fields.append(field) |
|
|
|
return features, \ |
|
numeric_fields, \ |
|
string_fields, \ |
|
bool_fields, \ |
|
list_fields, \ |
|
label_fields, \ |
|
categorical_fields, \ |
|
datetime_fields, \ |
|
uncategorized_fields |
|
|
|
|
|
def load_dataset_and_metadata(dataset_name, |
|
config=None, |
|
streaming=True): |
|
|
|
configs = get_dataset_config_names(dataset_name) |
|
if config is None: |
|
config = configs[0] |
|
|
|
splits = get_dataset_split_names(dataset_name, config) |
|
dataset = load_dataset(dataset_name, config, split = splits[0], streaming=streaming) |
|
head = pa.Table.from_pydict(dataset._head()) |
|
|
|
schema_dict = {field.name: str(field.type) for field in head.schema} |
|
|
|
dataset_dict = { |
|
"first_split_dataset": dataset, |
|
"name": dataset_name, |
|
"config": config, |
|
"splits": splits, |
|
"schema": schema_dict, |
|
"head": head |
|
} |
|
|
|
return dataset_dict |
|
|
|
|
|
def upload_dataset_to_atlas(dataset_dict, |
|
atlas_api_token: str, |
|
project_name = None, |
|
unique_id_field_name=None, |
|
indexed_field = None, |
|
modality=None, |
|
organization_name=None, |
|
wait_for_map=True, |
|
datum_limit=30000): |
|
nomic.login(atlas_api_token) |
|
|
|
if modality is None: |
|
modality = "text" |
|
|
|
if unique_id_field_name is None: |
|
unique_id_field_name = "atlas_datum_id" |
|
|
|
if project_name is None: |
|
project_name = dataset_dict["name"].replace("/", "--") + "--hf-atlas-map" |
|
|
|
desc = f"Config: {dataset_dict['config']}" |
|
|
|
features, \ |
|
numeric_fields, \ |
|
string_fields, \ |
|
bool_fields, \ |
|
list_fields, \ |
|
label_fields, \ |
|
categorical_fields, \ |
|
datetime_fields, \ |
|
uncategorized_fields = get_datum_fields(dataset_dict) |
|
|
|
|
|
|
|
if indexed_field is None: |
|
longest_length = 0 |
|
for field in string_fields: |
|
length = 0 |
|
for i in range(len(dataset_dict["head"])): |
|
ex = dataset_dict["head"].take([i]) |
|
if ex[field]: |
|
length += len(ex[field].split()) |
|
if length > longest_length: |
|
longest_length = length |
|
indexed_field = field |
|
|
|
topic_label_field = None |
|
if modality == "embedding": |
|
topic_label_field = indexed_field |
|
indexed_field = None |
|
|
|
|
|
easy_fields = string_fields + bool_fields + list_fields + categorical_fields |
|
|
|
proj = nomic.AtlasProject(name=project_name, |
|
modality=modality, |
|
unique_id_field=unique_id_field_name, |
|
organization_name=organization_name, |
|
description=desc, |
|
reset_project_if_exists=True) |
|
|
|
colorable_fields = ["split"] |
|
|
|
batch_size = 1000 |
|
batched_texts = [] |
|
|
|
allow_upload = True |
|
|
|
for split in dataset_dict["splits"]: |
|
|
|
if not allow_upload: |
|
break |
|
|
|
dataset = load_dataset(dataset_dict["name"], dataset_dict["config"], split = split, streaming=True) |
|
|
|
for i, ex in tqdm(enumerate(dataset)): |
|
if i % 10000 == 0: |
|
time.sleep(2) |
|
if i == datum_limit: |
|
print("Datum upload limited to 30,000 points. Stopping upload...") |
|
allow_upload = False |
|
break |
|
|
|
data_to_add = {"split": split, unique_id_field_name: f"{split}_{i}"} |
|
|
|
for field in numeric_fields: |
|
data_to_add[field] = ex[field] |
|
|
|
for field in easy_fields: |
|
val = "" |
|
if ex[field]: |
|
val = str(ex[field]) |
|
data_to_add[field] = val |
|
|
|
for field in datetime_fields: |
|
try: |
|
data_to_add[field] = parse(ex[field], fuzzy=False) |
|
except: |
|
data_to_add[field] = None |
|
|
|
for field in label_fields: |
|
label_name = "" |
|
if ex[field] is not None: |
|
index = ex[field] |
|
|
|
if index != -1: |
|
label_name = features[field].names[ex[field]] |
|
data_to_add[field] = str(ex[field]) |
|
data_to_add[field + "_name"] = label_name |
|
colorable_fields.add(field + "_name") |
|
|
|
for field in list_fields: |
|
list_str = "" |
|
if ex[field]: |
|
try: |
|
list_str = str(ex[field]) |
|
except: |
|
continue |
|
data_to_add[field] = list_str |
|
|
|
batched_texts.append(data_to_add) |
|
|
|
if len(batched_texts) >= batch_size: |
|
proj.add_text(batched_texts) |
|
batched_texts = [] |
|
|
|
if len(batched_texts) > 0: |
|
proj.add_text(batched_texts) |
|
|
|
colorable_fields = colorable_fields + \ |
|
categorical_fields + label_fields + bool_fields + datetime_fields |
|
|
|
projection = proj.create_index(name=project_name + " index", |
|
indexed_field=indexed_field, |
|
colorable_fields=colorable_fields, |
|
topic_label_field = topic_label_field, |
|
build_topic_model=True) |
|
|
|
if wait_for_map: |
|
with proj.wait_for_project_lock(): |
|
time.sleep(1) |
|
|
|
return projection.map_link |
|
|
|
|
|
if __name__ == "__main__": |
|
dataset_name = "databricks/databricks-dolly-15k" |
|
|
|
project_name = "huggingface_auto_upload_test-dolly-15k" |
|
|
|
dataset_dict = load_dataset_and_metadata(dataset_name) |
|
print(upload_dataset_to_atlas(dataset_dict, project_name=project_name)) |
|
|