Spaces:
Runtime error
Runtime error
import logging | |
from unimernet.common.registry import registry | |
from unimernet.datasets.builders.base_dataset_builder import BaseDatasetBuilder | |
from unimernet.datasets.datasets.formula import Im2LatexDataset | |
from unimernet.datasets.datasets.formula_multi_scale import MultiScaleIm2LatexDataset | |
class FormulaRecTrainBuilder(BaseDatasetBuilder): | |
train_dataset_cls = Im2LatexDataset | |
DATASET_CONFIG_DICT = { | |
"default": "configs/datasets/formula/formula_train.yaml" | |
} | |
LOG_INFO = "Formula Recgnition Train" | |
def build_datasets(self): | |
logging.info(f"Building {self.LOG_INFO} datasets ...") | |
self.build_processors() | |
build_info = self.config.build_info | |
anno_path = build_info.annotation, | |
vis_root = build_info.images | |
anno_path = [anno_path] if isinstance(anno_path, str) else anno_path | |
vis_root = [vis_root] if isinstance(vis_root, str) else vis_root | |
datasets = dict() | |
# create datasets | |
dataset_cls = self.train_dataset_cls | |
datasets['train'] = dataset_cls( | |
vis_processor=self.vis_processors["train"], | |
text_processor=self.text_processors["train"], | |
vis_root=vis_root, | |
anno_path=anno_path, | |
) | |
print(datasets['train'][0]) | |
return datasets | |
class MultiScaleFormulaRecTrainBuilder(BaseDatasetBuilder): | |
train_dataset_cls = MultiScaleIm2LatexDataset | |
DATASET_CONFIG_DICT = { | |
"default": "configs/datasets/formula/multi_scale_formula_train.yaml" | |
} | |
LOG_INFO = "Multi Scale Formula Recgnition Train" | |
def build_datasets(self): | |
logging.info(f"Building {self.LOG_INFO} datasets ...") | |
self.build_processors() | |
build_info = self.config.build_info | |
anno_path = build_info.annotation, | |
vis_root = build_info.images | |
anno_path = [anno_path] if isinstance(anno_path, str) else anno_path | |
vis_root = [vis_root] if isinstance(vis_root, str) else vis_root | |
datasets = dict() | |
# create datasets | |
dataset_cls = self.train_dataset_cls | |
datasets['train'] = dataset_cls( | |
vis_processor=self.vis_processors["train"], | |
text_processor=self.text_processors["train"], | |
vis_root=vis_root, | |
anno_path=anno_path, | |
) | |
print(datasets['train'][0]) | |
return datasets | |
class FormulaRecEvalBuilder(BaseDatasetBuilder): | |
eval_dataset_cls = Im2LatexDataset | |
DATASET_CONFIG_DICT = { | |
"default": "configs/datasets/formula/formula_eval.yaml" | |
} | |
LOG_INFO = "Formula Recgnition Eval" | |
def build_datasets(self): | |
logging.info(f"Building {self.LOG_INFO} datasets ...") | |
self.build_processors() | |
build_info = self.config.build_info | |
anno_path = build_info.annotation, | |
vis_root = build_info.images | |
anno_path = [anno_path] if isinstance(anno_path, str) else anno_path | |
vis_root = [vis_root] if isinstance(vis_root, str) else vis_root | |
datasets = dict() | |
# create datasets | |
dataset_cls = self.eval_dataset_cls | |
datasets['eval'] = dataset_cls( | |
vis_processor=self.vis_processors["eval"], | |
text_processor=self.text_processors["eval"], | |
vis_root=vis_root, | |
anno_path=anno_path, | |
) | |
print(datasets['eval'][0]) | |
return datasets | |