Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
from typing import List, Union | |
from mmcv.transforms import BaseTransform | |
PIPELINE_TYPE = List[Union[dict, BaseTransform]] | |
def get_transform_idx(pipeline: PIPELINE_TYPE, target: str) -> int: | |
"""Returns the index of the transform in a pipeline. | |
Args: | |
pipeline (List[dict] | List[BaseTransform]): The transforms list. | |
target (str): The target transform class name. | |
Returns: | |
int: The transform index. Returns -1 if not found. | |
""" | |
for i, transform in enumerate(pipeline): | |
if isinstance(transform, dict): | |
if isinstance(transform['type'], type): | |
if transform['type'].__name__ == target: | |
return i | |
else: | |
if transform['type'] == target: | |
return i | |
else: | |
if transform.__class__.__name__ == target: | |
return i | |
return -1 | |
def remove_transform(pipeline: PIPELINE_TYPE, target: str, inplace=False): | |
"""Remove the target transform type from the pipeline. | |
Args: | |
pipeline (List[dict] | List[BaseTransform]): The transforms list. | |
target (str): The target transform class name. | |
inplace (bool): Whether to modify the pipeline inplace. | |
Returns: | |
The modified transform. | |
""" | |
idx = get_transform_idx(pipeline, target) | |
if not inplace: | |
pipeline = copy.deepcopy(pipeline) | |
while idx >= 0: | |
pipeline.pop(idx) | |
idx = get_transform_idx(pipeline, target) | |
return pipeline | |