Spaces:
Runtime error
Runtime error
File size: 2,439 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import mmengine
from mmengine.dataset import BaseDataset
from mmengine.fileio import get_file_backend
from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class MiniGPT4Dataset(BaseDataset):
"""Dataset for training MiniGPT4.
MiniGPT4 dataset directory:
minigpt4_dataset
โโโ image
โ โโโ id0.jpg
โ โโโ id1.jpg
โ โโโ id2.jpg
โ โโโ ...
โโโ conversation_data.json
The structure of conversation_data.json:
[
// English data
{
"id": str(id0),
"conversation": "###Ask: <Img><ImageHere></Img> [Ask content]
###Answer: [Answer content]"
},
// Chinese data
{
"id": str(id1),
"conversation": "###้ฎ๏ผ<Img><ImageHere></Img> [Ask content]
###็ญ๏ผ[Answer content]"
},
...
]
Args:
data_root (str): The root directory for ``ann_file`` and ``image``.
ann_file (str): Conversation file path.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def load_data_list(self) -> List[dict]:
file_backend = get_file_backend(self.data_root)
conversation_path = file_backend.join_path(self.data_root,
self.ann_file)
conversation = mmengine.load(conversation_path)
img_ids = {}
n = 0
for conv in conversation:
img_id = conv['id']
if img_id not in img_ids.keys():
img_ids[img_id] = n
n += 1
img_root = file_backend.join_path(self.data_root, 'image')
data_list = []
for conv in conversation:
img_file = '{}.jpg'.format(conv['id'])
chat_content = conv['conversation']
lang = 'en' if chat_content.startswith('###Ask: ') else 'zh'
data_info = {
'image_id': img_ids[conv['id']],
'img_path': file_backend.join_path(img_root, img_file),
'chat_content': chat_content,
'lang': lang,
}
data_list.append(data_info)
return data_list
|