Spaces:
Runtime error
Runtime error
import os | |
from setuptools import setup | |
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension | |
from os.path import join | |
CPU_ONLY = False | |
project_root = 'Correlation_Module' | |
source_files = ['correlation.cpp', 'correlation_sampler.cpp'] | |
cxx_args = ['-std=c++17', '-fopenmp'] | |
def generate_nvcc_args(gpu_archs): | |
nvcc_args = [] | |
for arch in gpu_archs: | |
nvcc_args.extend(['-gencode', f'arch=compute_{arch},code=sm_{arch}']) | |
return nvcc_args | |
gpu_arch = os.environ.get('GPU_ARCH', '').split() | |
nvcc_args = generate_nvcc_args(gpu_arch) | |
with open("README.md", "r") as fh: | |
long_description = fh.read() | |
def launch_setup(): | |
if CPU_ONLY: | |
Extension = CppExtension | |
macro = [] | |
else: | |
Extension = CUDAExtension | |
source_files.append('correlation_cuda_kernel.cu') | |
macro = [("USE_CUDA", None)] | |
sources = [join(project_root, file) for file in source_files] | |
setup( | |
name='spatial_correlation_sampler', | |
version="0.4.0", | |
author="Clément Pinard", | |
author_email="[email protected]", | |
description="Correlation module for pytorch", | |
long_description=long_description, | |
long_description_content_type="text/markdown", | |
url="https://github.com/ClementPinard/Pytorch-Correlation-extension", | |
install_requires=['torch>=1.1', 'numpy'], | |
ext_modules=[ | |
Extension('spatial_correlation_sampler_backend', | |
sources, | |
define_macros=macro, | |
extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}, | |
extra_link_args=['-lgomp']) | |
], | |
package_dir={'': project_root}, | |
packages=['spatial_correlation_sampler'], | |
cmdclass={ | |
'build_ext': BuildExtension | |
}, | |
classifiers=[ | |
"Programming Language :: Python :: 3", | |
"License :: OSI Approved :: MIT License", | |
"Operating System :: POSIX :: Linux", | |
"Intended Audience :: Science/Research", | |
"Topic :: Scientific/Engineering :: Artificial Intelligence" | |
]) | |
if __name__ == '__main__': | |
launch_setup() | |