# -------------------------------------------------------- # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442) # Github source: https://github.com/microsoft/unilm/tree/master/beit3 # Copyright (c) 2023 Microsoft # Licensed under The MIT License [see LICENSE for details] # --------------------------------------------------------' import os import json import random import torch import glob from collections import defaultdict, Counter from torchvision import transforms from torchvision.datasets.folder import default_loader from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data.transforms import RandomResizedCropAndInterpolation from timm.data import create_transform import utils from glossary import normalize_word from randaug import RandomAugment class BaseDataset(torch.utils.data.Dataset): def __init__( self, data_path, split, transform, tokenizer, num_max_bpe_tokens, task=None, ): index_files = self.get_index_files(split, task=task) self.tokenizer = tokenizer self.num_max_bpe_tokens = num_max_bpe_tokens self.data_path = data_path items = [] self.index_files = index_files offset = 0 for _index_file in index_files: index_file = os.path.join(data_path, _index_file) with open(index_file, mode="r", encoding="utf-8") as reader: for line in reader: data = json.loads(line) items.append(data) print("Load %d image-text pairs from %s. " % (len(items) - offset, index_file)) offset = len(items) self.items = items self.bos_token_id = tokenizer.bos_token_id self.eos_token_id = tokenizer.eos_token_id self.pad_token_id = tokenizer.pad_token_id self.loader = default_loader self.transform = transform self.split = split @staticmethod def get_index_files(split): raise NotImplementedError() def _get_image(self, image_path: str): image_path = os.path.join(self.data_path, image_path) image = self.loader(image_path) return self.transform(image) def _get_text_segment(self, text_segment, max_len=None): if isinstance(text_segment, str): tokens = self.tokenizer.tokenize(text_segment) else: tokens = text_segment[:] if len(tokens) == 0: raise RuntimeError("The text segment should contains at least one tokens!") if max_len is None: max_len = self.num_max_bpe_tokens if len(tokens) > max_len - 2: tokens = tokens[:max_len - 2] tokens = [self.bos_token_id] + tokens[:] + [self.eos_token_id] num_tokens = len(tokens) padding_mask = [0] * num_tokens + [1] * (max_len - num_tokens) return tokens + [self.pad_token_id] * (max_len - num_tokens), padding_mask, num_tokens def _get_image_text_example(self, index: int, data: dict): item = self.items[index] img_path = item["image_path"] img = self._get_image(img_path) data["image"] = img text_segment = item["text_segment"] language_tokens, padding_mask, _ = self._get_text_segment(text_segment) data["language_tokens"] = language_tokens data["padding_mask"] = padding_mask def __getitem__(self, index: int): data = dict() self._get_image_text_example(index, data) return data def __len__(self) -> int: return len(self.items) def __repr__(self) -> str: head = "Dataset " + self.__class__.__name__ body = '{' + "\n Number of items: %s," % self.__len__() body += "\n data root = %s," % self.data_path body += "\n split = %s," % self.split body += "\n dataset index files = %s" % str(self.index_files) body += "\n num max bpe tokens = %s" % self.num_max_bpe_tokens body += "\n transforms = [" for t in self.transform.transforms: body += "\n %s" % str(t) body += "\n ]" body += "\n}" return head + body def _write_data_into_jsonl(items, jsonl_file): with open(jsonl_file, mode="w", encoding="utf-8") as writer: for data in items: writer.write(json.dumps(data, indent=None)) writer.write('\n') print("Write %s with %d items !" % (jsonl_file, len(items))) def _make_retrieval_coco_karpathy_dataset_index( data_path, tokenizer, split=("train", "restval"), split_name="train", ): coco_karpathy_split_json_file = os.path.join(data_path, "dataset_coco.json") items = [] image_counter = set() print("read %s" % coco_karpathy_split_json_file) with open(coco_karpathy_split_json_file, mode="r", encoding="utf-8") as reader: data = json.loads(reader.read()) for item in data["images"]: if item["split"] in split: image_path = os.path.join(item["filepath"], item["filename"]) for sent in item["sentences"]: tokens = tokenizer.tokenize(sent["raw"]) token_ids = tokenizer.convert_tokens_to_ids(tokens) items.append({ "image_path": image_path, "text_segment": token_ids, "image_id": len(image_counter), }) if image_path not in image_counter: image_counter.add(image_path) print("Find %d images and %d image-text pairs for karpathy dataset %s split !" % \ (len(image_counter), len(items), split_name)) index_file = os.path.join(data_path, "coco_retrieval.%s.jsonl" % split_name) _write_data_into_jsonl(items, index_file) pass def _make_captioning_coco_karpathy_dataset_index( data_path, tokenizer, split=("train", "restval"), split_name="train", ): coco_karpathy_split_json_file = os.path.join(data_path, "dataset_coco.json") items = [] image_counter = set() print("read %s" % coco_karpathy_split_json_file) with open(coco_karpathy_split_json_file, mode="r", encoding="utf-8") as reader: data = json.loads(reader.read()) for item in data["images"]: if item["split"] in split: image_path = os.path.join(item["filepath"], item["filename"]) if item["split"] in ["train", "restval"]: for sent in item["sentences"]: tokens = tokenizer.tokenize(sent["raw"]) token_ids = tokenizer.convert_tokens_to_ids(tokens) items.append({ "image_path": image_path, "text_segment": token_ids, "image_id": item["cocoid"], }) else: items.append({ "image_path": image_path, "text_segment": None, "image_id": item["cocoid"], }) if image_path not in image_counter: image_counter.add(image_path) print("Find %d images and %d image-text pairs for karpathy dataset %s split !" % \ (len(image_counter), len(items), split_name)) index_file = os.path.join(data_path, "coco_captioning.%s.jsonl" % split_name) _write_data_into_jsonl(items, index_file) pass def _make_nocaps_dataset_index( data_path, split="val", ): if split == "val": json_file = "nocaps_val_4500_captions.json" elif split == "test": json_file = "nocaps_test_image_info.json" nocaps_split_json_file = os.path.join(data_path, json_file) items = [] image_counter = set() print("read %s" % nocaps_split_json_file) with open(nocaps_split_json_file, mode="r", encoding="utf-8") as reader: data = json.loads(reader.read()) for item in data["images"]: image_path = os.path.join(split, item["file_name"]) items.append({ "image_path": image_path, "text_segment": None, "image_id": item["id"], }) if image_path not in image_counter: image_counter.add(image_path) print("Find %d images and %d image-text pairs for nocaps dataset %s split !" % \ (len(image_counter), len(items), split)) index_file = os.path.join(data_path, "nocaps.%s.jsonl" % split) _write_data_into_jsonl(items, index_file) class NLVR2Dataset(BaseDataset): @staticmethod def get_index_files(split, task=None): if split == "train": return ("nlvr2.train.index.jsonl", ) elif split == "val": return ("nlvr2.dev.index.jsonl", ) elif split == "test": return ("nlvr2.test-P.index.jsonl", ) else: raise RuntimeError("split %s is not found!" % split) def __getitem__(self, index: int): data = super().__getitem__(index) item = self.items[index] img_path = item["image2_path"] img = self._get_image(img_path) data["image2"] = img data["label"] = self.items[index]["label"] return data @staticmethod def __preprocess_json(preifx, json_file, tokenizer, index_file): items = [] with open(json_file, mode="r", encoding="utf-8") as reader: for line in reader: data = json.loads(line) path = os.path.join(preifx, str(data["directory"])) if "directory" in data else preifx path = os.path.join(path, "-".join(data["identifier"].split("-")[:-1])) tokens = tokenizer.tokenize(data["sentence"]) token_ids = tokenizer.convert_tokens_to_ids(tokens) items.append({ "image_path": path + "-img0.png", "image2_path": path + "-img1.png", "text_segment": token_ids, "label": 1 if data["label"] == "True" else 0, "identifier": data["identifier"], }) _write_data_into_jsonl(items, index_file) @classmethod def make_dataset_index(cls, data_path, tokenizer, nlvr_repo_path): cls.__preprocess_json( preifx="images/train", json_file=os.path.join(nlvr_repo_path, "nlvr2/data/train.json"), tokenizer=tokenizer, index_file=os.path.join(data_path, cls.get_index_files("train")[0]), ) cls.__preprocess_json( preifx="dev", json_file=os.path.join(nlvr_repo_path, "nlvr2/data/dev.json"), tokenizer=tokenizer, index_file=os.path.join(data_path, cls.get_index_files("val")[0]), ) cls.__preprocess_json( preifx="test1", json_file=os.path.join(nlvr_repo_path, "nlvr2/data/test1.json"), tokenizer=tokenizer, index_file=os.path.join(data_path, cls.get_index_files("test")[0]), ) class ImageNetDataset(BaseDataset): @staticmethod def get_index_files(split, task=None): if split == "train": return ("imagenet.train.index.jsonl", ) elif split == "val": return ("imagenet.val.index.jsonl", ) elif split == "test": return ("imagenet.val.index.jsonl", ) else: raise RuntimeError("split %s is not found!" % split) def __getitem__(self, index: int): data = dict() item = self.items[index] img_path = item["image_path"] img = self._get_image(img_path) data["image"] = img data["label"] = item["label"] return data @staticmethod def _find_classes(dir): """ Finds the class folders in a dataset. Args: dir (string): Root directory path. Returns: tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. Ensures: No class is a subdirectory of another. """ classes = [d.name for d in os.scandir(dir) if d.is_dir()] classes.sort() class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} return classes, class_to_idx @staticmethod def _make_imagenet_index(data_path, index_path, data_path_prefix, class_to_idx, split): items = [] index_file = os.path.join(index_path, f"imagenet.{split}.index.jsonl") for target_class in sorted(class_to_idx.keys()): class_index = class_to_idx[target_class] target_dir = os.path.join(data_path, target_class) if not os.path.isdir(target_dir): continue for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): for fname in sorted(fnames): path = os.path.join(root, fname) path = path.replace(data_path_prefix, "") items.append({ "image_path": path, "label": class_index, }) _write_data_into_jsonl(items, index_file) @classmethod def make_dataset_index(cls, train_data_path, val_data_path, index_path): data_path_prefix = train_data_path[:[x[0]==x[1] for x in zip(train_data_path, val_data_path)].index(0)] classes, class_to_idx = cls._find_classes(train_data_path) cls._make_imagenet_index( data_path=train_data_path, index_path=index_path, data_path_prefix=data_path_prefix, class_to_idx=class_to_idx, split="train", ) cls._make_imagenet_index( data_path=val_data_path, index_path=index_path, data_path_prefix=data_path_prefix, class_to_idx=class_to_idx, split="val", ) class VQAv2Dataset(BaseDataset): def __init__(self, data_path, **kwargs): super().__init__(data_path=data_path, **kwargs) ans2label_file = os.path.join(data_path, "answer2label.txt") ans2label = {} label2ans = [] with open(ans2label_file, mode="r", encoding="utf-8") as reader: for i, line in enumerate(reader): data = json.loads(line) ans = data["answer"] label = data["label"] label = int(label) assert label == i ans2label[ans] = i label2ans.append(ans) self.ans2label = ans2label self.label2ans = label2ans @staticmethod def get_index_files(split, task=None): if split == "train": return ("vqa.train.jsonl", "vqa.trainable_val.jsonl") elif split == "val": return ("vqa.rest_val.jsonl", ) elif split == "test": return ("vqa.test.jsonl", ) elif split == "test-dev": return ("vqa.test-dev.jsonl", ) else: raise RuntimeError("split %s is not found!" % split) def __getitem__(self, index: int): data = super().__getitem__(index) if "labels" in self.items[index] and len(self.items[index]["labels"]) > 0: labels = [0.] * len(self.label2ans) for l, s in zip(self.items[index]["labels"], self.items[index]["scores"]): labels[l] = s data["labels"] = torch.FloatTensor(labels) else: data["qid"] = self.items[index]["qid"] return data @staticmethod def get_score(occurences): if occurences == 0: return 0.0 elif occurences == 1: return 0.3 elif occurences == 2: return 0.6 elif occurences == 3: return 0.9 else: return 1.0 @classmethod def make_dataset_index(cls, data_path, tokenizer, annotation_data_path): with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_train2014_questions.json"), "r") as fp: questions_train2014 = json.load(fp)["questions"] with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_val2014_questions.json"), "r") as fp: questions_val2014 = json.load(fp)["questions"] with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_test2015_questions.json"), "r") as fp: questions_test2015 = json.load(fp)["questions"] with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_test-dev2015_questions.json"), "r") as fp: questions_test_dev2015 = json.load(fp)["questions"] with open(os.path.join(annotation_data_path, "v2_mscoco_train2014_annotations.json"), "r") as fp: annotations_train2014 = json.load(fp)["annotations"] with open(os.path.join(annotation_data_path, "v2_mscoco_val2014_annotations.json"), "r") as fp: annotations_val2014 = json.load(fp)["annotations"] annotations = dict() for split, questions in zip( ["train", "val", "test", "test-dev"], [questions_train2014, questions_val2014, questions_test2015, questions_test_dev2015], ): _annot = defaultdict(dict) for q in questions: question_text = q["question"] tokens = tokenizer.tokenize(question_text) token_ids = tokenizer.convert_tokens_to_ids(tokens) assert q["question_id"] not in _annot[q["image_id"]] _annot[q["image_id"]][q["question_id"]] = { "question": question_text, "token_ids": token_ids, } annotations[split] = _annot all_major_answers = list() for split, annots in zip( ["train", "val"], [annotations_train2014, annotations_val2014], ): # _annot = annotations[split] for q in annots: all_major_answers.append(q["multiple_choice_answer"]) all_major_answers = [normalize_word(word) for word in all_major_answers] counter = {k: v for k, v in Counter(all_major_answers).items() if v >= 9} ans2label = {k: i for i, k in enumerate(counter.keys())} label2ans = list(counter.keys()) for split, annots in zip( ["train", "val"], [annotations_train2014, annotations_val2014], ): _annot = annotations[split] for q in annots: answers = q["answers"] answer_count = {} for answer in answers: answer_ = answer["answer"] answer_count[answer_] = answer_count.get(answer_, 0) + 1 labels = [] scores = [] for answer in answer_count: if answer not in ans2label: continue labels.append(ans2label[answer]) score = cls.get_score(answer_count[answer]) scores.append(score) assert "labels" not in _annot[q["image_id"]][q["question_id"]] assert "question" in _annot[q["image_id"]][q["question_id"]] _annot[q["image_id"]][q["question_id"]]["labels"] = labels _annot[q["image_id"]][q["question_id"]]["scores"] = scores for split in ["train", "val"]: filtered_annot = dict() for ik, iv in annotations[split].items(): new_q = dict() for qk, qv in iv.items(): if len(qv["labels"]) != 0: new_q[qk] = qv if len(new_q) != 0: filtered_annot[ik] = new_q annotations[split] = filtered_annot split2items = {} for split in ["train", "val", "test", "test-dev"]: annot = annotations[split] split_name = { "train": "train2014", "val": "val2014", "test": "test2015", "test-dev": "test2015", }[split] paths = list(glob.glob(f"{data_path}/{split_name}/*.jpg")) random.shuffle(paths) annot_paths = [path for path in paths \ if int(path.split("/")[-1].split("_")[-1][:-4]) in annot] if len(paths) == len(annot_paths): print("all images have caption annotations") else: print("not all images have caption annotations") print(len(paths), len(annot_paths), len(annot)) items = [] for path in annot_paths: iid = int(path.split("/")[-1].split("_")[-1][:-4]) _annot = annotations[split][iid] for qid in _annot: q = _annot[qid] if split in ["train", "val"]: labels = q["labels"] scores = q["scores"] else: labels, scores = [], [] items.append({ "image_path": os.path.join(split_name, path.split('/')[-1]), "text_segment": q["token_ids"], "labels": labels, "scores": scores, "qid": qid, }) split2items[split] = items _write_data_into_jsonl(items=items, jsonl_file=os.path.join(data_path, "vqa.%s.jsonl" % split)) # Following ViLT, we use 1000 images of the original val set as the final val set val_image2items = defaultdict(list) for item in split2items["val"]: val_image2items[item["image_path"]].append(item) print("Contains %d image and %d pairs for val set!" % (len(val_image2items), len(split2items["val"]))) val_images = list(val_image2items.keys()) random.shuffle(val_images) trainable_val = [] rest_val = [] for i, image_id in enumerate(val_images): if i < 1000: rest_val += val_image2items[image_id] else: trainable_val += val_image2items[image_id] _write_data_into_jsonl(items=trainable_val, jsonl_file=os.path.join(data_path, "vqa.trainable_val.jsonl")) _write_data_into_jsonl(items=rest_val, jsonl_file=os.path.join(data_path, "vqa.rest_val.jsonl")) with open(os.path.join(data_path, "answer2label.txt"), mode="w", encoding="utf-8") as writer: for ans in ans2label: to_json = { "answer": ans, "label": ans2label[ans] } writer.write("%s\n" % json.dumps(to_json)) class RetrievalDataset(BaseDataset): @staticmethod def get_index_files(split, task=None): if split == "train": return (f"{task}.train.jsonl", ) elif split == "val": return (f"{task}.val.jsonl", ) elif split == "test": return (f"{task}.test.jsonl", ) else: raise RuntimeError("split %s is not found!" % split) def __getitem__(self, index: int): data = super().__getitem__(index) data["image_id"] = self.items[index]["image_id"] return data @staticmethod def make_flickr30k_dataset_index(data_path, tokenizer, karpathy_path): with open(os.path.join(karpathy_path, "dataset_flickr30k.json"), "r") as reader: captions = json.loads(reader.read()) captions = captions["images"] split2items = defaultdict(list) split2images = defaultdict(set) for each_item in captions: image_path = os.path.join("flickr30k-images", each_item["filename"]) split = each_item["split"] for text_segment in each_item["sentences"]: tokens = tokenizer.tokenize(text_segment["raw"]) token_ids = tokenizer.convert_tokens_to_ids(tokens) split2items[split].append({ "image_path": image_path, "text_segment": token_ids, "image_id": len(split2images[split]), }) assert each_item["filename"] not in split2images[split] split2images[split].add(each_item["filename"]) for split in split2items: print("%d images and %d image-text pairs!" % (len(split2images[split]), len(split2items[split]))) _write_data_into_jsonl(split2items[split], os.path.join(data_path, "flickr30k.%s.jsonl" % split)) @staticmethod def make_coco_dataset_index(data_path, tokenizer): _make_retrieval_coco_karpathy_dataset_index(data_path, tokenizer, split=("train", "restval"), split_name="train") _make_retrieval_coco_karpathy_dataset_index(data_path, tokenizer, split=("val", ), split_name="val") _make_retrieval_coco_karpathy_dataset_index(data_path, tokenizer, split=("test", ), split_name="test") class CaptioningDataset(BaseDataset): def __init__(self, data_path, split, transform, tokenizer, num_max_bpe_tokens, task, mask_prob): super().__init__( data_path=data_path, split=split, transform=transform, tokenizer=tokenizer, num_max_bpe_tokens=num_max_bpe_tokens, task=task, ) self.mask_token_id = tokenizer.mask_token_id self.language_vocab_size = tokenizer.vocab_size self.mask_prob = mask_prob @staticmethod def get_index_files(split, task=None): if split == "train": return ("coco_captioning.train.jsonl", ) elif split == "val": return (f"{task}.val.jsonl", ) elif split == "test": return (f"{task}.test.jsonl", ) else: raise RuntimeError("split %s is not found!" % split) def _get_mask_token(self, token): p = random.random() if p < 0.8: return self.mask_token_id elif p < 0.9: return token else: return random.randint(3, self.language_vocab_size - 1) def _masking_on_text_tokens(self, tokens, num_tokens, mask_prob): bool_masked_pos = [0] * len(tokens) to_mask = min(int(num_tokens * mask_prob + 0.5), num_tokens - 1) to_mask = max(to_mask, 1) num_masked_tokens = 0 while num_masked_tokens < to_mask: i = random.randint(1, num_tokens - 1) if bool_masked_pos[i] == 0: bool_masked_pos[i] = 1 tokens[i] = self._get_mask_token(tokens[i]) num_masked_tokens += 1 return tokens, bool_masked_pos def __getitem__(self, index: int): data = dict() item = self.items[index] img_path = item["image_path"] img = self._get_image(img_path) data["image"] = img data["image_id"] = item["image_id"] text_segment = item["text_segment"] if text_segment is not None: language_tokens, padding_mask, num_tokens = self._get_text_segment(text_segment) masked_tokens = language_tokens[:] masked_tokens, language_masked_pos = \ self._masking_on_text_tokens(masked_tokens, num_tokens, self.mask_prob) data["language_tokens"] = language_tokens data["masked_tokens"] = masked_tokens data["language_masked_pos"] = language_masked_pos data["padding_mask"] = padding_mask return data @staticmethod def make_coco_captioning_dataset_index(data_path, tokenizer): _make_captioning_coco_karpathy_dataset_index(data_path, tokenizer, split=("train", "restval"), split_name="train") _make_captioning_coco_karpathy_dataset_index(data_path, tokenizer, split=("val", ), split_name="val") _make_captioning_coco_karpathy_dataset_index(data_path, tokenizer, split=("test", ), split_name="test") @staticmethod def make_nocaps_captioning_dataset_index(data_path): _make_nocaps_dataset_index(data_path, split="val") _make_nocaps_dataset_index(data_path, split="test") task2dataset = { "nlvr2": NLVR2Dataset, "vqav2": VQAv2Dataset, "flickr30k": RetrievalDataset, "coco_retrieval": RetrievalDataset, "coco_captioning": CaptioningDataset, "nocaps": CaptioningDataset, "imagenet": ImageNetDataset, } def create_dataloader(dataset, is_train, batch_size, num_workers, pin_mem, dist_eval=False): if is_train or dist_eval: num_tasks = utils.get_world_size() global_rank = utils.get_rank() if not is_train and dist_eval and len(dataset) % num_tasks != 0: print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'equal num of samples per-process.') sampler = torch.utils.data.DistributedSampler( dataset, num_replicas=num_tasks, rank=global_rank, shuffle=is_train ) else: sampler = torch.utils.data.SequentialSampler(dataset) return torch.utils.data.DataLoader( dataset, sampler=sampler, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_mem, drop_last=is_train, collate_fn=utils.merge_batch_tensors_by_dict_key, ) def build_transform(is_train, args): if args.task in ["imagenet"]: return build_imagenet_transform(is_train, args) if is_train: t = [ RandomResizedCropAndInterpolation(args.input_size, scale=(0.5, 1.0), interpolation=args.train_interpolation), transforms.RandomHorizontalFlip(), ] if args.randaug: t.append( RandomAugment( 2, 7, isPIL=True, augs=[ 'Identity','AutoContrast','Equalize','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', ])) t += [ transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), ] t = transforms.Compose(t) else: t = transforms.Compose([ transforms.Resize((args.input_size, args.input_size), interpolation=3), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD) ]) return t def build_imagenet_transform(is_train, args): resize_im = args.input_size > 32 if is_train: # this should always dispatch to transforms_imagenet_train transform = create_transform( input_size=args.input_size, is_training=True, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation=args.train_interpolation, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, ) if not resize_im: # replace RandomResizedCropAndInterpolation with # RandomCrop transform.transforms[0] = transforms.RandomCrop( args.input_size, padding=4) return transform t = [] if resize_im: if args.crop_pct is None: args.crop_pct = 1.0 size = int(args.input_size / args.crop_pct) t.append( transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images ) t.append(transforms.CenterCrop(args.input_size)) t.append(transforms.ToTensor()) t.append(transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)) return transforms.Compose(t) def get_sentencepiece_model_for_beit3(args): from transformers import XLMRobertaTokenizer return XLMRobertaTokenizer(args.sentencepiece_model) def create_dataset_by_split(args, split, is_train=True): transform = build_transform(is_train=is_train, args=args) dataset_class = task2dataset[args.task] tokenizer = get_sentencepiece_model_for_beit3(args) opt_kwargs = {} if args.task in ["coco_captioning", "nocaps"]: opt_kwargs["mask_prob"] = args.captioning_mask_prob dataset = dataset_class( data_path=args.data_path, split=split, transform=transform, tokenizer=tokenizer, num_max_bpe_tokens=args.num_max_bpe_tokens, task=args.task, **opt_kwargs, ) if is_train: batch_size = args.batch_size elif hasattr(args, "eval_batch_size") and args.eval_batch_size is not None: batch_size = args.eval_batch_size else: batch_size = int(args.batch_size * 1.5) return create_dataloader( dataset, is_train=is_train, batch_size=batch_size, num_workers=args.num_workers, pin_mem=args.pin_mem, dist_eval=args.dist_eval, ) def create_downstream_dataset(args, is_eval=False): if is_eval: return create_dataset_by_split(args, split="test", is_train=False) else: return \ create_dataset_by_split(args, split="train", is_train=True), \ create_dataset_by_split(args, split="val", is_train=True)