From 7d4b378952489b5c19b8d3ca5c028bf46a6ae86c Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 16 May 2025 22:44:42 +0900 Subject: [PATCH] Include cuda_bf16.h for bfloat16 overloads (#2192) * Include cuda_bf16.h for bfloat16 overloads * Add NO_GPU_MULTI(Eig) in cuda backend --- mlx/backend/cuda/kernels/fp16_math.cuh | 33 +------------------------- mlx/backend/cuda/primitives.cu | 1 + 2 files changed, 2 insertions(+), 32 deletions(-) diff --git a/mlx/backend/cuda/kernels/fp16_math.cuh b/mlx/backend/cuda/kernels/fp16_math.cuh index 931c55ff7..edbd953de 100644 --- a/mlx/backend/cuda/kernels/fp16_math.cuh +++ b/mlx/backend/cuda/kernels/fp16_math.cuh @@ -2,44 +2,13 @@ #pragma once +#include #include #include #include namespace mlx::core::cu { -/////////////////////////////////////////////////////////////////////////////// -// Missing C++ operator overrides for CUDA 7. -/////////////////////////////////////////////////////////////////////////////// - -#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 - -#define MLX_DEFINE_BF16_OP(OP) \ - __forceinline__ __device__ __nv_bfloat16 operator OP( \ - __nv_bfloat16 x, __nv_bfloat16 y) { \ - return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \ - } - -#define MLX_DEFINE_BF16_CMP(OP) \ - __forceinline__ __device__ bool operator OP( \ - __nv_bfloat16 x, __nv_bfloat16 y) { \ - return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \ - } - -MLX_DEFINE_BF16_OP(+) -MLX_DEFINE_BF16_OP(-) -MLX_DEFINE_BF16_OP(*) -MLX_DEFINE_BF16_OP(/) -MLX_DEFINE_BF16_CMP(>) -MLX_DEFINE_BF16_CMP(<) -MLX_DEFINE_BF16_CMP(>=) -MLX_DEFINE_BF16_CMP(<=) - -#undef MLX_DEFINE_BF16_OP -#undef MLX_DEFINE_BF16_CMP - -#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 - /////////////////////////////////////////////////////////////////////////////// // Additional C++ operator overrides between half types and native types. /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index dc6edf606..defdc746a 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -140,6 +140,7 @@ NO_GPU(Tan) NO_GPU(Tanh) NO_GPU(Inverse) NO_GPU(Cholesky) +NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) namespace fast {