|
|
|
|
|
""" |
|
ReferIt, UNC, UNC+ and GRef referring image segmentation PyTorch dataset. |
|
|
|
Define and group batches of images, segmentations and queries. |
|
Based on: |
|
https://github.com/chenxi116/TF-phrasecut-public/blob/master/build_batches.py |
|
""" |
|
|
|
import os |
|
import re |
|
|
|
import sys |
|
import json |
|
import torch |
|
import numpy as np |
|
import os.path as osp |
|
import scipy.io as sio |
|
import torch.utils.data as data |
|
sys.path.append('.') |
|
|
|
from PIL import Image |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
from utils.word_utils import Corpus |
|
from utils.box_utils import sampleNegBBox |
|
from utils.genome_utils import getCLSLabel |
|
|
|
|
|
def read_examples(input_line, unique_id): |
|
"""Read a list of `InputExample`s from an input file.""" |
|
examples = [] |
|
|
|
line = input_line |
|
|
|
|
|
line = line.strip() |
|
text_a = None |
|
text_b = None |
|
m = re.match(r"^(.*) \|\|\| (.*)$", line) |
|
if m is None: |
|
text_a = line |
|
else: |
|
text_a = m.group(1) |
|
text_b = m.group(2) |
|
examples.append( |
|
InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) |
|
|
|
return examples |
|
|
|
|
|
class InputExample(object): |
|
def __init__(self, unique_id, text_a, text_b): |
|
self.unique_id = unique_id |
|
self.text_a = text_a |
|
self.text_b = text_b |
|
|
|
class InputFeatures(object): |
|
"""A single set of features of data.""" |
|
def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): |
|
self.unique_id = unique_id |
|
self.tokens = tokens |
|
self.input_ids = input_ids |
|
self.input_mask = input_mask |
|
self.input_type_ids = input_type_ids |
|
|
|
def convert_examples_to_features(examples, seq_length, tokenizer, usemarker=None): |
|
"""Loads a data file into a list of `InputBatch`s.""" |
|
features = [] |
|
for (ex_index, example) in enumerate(examples): |
|
tokens_a = tokenizer.tokenize(example.text_a) |
|
|
|
tokens_b = None |
|
if example.text_b: |
|
tokens_b = tokenizer.tokenize(example.text_b) |
|
|
|
if tokens_b: |
|
|
|
|
|
|
|
_truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) |
|
else: |
|
if usemarker is not None: |
|
|
|
marker_idx = [i for i,x in enumerate(tokens_a) if x=='*'] |
|
if marker_idx[1] > seq_length - 3 and len(tokens_a) - seq_length+1 < marker_idx[0]: |
|
tokens_a = tokens_a[-(seq_length-2):] |
|
new_marker_idx = [i for i,x in enumerate(tokens_a) if x=='*'] |
|
if len(new_marker_idx) < 2: |
|
pass |
|
elif len(tokens_a) - seq_length+1 >= marker_idx[0]: |
|
max_len = min(marker_idx[1]-marker_idx[0]+1, seq_length-2) |
|
tokens_a = tokens_a[marker_idx[0]: marker_idx[0]+max_len] |
|
tokens_a[-1] = '*' |
|
elif marker_idx[1]-marker_idx[0]<2: |
|
tokens_a = [i for i in tokens_a if i != '*'] |
|
tokens_a = ['*'] + tokens_a + ['*'] |
|
else: |
|
if len(tokens_a) > seq_length - 2: |
|
tokens_a = tokens_a[0:(seq_length - 2)] |
|
else: |
|
|
|
if len(tokens_a) > seq_length - 2: |
|
tokens_a = tokens_a[0:(seq_length - 2)] |
|
|
|
tokens = [] |
|
input_type_ids = [] |
|
tokens.append("[CLS]") |
|
input_type_ids.append(0) |
|
for token in tokens_a: |
|
tokens.append(token) |
|
input_type_ids.append(0) |
|
tokens.append("[SEP]") |
|
input_type_ids.append(0) |
|
|
|
if tokens_b: |
|
for token in tokens_b: |
|
tokens.append(token) |
|
input_type_ids.append(1) |
|
tokens.append("[SEP]") |
|
input_type_ids.append(1) |
|
|
|
input_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
|
|
|
|
|
|
input_mask = [1] * len(input_ids) |
|
|
|
|
|
while len(input_ids) < seq_length: |
|
input_ids.append(0) |
|
input_mask.append(0) |
|
input_type_ids.append(0) |
|
|
|
assert len(input_ids) == seq_length |
|
assert len(input_mask) == seq_length |
|
assert len(input_type_ids) == seq_length |
|
features.append( |
|
InputFeatures( |
|
unique_id=example.unique_id, |
|
tokens=tokens, |
|
input_ids=input_ids, |
|
input_mask=input_mask, |
|
input_type_ids=input_type_ids)) |
|
return features |
|
|
|
class DatasetNotFoundError(Exception): |
|
pass |
|
|
|
class TransVGDataset(data.Dataset): |
|
SUPPORTED_DATASETS = { |
|
'referit': {'splits': ('train', 'val', 'trainval', 'test')}, |
|
'unc': { |
|
'splits': ('train', 'val', 'trainval', 'testA', 'testB'), |
|
'params': {'dataset': 'refcoco', 'split_by': 'unc'} |
|
}, |
|
'unc+': { |
|
'splits': ('train', 'val', 'trainval', 'testA', 'testB'), |
|
'params': {'dataset': 'refcoco+', 'split_by': 'unc'} |
|
}, |
|
'gref': { |
|
'splits': ('train', 'val'), |
|
'params': {'dataset': 'refcocog', 'split_by': 'google'} |
|
}, |
|
'gref_umd': { |
|
'splits': ('train', 'val', 'test'), |
|
'params': {'dataset': 'refcocog', 'split_by': 'umd'} |
|
}, |
|
'flickr': { |
|
'splits': ('train', 'val', 'test') |
|
}, |
|
'MS_CXR': { |
|
'splits': ('train', 'val', 'test'), |
|
'params': {'dataset': 'MS_CXR', 'split_by': 'MS_CXR'} |
|
}, |
|
'ChestXray8': { |
|
'splits': ('train', 'val', 'test'), |
|
'params': {'dataset': 'ChestXray8', 'split_by': 'ChestXray8'} |
|
}, |
|
'SGH_CXR_V1': { |
|
'splits': ('train', 'val', 'test'), |
|
'params': {'dataset': 'SGH_CXR_V1', 'split_by': 'SGH_CXR_V1'} |
|
} |
|
|
|
} |
|
|
|
def __init__(self, args, data_root, split_root='data', dataset='referit', |
|
transform=None, return_idx=False, testmode=False, |
|
split='train', max_query_len=128, lstm=False, |
|
bert_model='bert-base-uncased'): |
|
self.images = [] |
|
self.data_root = data_root |
|
self.split_root = split_root |
|
self.dataset = dataset |
|
self.query_len = max_query_len |
|
self.lstm = lstm |
|
self.transform = transform |
|
self.testmode = testmode |
|
self.split = split |
|
self.tokenizer = AutoTokenizer.from_pretrained(bert_model, do_lower_case=True) |
|
self.return_idx=return_idx |
|
self.args = args |
|
self.ID_Categories = {1: 'Cardiomegaly', 2: 'Lung Opacity', 3:'Edema', 4: 'Consolidation', 5: 'Pneumonia', 6:'Atelectasis', 7: 'Pneumothorax', 8:'Pleural Effusion'} |
|
|
|
assert self.transform is not None |
|
|
|
if split == 'train': |
|
self.augment = True |
|
else: |
|
self.augment = False |
|
|
|
if self.dataset == 'MS_CXR': |
|
self.dataset_root = osp.join(self.data_root, 'MS_CXR') |
|
self.im_dir = self.dataset_root |
|
elif self.dataset == 'ChestXray8': |
|
self.dataset_root = osp.join(self.data_root, 'ChestXray8') |
|
self.im_dir = self.dataset_root |
|
elif self.dataset == 'SGH_CXR_V1': |
|
self.dataset_root = osp.join(self.data_root, 'SGH_CXR_V1') |
|
self.im_dir = self.dataset_root |
|
elif self.dataset == 'referit': |
|
self.dataset_root = osp.join(self.data_root, 'referit') |
|
self.im_dir = osp.join(self.dataset_root, 'images') |
|
self.split_dir = osp.join(self.dataset_root, 'splits') |
|
elif self.dataset == 'flickr': |
|
self.dataset_root = osp.join(self.data_root, 'Flickr30k') |
|
self.im_dir = osp.join(self.dataset_root, 'flickr30k_images') |
|
else: |
|
self.dataset_root = osp.join(self.data_root, 'other') |
|
self.im_dir = osp.join( |
|
self.dataset_root, 'images', 'mscoco', 'images', 'train2014') |
|
self.split_dir = osp.join(self.dataset_root, 'splits') |
|
|
|
if not self.exists_dataset(): |
|
|
|
print('Please download index cache to data folder: \n \ |
|
https://drive.google.com/open?id=1cZI562MABLtAzM6YU4WmKPFFguuVr0lZ') |
|
exit(0) |
|
|
|
dataset_path = osp.join(self.split_root, self.dataset) |
|
valid_splits = self.SUPPORTED_DATASETS[self.dataset]['splits'] |
|
|
|
if self.lstm: |
|
self.corpus = Corpus() |
|
corpus_path = osp.join(dataset_path, 'corpus.pth') |
|
self.corpus = torch.load(corpus_path) |
|
|
|
if split not in valid_splits: |
|
raise ValueError( |
|
'Dataset {0} does not have split {1}'.format( |
|
self.dataset, split)) |
|
|
|
splits = [split] |
|
if self.dataset != 'referit': |
|
splits = ['train', 'val'] if split == 'trainval' else [split] |
|
for split in splits: |
|
imgset_file = '{0}_{1}.pth'.format(self.dataset, split) |
|
imgset_path = osp.join(dataset_path, imgset_file) |
|
self.images += torch.load(imgset_path) |
|
|
|
def exists_dataset(self): |
|
return osp.exists(osp.join(self.split_root, self.dataset)) |
|
|
|
def pull_item(self, idx): |
|
info = {} |
|
if self.dataset == 'MS_CXR': |
|
|
|
anno_id, image_id, category_id, img_file, bbox, width, height, phrase = self.images[idx] |
|
info['anno_id'] = anno_id |
|
info['category_id'] = category_id |
|
elif self.dataset == 'ChestXray8': |
|
anno_id, image_id, category_id, img_file, bbox, phrase, prompt_text = self.images[idx] |
|
info['anno_id'] = anno_id |
|
info['category_id'] = category_id |
|
|
|
elif self.dataset == 'SGH_CXR_V1': |
|
anno_id, image_id, category_id, img_file, bbox, phrase, patient_id = self.images[idx] |
|
info['anno_id'] = anno_id |
|
info['category_id'] = category_id |
|
elif self.dataset == 'flickr': |
|
img_file, bbox, phrase = self.images[idx] |
|
else: |
|
img_file, _, bbox, phrase, attri = self.images[idx] |
|
|
|
if not (self.dataset == 'referit' or self.dataset == 'flickr'): |
|
bbox = np.array(bbox, dtype=int) |
|
bbox[2], bbox[3] = bbox[0]+bbox[2], bbox[1]+bbox[3] |
|
else: |
|
bbox = np.array(bbox, dtype=int) |
|
|
|
|
|
if self.args.ablation == 'onlyText': |
|
img_file = 'files/p12/p12423759/s53349935/b8c7a778-2f7f712d-5c598645-6aeebbb3-66ffbcc7.jpg' |
|
|
|
img_path = osp.join(self.im_dir, img_file) |
|
info['img_path'] = img_path |
|
img = Image.open(img_path).convert("RGB") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bbox = torch.tensor(bbox) |
|
bbox = bbox.float() |
|
|
|
return img, phrase, bbox, info |
|
|
|
def tokenize_phrase(self, phrase): |
|
return self.corpus.tokenize(phrase, self.query_len) |
|
|
|
def untokenize_word_vector(self, words): |
|
return self.corpus.dictionary[words] |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, idx): |
|
img, phrase, bbox, info = self.pull_item(idx) |
|
|
|
phrase = phrase.lower() |
|
if hasattr(self.args, 'CATextPoolType') and self.args.CATextPoolType == 'marker': |
|
|
|
phrase = info['phrase_marker'] |
|
info['phrase_record'] = phrase |
|
input_dict = {'img': img, 'box': bbox, 'text': phrase} |
|
|
|
if self.args.model_name == 'TransVG_ca' and self.split == 'train': |
|
NegBBoxs = sampleNegBBox(bbox, self.args.CAsampleType, self.args.CAsampleNum) |
|
|
|
input_dict = {'img': img, 'box': bbox, 'text': phrase, 'NegBBoxs': NegBBoxs} |
|
if self.args.model_name == 'TransVG_gn' and self.split == 'train': |
|
json_name = os.path.splitext(os.path.basename(info['img_path']))[0]+'_SceneGraph.json' |
|
json_name = os.path.join(self.args.GNpath, json_name) |
|
|
|
gnLabel = getCLSLabel(json_name, bbox) |
|
info['gnLabel'] = gnLabel |
|
|
|
input_dict = self.transform(input_dict) |
|
img = input_dict['img'] |
|
bbox = input_dict['box'] |
|
phrase = input_dict['text'] |
|
img_mask = input_dict['mask'] |
|
if self.args.model_name == 'TransVG_ca' and self.split == 'train': |
|
info['NegBBoxs'] = [np.array(negBBox, dtype=np.float32) for negBBox in input_dict['NegBBoxs']] |
|
|
|
if self.lstm: |
|
phrase = self.tokenize_phrase(phrase) |
|
word_id = phrase |
|
word_mask = np.array(word_id>0, dtype=int) |
|
else: |
|
|
|
examples = read_examples(phrase, idx) |
|
if hasattr(self.args, 'CATextPoolType') and self.args.CATextPoolType == 'marker': |
|
use_marker = 'yes' |
|
else: |
|
use_marker = None |
|
features = convert_examples_to_features( |
|
examples=examples, seq_length=self.query_len, tokenizer=self.tokenizer, usemarker=use_marker) |
|
word_id = features[0].input_ids |
|
word_mask = features[0].input_mask |
|
if self.args.ablation == 'onlyImage': |
|
word_mask = [0] * word_mask.__len__() |
|
|
|
|
|
|
|
if self.testmode: |
|
return img, np.array(word_id, dtype=int), np.array(word_mask, dtype=int), \ |
|
np.array(bbox, dtype=np.float32), np.array(ratio, dtype=np.float32), \ |
|
np.array(dw, dtype=np.float32), np.array(dh, dtype=np.float32), self.images[idx][0] |
|
else: |
|
return img, np.array(img_mask), np.array(word_id, dtype=int), np.array(word_mask, dtype=int), np.array(bbox, dtype=np.float32), info |