import datasets from datasets import load_dataset from PIL import Image from pathlib import Path import pandas as pd import os import json import tqdm import argparse import shutil import numpy as np np.random.seed(0) """ Creates a directory with images and JSON files for VQA examples. Final json is located in metadata_sampled.json """ def download_images_and_create_json( dataset_info, cache_dir="~/vqa_examples_cache", base_dir="./vqa_examples" ): for dataset_name, info in dataset_info.items(): dataset_cache_dir = os.path.join(cache_dir, dataset_name) os.makedirs(dataset_cache_dir, exist_ok=True) if info["subset"]: dataset = load_dataset( info["path"], info["subset"], cache_dir=dataset_cache_dir, split=info["split"], ) else: dataset = load_dataset( info["path"], cache_dir=dataset_cache_dir, split=info["split"] ) dataset_dir = os.path.join(base_dir, dataset_name) os.makedirs(dataset_dir, exist_ok=True) json_data = [] for i, item in enumerate(tqdm.tqdm(dataset)): id_key = i if info["id_key"] == "index" else item[info["id_key"]] image_pil = item[info["image_key"]].convert("RGB") image_path = os.path.join(dataset_dir, f"{id_key}.jpg") image_pil.save(image_path) json_entry = { "dataset": dataset_name, "question": item[info["question_key"]], "path": image_path, } json_data.append(json_entry) with open(os.path.join(dataset_dir, "data.json"), "w") as json_file: json.dump(json_data, json_file, indent=4) # Delete the cache directory for the dataset shutil.rmtree(dataset_cache_dir, ignore_errors=True) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--data_dir", type=str, default="~/.cache") parser.add_argument("--output_dir", type=str, default="./vqa_examples") args = parser.parse_args() datasets_info = { "DocVQA": { "path": "lmms-lab/DocVQA", "image_key": "image", "question_key": "question", "id_key": "questionId", "subset": "DocVQA", "split": "test", }, "ChartQA": { "path": "HuggingFaceM4/ChartQA", "image_key": "image", "question_key": "query", "id_key": "index", "subset": False, "split": "test", }, "realworldqa": { "path": "visheratin/realworldqa", "image_key": "image", "question_key": "question", "id_key": "index", "subset": False, "split": "test", }, "NewYorker": { "path": "jmhessel/newyorker_caption_contest", "image_key": "image", "question_key": "questions", "id_key": "index", "subset": "explanation", "split": "train", }, "WikiArt": { "path": "huggan/wikiart", "image_key": "image", "question_key": "artist", "id_key": "index", "subset": False, "split": "train", }, "TextVQA": { "path": "facebook/textvqa", "image_key": "image", "question_key": "question", "id_key": "question_id", "subset": False, "split": "train", }, } download_images_and_create_json( datasets_info, cache_dir=args.data_dir, base_dir=args.output_dir ) dataset_json = [] for dataset_name in datasets_info.keys(): with open(f"{args.output_dir}/{dataset_name}/data.json") as f: data = json.load(f) dataset_json.extend(np.random.choice(data, 500)) with open(f"{args.output_dir}/metadata_sampled.json", "w") as f: json.dump(dataset_json, f, indent=4)