File size: 2,439 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import mmengine
from mmengine.dataset import BaseDataset
from mmengine.fileio import get_file_backend

from mmpretrain.registry import DATASETS


@DATASETS.register_module()
class MiniGPT4Dataset(BaseDataset):
    """Dataset for training MiniGPT4.

    MiniGPT4 dataset directory:

        minigpt4_dataset
            โ”œโ”€โ”€ image
            โ”‚   โ”œโ”€โ”€ id0.jpg
            โ”‚   โ”‚โ”€โ”€ id1.jpg
            โ”‚   โ”‚โ”€โ”€ id2.jpg
            โ”‚   โ””โ”€โ”€ ...
            โ””โ”€โ”€ conversation_data.json

    The structure of conversation_data.json:

        [
            // English data
            {
                "id": str(id0),
                "conversation": "###Ask: <Img><ImageHere></Img> [Ask content]
                                ###Answer: [Answer content]"
            },

            // Chinese data
            {
                "id": str(id1),
                "conversation": "###้—ฎ๏ผš<Img><ImageHere></Img> [Ask content]
                                ###็ญ”๏ผš[Answer content]"
            },

            ...
        ]

    Args:
        data_root (str): The root directory for ``ann_file`` and ``image``.
        ann_file (str): Conversation file path.
        **kwargs: Other keyword arguments in :class:`BaseDataset`.
    """

    def load_data_list(self) -> List[dict]:
        file_backend = get_file_backend(self.data_root)
        conversation_path = file_backend.join_path(self.data_root,
                                                   self.ann_file)
        conversation = mmengine.load(conversation_path)
        img_ids = {}
        n = 0
        for conv in conversation:
            img_id = conv['id']
            if img_id not in img_ids.keys():
                img_ids[img_id] = n
                n += 1

        img_root = file_backend.join_path(self.data_root, 'image')
        data_list = []
        for conv in conversation:
            img_file = '{}.jpg'.format(conv['id'])
            chat_content = conv['conversation']
            lang = 'en' if chat_content.startswith('###Ask: ') else 'zh'
            data_info = {
                'image_id': img_ids[conv['id']],
                'img_path': file_backend.join_path(img_root, img_file),
                'chat_content': chat_content,
                'lang': lang,
            }

            data_list.append(data_info)

        return data_list