-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsetup.py
56 lines (52 loc) · 2.59 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#
# Nathan Lay
# AI Resource at National Cancer Institute
# National Institutes of Health
# November 2020
#
# THIS SOFTWARE IS PROVIDED BY THE AUTHOR(S) ``AS IS'' AND ANY EXPRESS OR
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
# IN NO EVENT SHALL THE AUTHOR(S) BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
# NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
# THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from setuptools import setup, Extension
import torch.cuda
from torch.utils import cpp_extension
sourceFiles = [ 'hingetree.cpp', 'hingetree_sparse.cpp', 'hingetrie.cpp', 'ImageToMatrix.cpp', 'hingetree_conv.cpp', 'hingetree_fused_linear.cpp', 'hingetree_fusion.cpp', 'expand.cpp', 'Timer.cpp' ]
extraCflags = [ '-O2' ]
extraCudaFlags = [ '-O2' ]
if torch.cuda.is_available():
sourceFiles.append('hingetree_gpu.cu')
sourceFiles.append('hingetree_sparse_gpu.cu')
sourceFiles.append('ImageToMatrix_gpu.cu')
sourceFiles.append('hingetree_conv_gpu.cu')
sourceFiles.append('hingetree_fused_linear_gpu.cu')
sourceFiles.append('hingetree_fusion_gpu.cu')
sourceFiles.append('expand_gpu.cu')
extraCflags.append('-DWITH_CUDA=1')
extraCudaFlags.append('-DWITH_CUDA=1')
setup(name='hingetree_cpp',
version='1.1.0',
description='Port of random hinge forest for PyTorch.',
author='Nathan Lay',
author_email='[email protected]',
url='https://github.com/nslay/HingeTreeForTorch/',
packages=["HingeTree", "RandomHingeForest"],
ext_modules=[cpp_extension.CUDAExtension(name = 'hingetree_cpp', sources = sourceFiles, extra_compile_args = {'cxx': extraCflags, 'nvcc': extraCudaFlags})],
cmdclass={'build_ext': cpp_extension.BuildExtension})
else:
setup(name='hingetree_cpp',
version='1.1.0',
description='Port of random hinge forest for PyTorch.',
author='Nathan Lay',
author_email='[email protected]',
url='https://github.com/nslay/HingeTreeForTorch/',
packages=["HingeTree", "RandomHingeForest"],
ext_modules=[cpp_extension.CppExtension(name = 'hingetree_cpp', sources = sourceFiles, extra_compile_args = {'cxx': extraCflags, 'nvcc': extraCudaFlags})],
cmdclass={'build_ext': cpp_extension.BuildExtension})