NeMo / tests /utils /test_utils.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
import pytest
from nemo import __version__ as NEMO_VERSION
from nemo.utils.data_utils import (
ais_binary,
ais_endpoint_to_dir,
bucket_and_object_from_uri,
datastore_path_to_webdataset_url,
is_datastore_path,
resolve_cache_dir,
)
class TestDataUtils:
@pytest.mark.unit
def test_resolve_cache_dir(self):
"""Test cache dir path.
"""
TEST_NEMO_ENV_CACHE_DIR = 'TEST_NEMO_ENV_CACHE_DIR'
with mock.patch('nemo.constants.NEMO_ENV_CACHE_DIR', TEST_NEMO_ENV_CACHE_DIR):
envar_to_resolved_path = {
'/path/to/cache': '/path/to/cache',
'relative/path': os.path.join(os.getcwd(), 'relative/path'),
'': os.path.expanduser(f'~/.cache/torch/NeMo/NeMo_{NEMO_VERSION}'),
}
for envar, expected_path in envar_to_resolved_path.items():
# Set envar
os.environ[TEST_NEMO_ENV_CACHE_DIR] = envar
# Check path
uut_path = resolve_cache_dir().as_posix()
assert uut_path == expected_path, f'Expected: {expected_path}, got {uut_path}'
@pytest.mark.unit
def test_is_datastore_path(self):
"""Test checking for datastore path.
"""
# Positive examples
assert is_datastore_path('ais://positive/example')
# Negative examples
assert not is_datastore_path('ais/negative/example')
assert not is_datastore_path('/negative/example')
assert not is_datastore_path('negative/example')
@pytest.mark.unit
def test_bucket_and_object_from_uri(self):
"""Test getting bucket and object from URI.
"""
# Positive examples
assert bucket_and_object_from_uri('ais://bucket/object') == ('bucket', 'object')
assert bucket_and_object_from_uri('ais://bucket_2/object/is/here') == ('bucket_2', 'object/is/here')
# Negative examples: invalid URI
with pytest.raises(ValueError):
bucket_and_object_from_uri('/local/file')
with pytest.raises(ValueError):
bucket_and_object_from_uri('local/file')
@pytest.mark.unit
def test_ais_endpoint_to_dir(self):
"""Test converting an AIS endpoint to dir.
"""
assert ais_endpoint_to_dir('http://local:123') == os.path.join('local', '123')
assert ais_endpoint_to_dir('http://1.2.3.4:567') == os.path.join('1.2.3.4', '567')
with pytest.raises(ValueError):
ais_endpoint_to_dir('local:123')
@pytest.mark.unit
def test_ais_binary(self):
"""Test cache dir path.
"""
with mock.patch('shutil.which', lambda x: '/test/path/ais'):
assert ais_binary() == '/test/path/ais'
# Negative example: AIS binary cannot be found
with mock.patch('shutil.which', lambda x: None), mock.patch('os.path.isfile', lambda x: None):
with pytest.raises(RuntimeError):
ais_binary()
@pytest.mark.unit
def test_datastore_path_to_webdataset_url(self):
"""Test conversion of data store path to an URL for WebDataset.
"""
assert datastore_path_to_webdataset_url('ais://test/path') == 'pipe:ais get ais://test/path - || true'