File size: 4,074 Bytes
24be7a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import os.path
import random

import numpy as np
import torch
import torch.utils.data as data
from PIL import Image


class DeepFashionAttrPoseDataset(data.Dataset):

    def __init__(self,
                 pose_dir,
                 texture_ann_dir,
                 shape_ann_path,
                 downsample_factor=2,
                 xflip=False):
        self._densepose_path = pose_dir
        self._image_fnames_target = []
        self._image_fnames = []
        self.upper_fused_attrs = []
        self.lower_fused_attrs = []
        self.outer_fused_attrs = []
        self.shape_attrs = []

        self.downsample_factor = downsample_factor
        self.xflip = xflip

        # load attributes
        assert os.path.exists(f'{texture_ann_dir}/upper_fused.txt')
        for idx, row in enumerate(
                open(os.path.join(f'{texture_ann_dir}/upper_fused.txt'), 'r')):
            annotations = row.split()
            self._image_fnames_target.append(annotations[0])
            self._image_fnames.append(f'{annotations[0].split(".")[0]}.png')
            self.upper_fused_attrs.append(int(annotations[1]))

        assert len(self._image_fnames_target) == len(self.upper_fused_attrs)

        assert os.path.exists(f'{texture_ann_dir}/lower_fused.txt')
        for idx, row in enumerate(
                open(os.path.join(f'{texture_ann_dir}/lower_fused.txt'), 'r')):
            annotations = row.split()
            assert self._image_fnames_target[idx] == annotations[0]
            self.lower_fused_attrs.append(int(annotations[1]))

        assert len(self._image_fnames_target) == len(self.lower_fused_attrs)

        assert os.path.exists(f'{texture_ann_dir}/outer_fused.txt')
        for idx, row in enumerate(
                open(os.path.join(f'{texture_ann_dir}/outer_fused.txt'), 'r')):
            annotations = row.split()
            assert self._image_fnames_target[idx] == annotations[0]
            self.outer_fused_attrs.append(int(annotations[1]))

        assert len(self._image_fnames_target) == len(self.outer_fused_attrs)

        assert os.path.exists(shape_ann_path)
        for idx, row in enumerate(open(os.path.join(shape_ann_path), 'r')):
            annotations = row.split()
            assert self._image_fnames_target[idx] == annotations[0]
            self.shape_attrs.append([int(i) for i in annotations[1:]])

    def _open_file(self, path_prefix, fname):
        return open(os.path.join(path_prefix, fname), 'rb')

    def _load_densepose(self, raw_idx):
        fname = self._image_fnames[raw_idx]
        fname = f'{fname[:-4]}_densepose.png'
        with self._open_file(self._densepose_path, fname) as f:
            densepose = Image.open(f)
            if self.downsample_factor != 1:
                width, height = densepose.size
                width = width // self.downsample_factor
                height = height // self.downsample_factor
                densepose = densepose.resize(
                    size=(width, height), resample=Image.NEAREST)
            # channel-wise IUV order, [3, H, W]
            densepose = np.array(densepose)[:, :, 2:].transpose(2, 0, 1)
        return densepose.astype(np.float32)

    def __getitem__(self, index):
        pose = self._load_densepose(index)
        shape_attr = self.shape_attrs[index]
        shape_attr = torch.LongTensor(shape_attr)

        if self.xflip and random.random() > 0.5:
            pose = pose[:, :, ::-1].copy()

        upper_fused_attr = self.upper_fused_attrs[index]
        lower_fused_attr = self.lower_fused_attrs[index]
        outer_fused_attr = self.outer_fused_attrs[index]

        pose = pose / 12. - 1

        return_dict = {
            'densepose': pose,
            'img_name': self._image_fnames_target[index],
            'shape_attr': shape_attr,
            'upper_fused_attr': upper_fused_attr,
            'lower_fused_attr': lower_fused_attr,
            'outer_fused_attr': outer_fused_attr,
        }

        return return_dict

    def __len__(self):
        return len(self._image_fnames)