atlas-map / build_map.py
Richard Guo
use sample instead of head
a2483b1
import nomic
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset, \
get_dataset_split_names, \
get_dataset_config_names, \
ClassLabel, utils
utils.logging.set_verbosity_error()
import pyarrow as pa
from dateutil.parser import parse
import time
def get_datum_fields(dataset_dict, n_samples = 100, unique_cutoff=20):
# take a sample of points
dataset = dataset_dict["first_split_dataset"]
sample = pd.DataFrame(dataset.shuffle(seed=42).take(n_samples))
features = dataset.features
indexable_field = None
numeric_fields = []
string_fields = []
bool_fields = []
list_fields = []
label_fields = []
categorical_fields = []
datetime_fields = []
uncategorized_fields = []
if unique_cutoff < 1:
unique_cutoff = unique_cutoff*len(sample)
for field, dtype in dataset_dict["schema"].items():
try:
num_unique = sample[field].nunique()
except:
num_unique = len(sample)
if dtype == "string":
if num_unique < unique_cutoff:
categorical_fields.append(field)
else:
is_datetime = True
for row in sample:
try:
parse(row[field], fuzzy=False)
except:
is_datetime = False
break
if is_datetime:
datetime_fields.append(field)
else:
string_fields.append(field)
elif dtype in ("float"):
numeric_fields.append(field)
elif dtype in ("int64", "int32", "int16", "int8"):
if features is not None and field in features and isinstance(features[field], ClassLabel):
label_fields.append(field)
elif num_unique < unique_cutoff:
categorical_fields.append(field)
else:
numeric_fields.append(field)
elif dtype == "bool":
bool_fields.append(field)
elif "list" == dtype[0:4]:
list_fields.append(field)
else:
uncategorized_fields.append(field)
longest_length = 0
for field in string_fields:
length = 0
for i in range(len(sample)):
if sample[field][i]:
length += len(str(sample[field][i]).split())
if length > longest_length:
longest_length = length
indexable_field = field
return features, \
numeric_fields, \
string_fields, \
bool_fields, \
list_fields, \
label_fields, \
categorical_fields, \
datetime_fields, \
uncategorized_fields, \
indexable_field
def load_dataset_and_metadata(dataset_name,
config=None,
streaming=True):
configs = get_dataset_config_names(dataset_name)
if config is None:
config = configs[0]
splits = get_dataset_split_names(dataset_name, config)
dataset = load_dataset(dataset_name, config, split = splits[0], streaming=streaming)
head = pa.Table.from_pydict(dataset._head())
schema_dict = {field.name: str(field.type) for field in head.schema}
dataset_dict = {
"first_split_dataset": dataset,
"name": dataset_name,
"config": config,
"splits": splits,
"schema": schema_dict,
"head": head
}
return dataset_dict
def upload_dataset_to_atlas(dataset_dict,
atlas_api_token: str,
project_name = None,
unique_id_field_name=None,
indexed_field = None,
modality=None,
organization_name=None,
wait_for_map=True,
datum_limit=30000):
nomic.login(atlas_api_token)
if modality is None:
modality = "text"
if unique_id_field_name is None:
unique_id_field_name = "atlas_datum_id"
if project_name is None:
project_name = dataset_dict["name"].replace("/", "--") + "--hf-atlas-map"
desc = f"Config: {dataset_dict['config']}"
features, \
numeric_fields, \
string_fields, \
bool_fields, \
list_fields, \
label_fields, \
categorical_fields, \
datetime_fields, \
uncategorized_fields, \
indexable_field = get_datum_fields(dataset_dict)
if indexed_field is None:
indexed_field = indexable_field
topic_label_field = None
if modality == "embedding":
topic_label_field = indexed_field
indexed_field = None
easy_fields = string_fields + bool_fields + list_fields + categorical_fields
proj = nomic.AtlasProject(name=project_name,
modality=modality,
unique_id_field=unique_id_field_name,
organization_name=organization_name,
description=desc,
reset_project_if_exists=True)
colorable_fields = ["split"]
batch_size = 1000
batched_texts = []
allow_upload = True
for split in dataset_dict["splits"]:
if not allow_upload:
break
dataset = load_dataset(dataset_dict["name"], dataset_dict["config"], split = split, streaming=True)
for i, ex in tqdm(enumerate(dataset)):
if i % 10000 == 0:
time.sleep(2)
if i == datum_limit:
print("Datum upload limited to 30,000 points. Stopping upload...")
allow_upload = False
break
data_to_add = {"split": split, unique_id_field_name: f"{split}_{i}"}
for field in numeric_fields:
data_to_add[field] = ex[field]
for field in easy_fields:
val = ""
if ex[field]:
val = str(ex[field])
data_to_add[field] = val
for field in datetime_fields:
try:
data_to_add[field] = parse(ex[field], fuzzy=False)
except:
data_to_add[field] = None
for field in label_fields:
label_name = ""
if ex[field] is not None:
index = ex[field]
# NOTE: THIS MAY BREAK if -1 is ACTUALLY NO LABEL
if index != -1:
label_name = features[field].names[ex[field]]
data_to_add[field] = str(ex[field])
data_to_add[field + "_name"] = label_name
colorable_fields.add(field + "_name")
for field in list_fields:
list_str = ""
if ex[field]:
try:
list_str = str(ex[field])
except:
continue
data_to_add[field] = list_str
batched_texts.append(data_to_add)
if len(batched_texts) >= batch_size:
proj.add_text(batched_texts)
batched_texts = []
if len(batched_texts) > 0:
proj.add_text(batched_texts)
colorable_fields = colorable_fields + \
categorical_fields + label_fields + bool_fields + datetime_fields
projection = proj.create_index(name=project_name + " index",
indexed_field=indexed_field,
colorable_fields=colorable_fields,
topic_label_field = topic_label_field,
build_topic_model=True)
if wait_for_map:
with proj.wait_for_project_lock():
time.sleep(1)
return projection.map_link
# Run test
if __name__ == "__main__":
dataset_name = "databricks/databricks-dolly-15k"
#dataset_name = "fka/awesome-chatgpt-prompts"
project_name = "huggingface_auto_upload_test-dolly-15k"
dataset_dict = load_dataset_and_metadata(dataset_name)
api_token = "ODdPKqJHYci4Gq4jnCC5-VR0L-rnIdfIy-6djgC4CTPCJ"
print(upload_dataset_to_atlas(dataset_dict, api_token, project_name=project_name))