mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Include cuda_bf16.h for bfloat16 overloads (#2192)
* Include cuda_bf16.h for bfloat16 overloads * Add NO_GPU_MULTI(Eig) in cuda backend
This commit is contained in:
parent
7ff5c41e06
commit
7d4b378952
@ -2,44 +2,13 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda/std/limits>
|
||||
#include <cuda/std/type_traits>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Missing C++ operator overrides for CUDA 7.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
|
||||
#define MLX_DEFINE_BF16_OP(OP) \
|
||||
__forceinline__ __device__ __nv_bfloat16 operator OP( \
|
||||
__nv_bfloat16 x, __nv_bfloat16 y) { \
|
||||
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
|
||||
}
|
||||
|
||||
#define MLX_DEFINE_BF16_CMP(OP) \
|
||||
__forceinline__ __device__ bool operator OP( \
|
||||
__nv_bfloat16 x, __nv_bfloat16 y) { \
|
||||
return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \
|
||||
}
|
||||
|
||||
MLX_DEFINE_BF16_OP(+)
|
||||
MLX_DEFINE_BF16_OP(-)
|
||||
MLX_DEFINE_BF16_OP(*)
|
||||
MLX_DEFINE_BF16_OP(/)
|
||||
MLX_DEFINE_BF16_CMP(>)
|
||||
MLX_DEFINE_BF16_CMP(<)
|
||||
MLX_DEFINE_BF16_CMP(>=)
|
||||
MLX_DEFINE_BF16_CMP(<=)
|
||||
|
||||
#undef MLX_DEFINE_BF16_OP
|
||||
#undef MLX_DEFINE_BF16_CMP
|
||||
|
||||
#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Additional C++ operator overrides between half types and native types.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -140,6 +140,7 @@ NO_GPU(Tan)
|
||||
NO_GPU(Tanh)
|
||||
NO_GPU(Inverse)
|
||||
NO_GPU(Cholesky)
|
||||
NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace fast {
|
||||
|
Loading…
Reference in New Issue
Block a user