Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import gzip | |
import hashlib | |
import os | |
import os.path | |
import shutil | |
import tarfile | |
import tempfile | |
import urllib.error | |
import urllib.request | |
import zipfile | |
from mmengine.fileio import LocalBackend, get_file_backend | |
__all__ = [ | |
'rm_suffix', 'check_integrity', 'download_and_extract_archive', | |
'open_maybe_compressed_file' | |
] | |
def rm_suffix(s, suffix=None): | |
if suffix is None: | |
return s[:s.rfind('.')] | |
else: | |
return s[:s.rfind(suffix)] | |
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024): | |
md5 = hashlib.md5() | |
backend = get_file_backend(fpath, enable_singleton=True) | |
if isinstance(backend, LocalBackend): | |
# Enable chunk update for local file. | |
with open(fpath, 'rb') as f: | |
for chunk in iter(lambda: f.read(chunk_size), b''): | |
md5.update(chunk) | |
else: | |
md5.update(backend.get(fpath)) | |
return md5.hexdigest() | |
def check_md5(fpath, md5, **kwargs): | |
return md5 == calculate_md5(fpath, **kwargs) | |
def check_integrity(fpath, md5=None): | |
if not os.path.isfile(fpath): | |
return False | |
if md5 is None: | |
return True | |
return check_md5(fpath, md5) | |
def download_url_to_file(url, dst, hash_prefix=None, progress=True): | |
"""Download object at the given URL to a local path. | |
Modified from | |
https://pytorch.org/docs/stable/hub.html#torch.hub.download_url_to_file | |
Args: | |
url (str): URL of the object to download | |
dst (str): Full path where object will be saved, | |
e.g. ``/tmp/temporary_file`` | |
hash_prefix (string, optional): If not None, the SHA256 downloaded | |
file should start with ``hash_prefix``. Defaults to None. | |
progress (bool): whether or not to display a progress bar to stderr. | |
Defaults to True | |
""" | |
file_size = None | |
req = urllib.request.Request(url) | |
u = urllib.request.urlopen(req) | |
meta = u.info() | |
if hasattr(meta, 'getheaders'): | |
content_length = meta.getheaders('Content-Length') | |
else: | |
content_length = meta.get_all('Content-Length') | |
if content_length is not None and len(content_length) > 0: | |
file_size = int(content_length[0]) | |
# We deliberately save it in a temp file and move it after download is | |
# complete. This prevents a local file being overridden by a broken | |
# download. | |
dst = os.path.expanduser(dst) | |
dst_dir = os.path.dirname(dst) | |
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) | |
import rich.progress | |
columns = [ | |
rich.progress.DownloadColumn(), | |
rich.progress.BarColumn(bar_width=None), | |
rich.progress.TimeRemainingColumn(), | |
] | |
try: | |
if hash_prefix is not None: | |
sha256 = hashlib.sha256() | |
with rich.progress.Progress(*columns) as pbar: | |
task = pbar.add_task('download', total=file_size, visible=progress) | |
while True: | |
buffer = u.read(8192) | |
if len(buffer) == 0: | |
break | |
f.write(buffer) | |
if hash_prefix is not None: | |
sha256.update(buffer) | |
pbar.update(task, advance=len(buffer)) | |
f.close() | |
if hash_prefix is not None: | |
digest = sha256.hexdigest() | |
if digest[:len(hash_prefix)] != hash_prefix: | |
raise RuntimeError( | |
'invalid hash value (expected "{}", got "{}")'.format( | |
hash_prefix, digest)) | |
shutil.move(f.name, dst) | |
finally: | |
f.close() | |
if os.path.exists(f.name): | |
os.remove(f.name) | |
def download_url(url, root, filename=None, md5=None): | |
"""Download a file from a url and place it in root. | |
Args: | |
url (str): URL to download file from. | |
root (str): Directory to place downloaded file in. | |
filename (str | None): Name to save the file under. | |
If filename is None, use the basename of the URL. | |
md5 (str | None): MD5 checksum of the download. | |
If md5 is None, download without md5 check. | |
""" | |
root = os.path.expanduser(root) | |
if not filename: | |
filename = os.path.basename(url) | |
fpath = os.path.join(root, filename) | |
os.makedirs(root, exist_ok=True) | |
if check_integrity(fpath, md5): | |
print(f'Using downloaded and verified file: {fpath}') | |
else: | |
try: | |
print(f'Downloading {url} to {fpath}') | |
download_url_to_file(url, fpath) | |
except (urllib.error.URLError, IOError) as e: | |
if url[:5] == 'https': | |
url = url.replace('https:', 'http:') | |
print('Failed download. Trying https -> http instead.' | |
f' Downloading {url} to {fpath}') | |
download_url_to_file(url, fpath) | |
else: | |
raise e | |
# check integrity of downloaded file | |
if not check_integrity(fpath, md5): | |
raise RuntimeError('File not found or corrupted.') | |
def _is_tarxz(filename): | |
return filename.endswith('.tar.xz') | |
def _is_tar(filename): | |
return filename.endswith('.tar') | |
def _is_targz(filename): | |
return filename.endswith('.tar.gz') | |
def _is_tgz(filename): | |
return filename.endswith('.tgz') | |
def _is_gzip(filename): | |
return filename.endswith('.gz') and not filename.endswith('.tar.gz') | |
def _is_zip(filename): | |
return filename.endswith('.zip') | |
def extract_archive(from_path, to_path=None, remove_finished=False): | |
if to_path is None: | |
to_path = os.path.dirname(from_path) | |
if _is_tar(from_path): | |
with tarfile.open(from_path, 'r') as tar: | |
tar.extractall(path=to_path) | |
elif _is_targz(from_path) or _is_tgz(from_path): | |
with tarfile.open(from_path, 'r:gz') as tar: | |
tar.extractall(path=to_path) | |
elif _is_tarxz(from_path): | |
with tarfile.open(from_path, 'r:xz') as tar: | |
tar.extractall(path=to_path) | |
elif _is_gzip(from_path): | |
to_path = os.path.join( | |
to_path, | |
os.path.splitext(os.path.basename(from_path))[0]) | |
with open(to_path, 'wb') as out_f, gzip.GzipFile(from_path) as zip_f: | |
out_f.write(zip_f.read()) | |
elif _is_zip(from_path): | |
with zipfile.ZipFile(from_path, 'r') as z: | |
z.extractall(to_path) | |
else: | |
raise ValueError(f'Extraction of {from_path} not supported') | |
if remove_finished: | |
os.remove(from_path) | |
def download_and_extract_archive(url, | |
download_root, | |
extract_root=None, | |
filename=None, | |
md5=None, | |
remove_finished=False): | |
download_root = os.path.expanduser(download_root) | |
if extract_root is None: | |
extract_root = download_root | |
if not filename: | |
filename = os.path.basename(url) | |
download_url(url, download_root, filename, md5) | |
archive = os.path.join(download_root, filename) | |
print(f'Extracting {archive} to {extract_root}') | |
extract_archive(archive, extract_root, remove_finished) | |
def open_maybe_compressed_file(path: str): | |
"""Return a file object that possibly decompresses 'path' on the fly. | |
Decompression occurs when argument `path` is a string and ends with '.gz' | |
or '.xz'. | |
""" | |
if not isinstance(path, str): | |
return path | |
if path.endswith('.gz'): | |
import gzip | |
return gzip.open(path, 'rb') | |
if path.endswith('.xz'): | |
import lzma | |
return lzma.open(path, 'rb') | |
return open(path, 'rb') | |