|
import os |
|
import shutil |
|
import logging |
|
import pretty_errors |
|
import huggingface_hub |
|
from datasets import Dataset, load_dataset, disable_caching |
|
import schedule |
|
import time |
|
|
|
disable_caching() |
|
|
|
|
|
logger = logging.getLogger("basic_logger") |
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
console_handler = logging.StreamHandler() |
|
console_handler.setLevel(logging.INFO) |
|
formatter = logging.Formatter( |
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
) |
|
console_handler.setFormatter(formatter) |
|
logger.addHandler(console_handler) |
|
|
|
DS_NAME = "amaye15/object-segmentation" |
|
DATA_DIR = "data" |
|
|
|
|
|
def get_data(): |
|
ds = load_dataset( |
|
DS_NAME, |
|
cache_dir=os.path.join(os.getcwd(), DATA_DIR), |
|
streaming=True, |
|
download_mode="force_redownload", |
|
) |
|
for row in ds["train"]: |
|
yield row |
|
|
|
|
|
def process_and_push_data(): |
|
p = os.path.join(os.getcwd(), DATA_DIR) |
|
|
|
if os.path.exists(p): |
|
shutil.rmtree(p) |
|
|
|
os.mkdir(p) |
|
|
|
ds_processed = Dataset.from_generator(get_data) |
|
ds_processed.push_to_hub("amaye15/tmp") |
|
|
|
|
|
|
|
|
|
schedule.every(1).minute.do(process_and_push_data) |
|
|
|
|
|
while True: |
|
schedule.run_pending() |
|
time.sleep(1) |
|
|