diff --git a/mlx/backend/cuda/kernels/fp16_math.cuh b/mlx/backend/cuda/kernels/fp16_math.cuh index 931c55ff7..edbd953de 100644 --- a/mlx/backend/cuda/kernels/fp16_math.cuh +++ b/mlx/backend/cuda/kernels/fp16_math.cuh @@ -2,44 +2,13 @@ #pragma once +#include #include #include #include namespace mlx::core::cu { -/////////////////////////////////////////////////////////////////////////////// -// Missing C++ operator overrides for CUDA 7. -/////////////////////////////////////////////////////////////////////////////// - -#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 - -#define MLX_DEFINE_BF16_OP(OP) \ - __forceinline__ __device__ __nv_bfloat16 operator OP( \ - __nv_bfloat16 x, __nv_bfloat16 y) { \ - return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \ - } - -#define MLX_DEFINE_BF16_CMP(OP) \ - __forceinline__ __device__ bool operator OP( \ - __nv_bfloat16 x, __nv_bfloat16 y) { \ - return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \ - } - -MLX_DEFINE_BF16_OP(+) -MLX_DEFINE_BF16_OP(-) -MLX_DEFINE_BF16_OP(*) -MLX_DEFINE_BF16_OP(/) -MLX_DEFINE_BF16_CMP(>) -MLX_DEFINE_BF16_CMP(<) -MLX_DEFINE_BF16_CMP(>=) -MLX_DEFINE_BF16_CMP(<=) - -#undef MLX_DEFINE_BF16_OP -#undef MLX_DEFINE_BF16_CMP - -#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 - /////////////////////////////////////////////////////////////////////////////// // Additional C++ operator overrides between half types and native types. /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index dc6edf606..defdc746a 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -140,6 +140,7 @@ NO_GPU(Tan) NO_GPU(Tanh) NO_GPU(Inverse) NO_GPU(Cholesky) +NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) namespace fast {