ROCm support to build py-torch (#31115)
* rocm support to build py-torch * rocm support to build py-torch
This commit is contained in:
parent
1ceee714db
commit
38a8d4d2fe
@ -158,6 +158,20 @@ class PyTorch(PythonPackage, CudaPackage):
|
||||
depends_on('numactl', when='+numa')
|
||||
depends_on('llvm-openmp', when='%apple-clang +openmp')
|
||||
depends_on('valgrind', when='+valgrind')
|
||||
with when("+rocm"):
|
||||
depends_on('hsa-rocr-dev')
|
||||
depends_on('hip')
|
||||
depends_on('rccl')
|
||||
depends_on('rocprim')
|
||||
depends_on('hipcub')
|
||||
depends_on('rocthrust')
|
||||
depends_on('roctracer-dev')
|
||||
depends_on('rocrand')
|
||||
depends_on('hipsparse')
|
||||
depends_on('hipfft')
|
||||
depends_on('rocfft')
|
||||
depends_on('rocblas')
|
||||
depends_on('miopen-hip')
|
||||
# https://github.com/pytorch/pytorch/issues/60332
|
||||
# depends_on('xnnpack@2021-02-22', when='@1.8:+xnnpack')
|
||||
# depends_on('xnnpack@2020-03-23', when='@1.6:1.7+xnnpack')
|
||||
@ -332,6 +346,22 @@ def enable_or_disable(variant, keyword='USE', var=None, newer=False):
|
||||
env.set('CMAKE_CUDA_FLAGS', '=-Xcompiler={0}'.format(flag))
|
||||
|
||||
enable_or_disable('rocm')
|
||||
if '+rocm' in self.spec:
|
||||
env.set('HSA_PATH', self.spec['hsa-rocr-dev'].prefix)
|
||||
env.set('ROCBLAS_PATH', self.spec['rocblas'].prefix)
|
||||
env.set('ROCFFT_PATH', self.spec['rocfft'].prefix)
|
||||
env.set('HIPFFT_PATH', self.spec['hipfft'].prefix)
|
||||
env.set('HIPSPARSE_PATH', self.spec['hipsparse'].prefix)
|
||||
env.set('THRUST_PATH', self.spec['rocthrust'].prefix.include)
|
||||
env.set('HIP_PATH', self.spec['hip'].prefix)
|
||||
env.set('HIPRAND_PATH', self.spec['rocrand'].prefix)
|
||||
env.set('ROCRAND_PATH', self.spec['rocrand'].prefix)
|
||||
env.set('MIOPEN_PATH', self.spec['miopen-hip'].prefix)
|
||||
env.set('RCCL_PATH', self.spec['rccl'].prefix)
|
||||
env.set('ROCPRIM_PATH', self.spec['rocprim'].prefix)
|
||||
env.set('HIPCUB_PATH', self.spec['hipcub'].prefix)
|
||||
env.set('ROCTHRUST_PATH', self.spec['rocthrust'].prefix)
|
||||
env.set('ROCTRACER_PATH', self.spec['roctracer-dev'].prefix)
|
||||
|
||||
enable_or_disable('cudnn')
|
||||
if '+cudnn' in self.spec:
|
||||
|
Loading…
Reference in New Issue
Block a user