diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index ad942a406..67ef5d968 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -1,6 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/iterators/strided_iterator.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 638d68727..99ccfdb4a 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -264,19 +264,26 @@ void CommandEncoder::commit() { graph_key_ += std::to_string(graph_node_count_); graph_key_ += "."; graph_key_ += std::to_string(empty_node_count_); - auto [it, _] = graph_cache_.emplace(graph_key_, nullptr); - auto& graph_exec = it->second; - if (graph_exec != NULL) { - cudaGraphExecUpdateResultInfo update_result; - cudaGraphExecUpdate(graph_exec, graph_, &update_result); - if (update_result.result != cudaGraphExecUpdateSuccess) { - cudaGetLastError(); + cudaGraphExec_t& graph_exec = graph_cache_[graph_key_]; + + if (graph_exec != nullptr) { + cudaGraphExecUpdateResult update_result; +#if CUDART_VERSION >= 12000 + cudaGraphExecUpdateResultInfo info; + cudaGraphExecUpdate(graph_exec, graph_, &info); + update_result = info.result; +#else + cudaGraphNode_t error_node; + cudaGraphExecUpdate(graph_exec, graph_, &error_node, &update_result); +#endif // CUDART_VERSION >= 12000 + if (update_result != cudaGraphExecUpdateSuccess) { + cudaGetLastError(); // reset error CHECK_CUDA_ERROR(cudaGraphExecDestroy(graph_exec)); - graph_exec = NULL; + graph_exec = nullptr; } } - if (graph_exec == NULL) { + if (graph_exec == nullptr) { CHECK_CUDA_ERROR( cudaGraphInstantiate(&graph_exec, graph_, NULL, NULL, 0)); } diff --git a/mlx/backend/cuda/device/cast_op.cuh b/mlx/backend/cuda/device/cast_op.cuh index f15270432..8da19ddf8 100644 --- a/mlx/backend/cuda/device/cast_op.cuh +++ b/mlx/backend/cuda/device/cast_op.cuh @@ -3,6 +3,8 @@ #pragma once #include +#include +#include #include namespace mlx::core::cu { @@ -17,6 +19,26 @@ struct CastOp { } }; +// Castings between complex and boolean. +// TODO: Should make a custom complex type. +template <> +struct CastOp { + static constexpr bool is_castable = true; + + __device__ bool operator()(cuComplex x) { + return x.x != 0 && x.y != 0; + } +}; + +template <> +struct CastOp { + static constexpr bool is_castable = true; + + __device__ cuComplex operator()(bool x) { + return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0); + } +}; + // Converting a complex number to real number discards the imaginary part. template struct CastOp< @@ -45,6 +67,7 @@ struct CastOp< } }; +// Do nothing when no casting is needed. template struct CastOp< SrcT, @@ -57,9 +80,53 @@ struct CastOp< } }; +// In CUDA 11 the half types do not define conversions between some types, +// provide fallbacks here. +#if CUDART_VERSION < 12000 +template +struct CastOp< + SrcT, + DstT, + cuda::std::enable_if_t< + !cuda::std::is_convertible_v && + !cuda::std::is_same_v && + (cuda::std::is_same_v || + cuda::std::is_same_v)>> { + static constexpr bool is_castable = true; + + __device__ DstT operator()(SrcT x) { + return DstT(static_cast(x)); + } +}; + +template +struct CastOp< + SrcT, + DstT, + cuda::std::enable_if_t< + !cuda::std::is_convertible_v && + !cuda::std::is_same_v && + !cuda::std::is_same_v && + !cuda::std::is_same_v && + (cuda::std::is_same_v || + cuda::std::is_same_v)>> { + static constexpr bool is_castable = true; + + __device__ DstT operator()(SrcT x) { + return DstT(static_cast(x)); + } +}; +#endif // CUDART_VERSION < 12000 + +// Helper to deduce the SrcT. +template +inline __host__ __device__ auto cast_to(SrcT x) { + return CastOp{}(x); +} + // Return an iterator that cast the value to DstT using CastOp. template -__host__ __device__ auto make_cast_iterator(Iterator it) { +inline __host__ __device__ auto make_cast_iterator(Iterator it) { using SrcT = typename cuda::std::iterator_traits::value_type; if constexpr (std::is_same_v) { return it; diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index 89b609c45..83e149165 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -99,20 +99,20 @@ struct Limits< return cuda::std::numeric_limits::infinity(); } static constexpr __host__ __device__ T min() { -#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000 - return -cuda::std::numeric_limits::infinity(); -#else +#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 return -cuda::std::numeric_limits::infinity(); +#else + return -cuda::std::numeric_limits::infinity(); #endif } static constexpr __host__ __device__ T finite_max() { return cuda::std::numeric_limits::max(); } static constexpr __host__ __device__ T finite_min() { -#if defined(__CUDA_ARCH__) || CUDART_VERSION >= 12000 - return cuda::std::numeric_limits::lowest(); -#else +#if CUDART_VERSION < 12000 && __CUDA_ARCH__ < 800 return cuda::std::numeric_limits::lowest(); +#else + return cuda::std::numeric_limits::lowest(); #endif } }; diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 3419d61cb..166a11a79 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -37,15 +37,15 @@ __global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { for (; i + block.size() * N <= check; i += block.size() * N) { cub::LoadDirectBlockedVectorized(block.thread_rank(), in + i, vals); for (int j = 0; j < N; j++) { - accs[0] = op(accs[0], __cast(vals[j])); + accs[0] = op(accs[0], cast_to(vals[j])); } } if (i < check) { cub::LoadDirectBlocked( - block.thread_rank(), in + i, vals, check - i, __cast(init)); + block.thread_rank(), in + i, vals, check - i, cast_to(init)); for (int i = 0; i < N; i++) { - accs[0] = op(accs[0], __cast(vals[i])); + accs[0] = op(accs[0], cast_to(vals[i])); } } diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index 910fa0379..fec5ca76b 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -3,7 +3,6 @@ #include #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" #include @@ -128,7 +127,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { T vals[N_READS]; cub::LoadDirectBlockedVectorized(thread_x, in + loop.location(), vals); for (int i = 0; i < N_READS; i++) { - totals[i] = op(totals[i], __cast(vals[i])); + totals[i] = op(totals[i], cast_to(vals[i])); } loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); } @@ -137,7 +136,7 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { T vals[N_READS]; cub::LoadDirectBlocked(thread_x, in + loop.location(), vals); for (int i = 0; i < N_READS; i++) { - totals[i] = op(totals[i], __cast(vals[i])); + totals[i] = op(totals[i], cast_to(vals[i])); } loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); } @@ -150,9 +149,9 @@ col_reduce_looped(T* in, U* out, const __grid_constant__ ColReduceArgs args) { in + loop.location(), vals, args.reduction_stride - tile_x * BN, - __cast(ReduceInit::value())); + cast_to(ReduceInit::value())); for (int i = 0; i < N_READS; i++) { - totals[i] = op(totals[i], __cast(vals[i])); + totals[i] = op(totals[i], cast_to(vals[i])); } loop.next(BM, args.reduce_shape.data(), args.reduce_strides.data()); } diff --git a/mlx/backend/cuda/reduce/reduce_ops.cuh b/mlx/backend/cuda/reduce/reduce_ops.cuh index b40d2bd4e..bc4dce33e 100644 --- a/mlx/backend/cuda/reduce/reduce_ops.cuh +++ b/mlx/backend/cuda/reduce/reduce_ops.cuh @@ -2,6 +2,8 @@ #pragma once +#include "mlx/backend/cuda/device/atomic_ops.cuh" +#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/reduce/reduce_utils.cuh" @@ -40,15 +42,15 @@ struct Sum { } __device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) { - atomicAdd(x, y); + atomic_add(x, y); } __device__ void atomic_update(int* x, int y) { - atomicAdd(x, y); + atomic_add(x, y); } __device__ void atomic_update(float* x, float y) { - atomicAdd(x, y); + atomic_add(x, y); } }; @@ -152,7 +154,7 @@ struct ReduceInit { if constexpr (cuda::std::is_same_v) { return T{0, 0}; } else { - return typename ReduceResult::type{0}; + return cast_to::type>(0); } } }; @@ -163,7 +165,7 @@ struct ReduceInit { if constexpr (cuda::std::is_same_v) { return T{1, 0}; } else { - return typename ReduceResult::type{1}; + return cast_to::type>(1); } } }; diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh index d4670503a..ccd7ae48d 100644 --- a/mlx/backend/cuda/reduce/reduce_utils.cuh +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -55,22 +55,6 @@ __device__ void atomic_reduce(T* x, T y) { } } -// TODO: Should make a custom complex type -template -inline __device__ U __cast(T x) { - return static_cast(x); -} - -template <> -inline __device__ bool __cast(cuComplex x) { - return x.x != 0 && x.y != 0; -} - -template <> -inline __device__ cuComplex __cast(bool x) { - return x ? make_cuFloatComplex(1, 1) : make_cuFloatComplex(0, 0); -} - template inline __device__ void block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) { diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index e57f18668..61838ddd3 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -3,7 +3,6 @@ #include #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/reduce/reduce.cuh" #include @@ -113,7 +112,7 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { in + k * size + r * (block.size() * N), vals[k]); for (int j = 0; j < N; j++) { - accs[k] = op(accs[k], __cast(vals[k][j])); + accs[k] = op(accs[k], cast_to(vals[k][j])); } } } @@ -125,7 +124,7 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { in + k * size + r * (block.size() * N), vals[k]); for (int j = 0; j < N; j++) { - accs[k] = op(accs[k], __cast(vals[k][j])); + accs[k] = op(accs[k], cast_to(vals[k][j])); } } } @@ -138,9 +137,9 @@ __global__ void row_reduce_simple(T* in, U* out, size_t n_rows, int size) { in + k * size + final_offset, vals[k], size, - __cast(init)); + cast_to(init)); for (int j = 0; j < N; j++) { - accs[k] = op(accs[k], __cast(vals[k][j])); + accs[k] = op(accs[k], cast_to(vals[k][j])); } } } @@ -199,7 +198,7 @@ __global__ void row_reduce_looped( in + loop.location() + r * BLOCK_DIM * N_READS, vals); for (int i = 0; i < N_READS; i++) { - total[0] = op(total[0], __cast(vals[i])); + total[0] = op(total[0], cast_to(vals[i])); } } if (final_offset < args.row_size) { @@ -209,9 +208,9 @@ __global__ void row_reduce_looped( in + loop.location() + final_offset, vals, args.row_size - final_offset, - __cast(init)); + cast_to(init)); for (int i = 0; i < N_READS; i++) { - total[0] = op(total[0], __cast(vals[i])); + total[0] = op(total[0], cast_to(vals[i])); } } // TODO: Maybe block.sync() here? diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index 5ee1d3386..964bd7d98 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -74,7 +74,7 @@ __global__ void rms_norm( for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) { auto index = r * BLOCK_DIM + block.thread_rank(); T xn[N_READS]; - cub::LoadDirectBlocked(index, x, xn, axis_size, 0); + cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to(0)); for (int i = 0; i < N_READS; ++i) { float t = static_cast(xn[i]); normalizer += t * t; @@ -130,7 +130,7 @@ __global__ void rms_norm_vjp( T wn[N_READS] = {}; T gn[N_READS] = {}; auto index = r * BLOCK_DIM + block.thread_rank(); - cub::LoadDirectBlocked(index, x, xn, axis_size, 0); + cub::LoadDirectBlocked(index, x, xn, axis_size, cast_to(0)); cub::LoadDirectBlocked(index, g, gn, axis_size); cub::LoadDirectBlocked(index, strided_iterator(w, w_stride), wn, axis_size); for (int i = 0; i < N_READS; i++) { diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index fd807bd8d..56f67d7f3 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -43,7 +43,7 @@ __global__ void softmax(const T* in, T* out, int axis_size) { // Thread reduce. AccT prevmax; AccT maxval = Limits::finite_min(); - AccT normalizer = 0; + AccT normalizer = cast_to(0); for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); r++) { AccT vals[N_READS]; cub::LoadDirectBlocked(