Include cuda_bf16.h for bfloat16 overloads (#2192)

* Include cuda_bf16.h for bfloat16 overloads

* Add NO_GPU_MULTI(Eig) in cuda backend
This commit is contained in:
Cheng 2025-05-16 22:44:42 +09:00 committed by GitHub
parent 7ff5c41e06
commit 7d4b378952
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 32 deletions

View File

@ -2,44 +2,13 @@
#pragma once #pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda/std/limits> #include <cuda/std/limits>
#include <cuda/std/type_traits> #include <cuda/std/type_traits>
namespace mlx::core::cu { namespace mlx::core::cu {
///////////////////////////////////////////////////////////////////////////////
// Missing C++ operator overrides for CUDA 7.
///////////////////////////////////////////////////////////////////////////////
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
#define MLX_DEFINE_BF16_OP(OP) \
__forceinline__ __device__ __nv_bfloat16 operator OP( \
__nv_bfloat16 x, __nv_bfloat16 y) { \
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
}
#define MLX_DEFINE_BF16_CMP(OP) \
__forceinline__ __device__ bool operator OP( \
__nv_bfloat16 x, __nv_bfloat16 y) { \
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
}
MLX_DEFINE_BF16_OP(+)
MLX_DEFINE_BF16_OP(-)
MLX_DEFINE_BF16_OP(*)
MLX_DEFINE_BF16_OP(/)
MLX_DEFINE_BF16_CMP(>)
MLX_DEFINE_BF16_CMP(<)
MLX_DEFINE_BF16_CMP(>=)
MLX_DEFINE_BF16_CMP(<=)
#undef MLX_DEFINE_BF16_OP
#undef MLX_DEFINE_BF16_CMP
#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Additional C++ operator overrides between half types and native types. // Additional C++ operator overrides between half types and native types.
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////

View File

@ -140,6 +140,7 @@ NO_GPU(Tan)
NO_GPU(Tanh) NO_GPU(Tanh)
NO_GPU(Inverse) NO_GPU(Inverse)
NO_GPU(Cholesky) NO_GPU(Cholesky)
NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh) NO_GPU_MULTI(Eigh)
namespace fast { namespace fast {