FIRE / src /serve /vision /create_vqa_examples_dir.py
zhangbofei
feat: change to fstchat
6dc0c9c
raw
history blame contribute delete
No virus
4.09 kB
import datasets
from datasets import load_dataset
from PIL import Image
from pathlib import Path
import pandas as pd
import os
import json
import tqdm
import argparse
import shutil
import numpy as np
np.random.seed(0)
"""
Creates a directory with images and JSON files for VQA examples. Final json is located in metadata_sampled.json
"""
def download_images_and_create_json(
dataset_info, cache_dir="~/vqa_examples_cache", base_dir="./vqa_examples"
):
for dataset_name, info in dataset_info.items():
dataset_cache_dir = os.path.join(cache_dir, dataset_name)
os.makedirs(dataset_cache_dir, exist_ok=True)
if info["subset"]:
dataset = load_dataset(
info["path"],
info["subset"],
cache_dir=dataset_cache_dir,
split=info["split"],
)
else:
dataset = load_dataset(
info["path"], cache_dir=dataset_cache_dir, split=info["split"]
)
dataset_dir = os.path.join(base_dir, dataset_name)
os.makedirs(dataset_dir, exist_ok=True)
json_data = []
for i, item in enumerate(tqdm.tqdm(dataset)):
id_key = i if info["id_key"] == "index" else item[info["id_key"]]
image_pil = item[info["image_key"]].convert("RGB")
image_path = os.path.join(dataset_dir, f"{id_key}.jpg")
image_pil.save(image_path)
json_entry = {
"dataset": dataset_name,
"question": item[info["question_key"]],
"path": image_path,
}
json_data.append(json_entry)
with open(os.path.join(dataset_dir, "data.json"), "w") as json_file:
json.dump(json_data, json_file, indent=4)
# Delete the cache directory for the dataset
shutil.rmtree(dataset_cache_dir, ignore_errors=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", type=str, default="~/.cache")
parser.add_argument("--output_dir", type=str, default="./vqa_examples")
args = parser.parse_args()
datasets_info = {
"DocVQA": {
"path": "lmms-lab/DocVQA",
"image_key": "image",
"question_key": "question",
"id_key": "questionId",
"subset": "DocVQA",
"split": "test",
},
"ChartQA": {
"path": "HuggingFaceM4/ChartQA",
"image_key": "image",
"question_key": "query",
"id_key": "index",
"subset": False,
"split": "test",
},
"realworldqa": {
"path": "visheratin/realworldqa",
"image_key": "image",
"question_key": "question",
"id_key": "index",
"subset": False,
"split": "test",
},
"NewYorker": {
"path": "jmhessel/newyorker_caption_contest",
"image_key": "image",
"question_key": "questions",
"id_key": "index",
"subset": "explanation",
"split": "train",
},
"WikiArt": {
"path": "huggan/wikiart",
"image_key": "image",
"question_key": "artist",
"id_key": "index",
"subset": False,
"split": "train",
},
"TextVQA": {
"path": "facebook/textvqa",
"image_key": "image",
"question_key": "question",
"id_key": "question_id",
"subset": False,
"split": "train",
},
}
download_images_and_create_json(
datasets_info, cache_dir=args.data_dir, base_dir=args.output_dir
)
dataset_json = []
for dataset_name in datasets_info.keys():
with open(f"{args.output_dir}/{dataset_name}/data.json") as f:
data = json.load(f)
dataset_json.extend(np.random.choice(data, 500))
with open(f"{args.output_dir}/metadata_sampled.json", "w") as f:
json.dump(dataset_json, f, indent=4)