wufan's picture
Upload 111 files
18e4b60 verified
raw
history blame
3.55 kB
import logging
from unimernet.common.registry import registry
from unimernet.datasets.builders.base_dataset_builder import BaseDatasetBuilder
from unimernet.datasets.datasets.formula import Im2LatexDataset
from unimernet.datasets.datasets.formula_multi_scale import MultiScaleIm2LatexDataset
@registry.register_builder("formula_rec_train")
class FormulaRecTrainBuilder(BaseDatasetBuilder):
train_dataset_cls = Im2LatexDataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/formula/formula_train.yaml"
}
LOG_INFO = "Formula Recgnition Train"
def build_datasets(self):
logging.info(f"Building {self.LOG_INFO} datasets ...")
self.build_processors()
build_info = self.config.build_info
anno_path = build_info.annotation,
vis_root = build_info.images
anno_path = [anno_path] if isinstance(anno_path, str) else anno_path
vis_root = [vis_root] if isinstance(vis_root, str) else vis_root
datasets = dict()
# create datasets
dataset_cls = self.train_dataset_cls
datasets['train'] = dataset_cls(
vis_processor=self.vis_processors["train"],
text_processor=self.text_processors["train"],
vis_root=vis_root,
anno_path=anno_path,
)
print(datasets['train'][0])
return datasets
@registry.register_builder("multi_scale_formula_rec_train")
class MultiScaleFormulaRecTrainBuilder(BaseDatasetBuilder):
train_dataset_cls = MultiScaleIm2LatexDataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/formula/multi_scale_formula_train.yaml"
}
LOG_INFO = "Multi Scale Formula Recgnition Train"
def build_datasets(self):
logging.info(f"Building {self.LOG_INFO} datasets ...")
self.build_processors()
build_info = self.config.build_info
anno_path = build_info.annotation,
vis_root = build_info.images
anno_path = [anno_path] if isinstance(anno_path, str) else anno_path
vis_root = [vis_root] if isinstance(vis_root, str) else vis_root
datasets = dict()
# create datasets
dataset_cls = self.train_dataset_cls
datasets['train'] = dataset_cls(
vis_processor=self.vis_processors["train"],
text_processor=self.text_processors["train"],
vis_root=vis_root,
anno_path=anno_path,
)
print(datasets['train'][0])
return datasets
@registry.register_builder("formula_rec_eval")
class FormulaRecEvalBuilder(BaseDatasetBuilder):
eval_dataset_cls = Im2LatexDataset
DATASET_CONFIG_DICT = {
"default": "configs/datasets/formula/formula_eval.yaml"
}
LOG_INFO = "Formula Recgnition Eval"
def build_datasets(self):
logging.info(f"Building {self.LOG_INFO} datasets ...")
self.build_processors()
build_info = self.config.build_info
anno_path = build_info.annotation,
vis_root = build_info.images
anno_path = [anno_path] if isinstance(anno_path, str) else anno_path
vis_root = [vis_root] if isinstance(vis_root, str) else vis_root
datasets = dict()
# create datasets
dataset_cls = self.eval_dataset_cls
datasets['eval'] = dataset_cls(
vis_processor=self.vis_processors["eval"],
text_processor=self.text_processors["eval"],
vis_root=vis_root,
anno_path=anno_path,
)
print(datasets['eval'][0])
return datasets