From 5e71f2f3efb4859cb9fd74f075cbb20af497b694 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sun, 11 May 2025 13:07:15 +0000 Subject: [PATCH] Include cuda_bf16.h for bfloat16 overloads --- mlx/backend/cuda/kernels/fp16_math.cuh | 33 +------------------------- 1 file changed, 1 insertion(+), 32 deletions(-) diff --git a/mlx/backend/cuda/kernels/fp16_math.cuh b/mlx/backend/cuda/kernels/fp16_math.cuh index 931c55ff7c..edbd953de1 100644 --- a/mlx/backend/cuda/kernels/fp16_math.cuh +++ b/mlx/backend/cuda/kernels/fp16_math.cuh @@ -2,44 +2,13 @@ #pragma once +#include #include #include #include namespace mlx::core::cu { -/////////////////////////////////////////////////////////////////////////////// -// Missing C++ operator overrides for CUDA 7. -/////////////////////////////////////////////////////////////////////////////// - -#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 - -#define MLX_DEFINE_BF16_OP(OP) \ - __forceinline__ __device__ __nv_bfloat16 operator OP( \ - __nv_bfloat16 x, __nv_bfloat16 y) { \ - return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \ - } - -#define MLX_DEFINE_BF16_CMP(OP) \ - __forceinline__ __device__ bool operator OP( \ - __nv_bfloat16 x, __nv_bfloat16 y) { \ - return __float2bfloat16(__bfloat162float(x) OP __bfloat162float(y)); \ - } - -MLX_DEFINE_BF16_OP(+) -MLX_DEFINE_BF16_OP(-) -MLX_DEFINE_BF16_OP(*) -MLX_DEFINE_BF16_OP(/) -MLX_DEFINE_BF16_CMP(>) -MLX_DEFINE_BF16_CMP(<) -MLX_DEFINE_BF16_CMP(>=) -MLX_DEFINE_BF16_CMP(<=) - -#undef MLX_DEFINE_BF16_OP -#undef MLX_DEFINE_BF16_CMP - -#endif // CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 - /////////////////////////////////////////////////////////////////////////////// // Additional C++ operator overrides between half types and native types. ///////////////////////////////////////////////////////////////////////////////