diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 8130d396f..87f4cb4ae 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -35,6 +35,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu + ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu @@ -67,6 +68,11 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen") target_compile_options(mlx PRIVATE "$<$:--extended-lambda>") +# Enable calling host constexpr functions from device. This is needed because +# the constexpr version of isnan is host only. +target_compile_options( + mlx PRIVATE "$<$:--expt-relaxed-constexpr>") + # CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive. # Explicitly pass this flag to suppress the warning, it is safe to set it to # true but the warning wouldn't be suppressed. diff --git a/mlx/backend/cuda/device/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh index dc4f8e7bb..644786a92 100644 --- a/mlx/backend/cuda/device/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -1,10 +1,7 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/device/cucomplex_math.cuh" -#include "mlx/backend/cuda/device/fp16_math.cuh" -#include "mlx/backend/cuda/device/utils.cuh" +#include "mlx/backend/cuda/device/unary_ops.cuh" -#include #include namespace mlx::core::cu { @@ -114,36 +111,38 @@ struct LessEqual { struct LogAddExp { template __device__ T operator()(T x, T y) { - if (isnan(x) || isnan(y)) { - return cuda::std::numeric_limits::quiet_NaN(); + if constexpr (cuda::std::is_same_v) { + if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) || + isnan(cuCimagf(y))) { + return { + cuda::std::numeric_limits::quiet_NaN(), + cuda::std::numeric_limits::quiet_NaN()}; + } + auto max = cuCrealf(x) > cuCrealf(y) ? x : y; + auto min = cuCrealf(x) < cuCrealf(y) ? x : y; + auto min_real = cuCrealf(min); + auto max_real = cuCrealf(max); + if (!isfinite(min_real) && (min_real == max_real)) { + if (min_real < 0) { + return min; + } else { + return Log{}(Exp{}(min) + Exp{}(max)); + } + } else { + return Log1p{}(Exp{}(min - max)) + max; + } + } else { + if (isnan(x) || isnan(y)) { + return cuda::std::numeric_limits::quiet_NaN(); + } + T maxval = max(x, y); + T minval = min(x, y); + return (minval == -cuda::std::numeric_limits::infinity() || + maxval == cuda::std::numeric_limits::infinity()) + ? maxval + : T(float(maxval) + log1p(expf(minval - maxval))); } - T maxval = max(x, y); - T minval = min(x, y); - return (minval == -cuda::std::numeric_limits::infinity() || - maxval == cuda::std::numeric_limits::infinity()) - ? maxval - : T(float(maxval) + log1p(expf(minval - maxval))); }; - - __device__ cuComplex operator()(cuComplex x, cuComplex y) { - if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) || - isnan(cuCimagf(y))) { - return { - cuda::std::numeric_limits::quiet_NaN(), - cuda::std::numeric_limits::quiet_NaN()}; - } - float inf = cuda::std::numeric_limits::infinity(); - auto maxval = x > y ? x : y; - auto minval = x < y ? x : y; - if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf) - return maxval; - float m = exp(cuCrealf(minval) - cuCrealf(maxval)); - cuComplex dexp{ - m * cos(cuCimagf(minval) - cuCimagf(maxval)), - m * sin(cuCimagf(minval) - cuCimagf(maxval)), - }; - return maxval + log1p(dexp); - } }; struct Maximum { diff --git a/mlx/backend/cuda/device/cexpf.cuh b/mlx/backend/cuda/device/cexpf.cuh new file mode 100644 index 000000000..61c94c00f --- /dev/null +++ b/mlx/backend/cuda/device/cexpf.cuh @@ -0,0 +1,138 @@ +// Copyright © 2025 Apple Inc. +// Copyright © 2008-2013 NVIDIA Corporation +// Copyright © 2013 Filipe RNC Maia +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Forked from +// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h + +// TODO: We should use thrust::exp but the thrust header in old CUDA versions +// can not be used in JIT. + +#pragma once + +#include +#include + +namespace mlx::core::cu::detail { + +using ieee_float_shape_type = union { + float value; + uint32_t word; +}; + +inline __device__ void get_float_word(uint32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} + +inline __device__ void get_float_word(int32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} + +inline __device__ void set_float_word(float& d, uint32_t i) { + ieee_float_shape_type sf_u; + sf_u.word = (i); + (d) = sf_u.value; +} + +inline __device__ float frexp_expf(float x, int* expt) { + const uint32_t k = 235; + const float kln2 = 162.88958740F; + + float exp_x; + uint32_t hx; + + exp_x = expf(x - kln2); + get_float_word(hx, exp_x); + *expt = (hx >> 23) - (0x7f + 127) + k; + set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23)); + return exp_x; +} + +inline __device__ cuComplex ldexp_cexpf(cuComplex z, int expt) { + float x, y, exp_x, scale1, scale2; + int ex_expt, half_expt; + + x = cuCrealf(z); + y = cuCimagf(z); + exp_x = frexp_expf(x, &ex_expt); + expt += ex_expt; + + half_expt = expt / 2; + set_float_word(scale1, (0x7f + half_expt) << 23); + half_expt = expt - half_expt; + set_float_word(scale2, (0x7f + half_expt) << 23); + + return cuComplex{ + cosf(y) * exp_x * scale1 * scale2, sinf(y) * exp_x * scale1 * scale2}; +} + +inline __device__ cuComplex cexpf(const cuComplex& z) { + float x, y, exp_x; + uint32_t hx, hy; + + const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074; + + x = cuCrealf(z); + y = cuCimagf(z); + + get_float_word(hy, y); + hy &= 0x7fffffff; + + /* cexp(x + I 0) = exp(x) + I 0 */ + if (hy == 0) { + return cuComplex{expf(x), y}; + } + get_float_word(hx, x); + /* cexp(0 + I y) = cos(y) + I sin(y) */ + if ((hx & 0x7fffffff) == 0) { + return cuComplex{cosf(y), sinf(y)}; + } + if (hy >= 0x7f800000) { + if ((hx & 0x7fffffff) != 0x7f800000) { + /* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */ + return cuComplex{y - y, y - y}; + } else if (hx & 0x80000000) { + /* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */ + return cuComplex{0.0, 0.0}; + } else { + /* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */ + return cuComplex{x, y - y}; + } + } + + if (hx >= exp_ovfl && hx <= cexp_ovfl) { + /* + * x is between 88.7 and 192, so we must scale to avoid + * overflow in expf(x). + */ + return ldexp_cexpf(z, 0); + } else { + /* + * Cases covered here: + * - x < exp_ovfl and exp(x) won't overflow (common case) + * - x > cexp_ovfl, so exp(x) * s overflows for all s > 0 + * - x = +-Inf (generated by exp()) + * - x = NaN (spurious inexact exception from y) + */ + exp_x = expf(x); + return cuComplex{exp_x * cosf(y), exp_x * sinf(y)}; + } +} + +} // namespace mlx::core::cu::detail diff --git a/mlx/backend/cuda/device/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh index 18d769c2a..8716d3a8c 100644 --- a/mlx/backend/cuda/device/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -2,6 +2,8 @@ #pragma once +#include "mlx/backend/cuda/device/cexpf.cuh" +#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/utils.cuh" @@ -150,8 +152,7 @@ struct Exp { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { - auto m = exp(cuCrealf(x)); - return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))}; + return detail::cexpf(x); } else { return exp(x); } @@ -228,8 +229,25 @@ struct Log10 { struct Log1p { template - __device__ T operator()(T x) { - return log1p(x); + __device__ T operator()(T z) { + if constexpr (cuda::std::is_same_v) { + float x = cuCrealf(z); + float y = cuCimagf(z); + float zabs = cuCrealf(Abs{}(z)); + float theta = atan2f(y, x + 1); + if (zabs < 0.5f) { + float r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {0.5f * log1pf(r), theta}; + } else { + float z0 = hypotf(x + 1, y); + return {logf(z0), theta}; + } + } else { + return log1p(z); + } } }; @@ -387,19 +405,19 @@ struct Tanh { } }; -__device__ cuComplex ArcCos::operator()(cuComplex x) { +inline __device__ cuComplex ArcCos::operator()(cuComplex x) { auto i = cuComplex{0.0, 1.0}; auto y = Log{}(x + i * Sqrt{}(1.0 - x * x)); return {cuCimagf(y), -cuCrealf(y)}; }; -__device__ cuComplex ArcSin::operator()(cuComplex x) { +inline __device__ cuComplex ArcSin::operator()(cuComplex x) { auto i = cuComplex{0.0f, 1.0f}; auto y = Log{}(i * x + Sqrt{}(1.0f - x * x)); return {cuCimagf(y), -cuCrealf(y)}; }; -__device__ cuComplex ArcTan::operator()(cuComplex x) { +inline __device__ cuComplex ArcTan::operator()(cuComplex x) { auto i = cuComplex{0.0f, 1.0f}; auto ix = i * x; return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix)); diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index 83e149165..af022c141 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -359,21 +359,4 @@ struct LoopedElemToLoc<1, false, OffsetT> { } }; -inline __device__ cuComplex log1p(cuComplex in) { - float x = cuCrealf(in); - float y = cuCimagf(in); - float zabs = sqrt(x * x + y * y); - float theta = atan2f(y, x + 1); - if (zabs < 0.5f) { - float r = x * (2 + x) + y * y; - if (r == 0) { // handle underflow - return {x, theta}; - } - return {0.5f * log1pf(r), theta}; - } else { - auto z0 = sqrt((x + 1) * (x + 1) + y * y); - return {log(z0), theta}; - } -} - } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index e6dbd35da..834e4a3d1 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -161,6 +161,7 @@ constexpr const char* g_include_names[] = { INCLUDE_PREFIX "atomic_ops.cuh", INCLUDE_PREFIX "binary_ops.cuh", INCLUDE_PREFIX "cast_op.cuh", + INCLUDE_PREFIX "cexpf.cuh", INCLUDE_PREFIX "config.h", INCLUDE_PREFIX "cucomplex_math.cuh", INCLUDE_PREFIX "fp16_math.cuh", @@ -177,6 +178,7 @@ constexpr const char* g_headers[] = { jit_source_atomic_ops, jit_source_binary_ops, jit_source_cast_op, + jit_source_cexpf, jit_source_config, jit_source_cucomplex_math, jit_source_fp16_math, diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index a8496b958..3a3f8ff54 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -82,7 +82,6 @@ NO_GPU(Load) NO_GPU_MULTI(LUF) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) -NO_GPU(Scan) NO_GPU(SegmentedMM) NO_GPU_MULTI(SVD) NO_GPU(Inverse) diff --git a/mlx/backend/cuda/reduce/reduce_utils.cuh b/mlx/backend/cuda/reduce/reduce_utils.cuh index ccd7ae48d..d993bacbb 100644 --- a/mlx/backend/cuda/reduce/reduce_utils.cuh +++ b/mlx/backend/cuda/reduce/reduce_utils.cuh @@ -4,6 +4,7 @@ #include +#include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/device/utils.cuh" #include diff --git a/mlx/backend/cuda/scan.cu b/mlx/backend/cuda/scan.cu new file mode 100644 index 000000000..7a26ee161 --- /dev/null +++ b/mlx/backend/cuda/scan.cu @@ -0,0 +1,467 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/binary_ops.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/reduce/reduce_ops.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +struct ScanResult { + using type = T; +}; + +template <> +struct ScanResult { + using type = int32_t; +}; + +template +struct ReduceInit { + static constexpr __host__ __device__ T value() { + return Limits::min(); + } +}; + +template +inline __device__ void +load_values(int index, const T* in, U (&values)[N_READS], int size, U init) { + int remaining = size - index * N_READS; + if constexpr (reverse) { + in += remaining - N_READS; + if (remaining < N_READS) { + for (int i = 0; i < N_READS; ++i) { + values[N_READS - i - 1] = + (N_READS - i - 1 < remaining) ? cast_to(in[i]) : init; + } + } else { + for (int i = 0; i < N_READS; ++i) { + values[N_READS - i - 1] = cast_to(in[i]); + } + } + } else { + in += index * N_READS; + if (remaining < N_READS) { + for (int i = 0; i < N_READS; ++i) { + values[i] = (i < remaining) ? cast_to(in[i]) : init; + } + } else { + for (int i = 0; i < N_READS; ++i) { + values[i] = cast_to(in[i]); + } + } + } +} + +template +inline __device__ void +store_values(int index, T* out, T (&values)[N_READS], int size) { + int start = index * N_READS + offset; + int remaining = size - start; + if constexpr (reverse) { + out += remaining - N_READS; + if (remaining < N_READS) { + for (int i = 0; i < N_READS; ++i) { + if (N_READS - i - 1 < remaining) { + out[i] = values[N_READS - i - 1]; + } + } + } else { + for (int i = 0; i < N_READS; ++i) { + out[i] = values[N_READS - i - 1]; + } + } + } else { + out += start; + if (remaining < N_READS) { + for (int i = 0; i < N_READS; ++i) { + if (i < remaining) { + out[i] = values[i]; + } + } + } else { + for (int i = 0; i < N_READS; ++i) { + out[i] = values[i]; + } + } + } +} + +template < + typename T, + typename U, + typename Op, + int N_READS, + bool inclusive, + bool reverse> +__global__ void contiguous_scan(const T* in, U* out, int32_t axis_size) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + in += grid.block_rank() * axis_size; + out += grid.block_rank() * axis_size; + + __shared__ U warp_sums[WARP_SIZE]; + + Op op; + U init = ReduceInit::value(); + U prefix = init; + + // Scan per block. + for (int r = 0; r < cuda::ceil_div(axis_size, block.size() * N_READS); ++r) { + int32_t index = r * block.size() + block.thread_rank(); + U values[N_READS]; + load_values(index, in, values, axis_size, init); + + // Compute an inclusive scan per thread. + for (int i = 1; i < N_READS; ++i) { + values[i] = op(values[i], values[i - 1]); + } + + // Compute exclusive scan of thread sums. + U prev_thread_sum = cg::exclusive_scan(warp, values[N_READS - 1], op); + if (warp.thread_rank() == 0) { + prev_thread_sum = init; + } + + // Write wrap's sum to shared memory. + if (warp.thread_rank() == WARP_SIZE - 1) { + warp_sums[warp.meta_group_rank()] = + op(prev_thread_sum, values[N_READS - 1]); + } + block.sync(); + + // Compute exclusive scan of warp sums. + if (warp.meta_group_rank() == 0) { + U prev_warp_sum = + cg::exclusive_scan(warp, warp_sums[warp.thread_rank()], op); + if (warp.thread_rank() == 0) { + prev_warp_sum = init; + } + warp_sums[warp.thread_rank()] = prev_warp_sum; + } + block.sync(); + + // Compute the output. + for (int i = 0; i < N_READS; ++i) { + values[i] = op(values[i], prefix); + values[i] = op(values[i], warp_sums[warp.meta_group_rank()]); + values[i] = op(values[i], prev_thread_sum); + } + + // Write the values. + if (inclusive) { + store_values(index, out, values, axis_size); + } else { + store_values(index, out, values, axis_size); + if (reverse) { + if (block.thread_rank() == 0 && index == 0) { + out[axis_size - 1] = init; + } + } else { + if (block.thread_rank() == 0 && index == 0) { + out[0] = init; + } + } + } + block.sync(); + + // Share the prefix. + if ((warp.meta_group_rank() == warp.meta_group_size() - 1) && + (warp.thread_rank() == WARP_SIZE - 1)) { + warp_sums[0] = values[N_READS - 1]; + } + block.sync(); + prefix = warp_sums[0]; + } +} + +template < + typename T, + typename U, + typename Op, + int N_READS, + int BM, + int BN, + bool inclusive, + bool reverse> +__global__ void strided_scan( + const T* in, + U* out, + int32_t axis_size, + int64_t stride, + int64_t stride_blocks) { + auto grid = cg::this_grid(); + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + constexpr int BN_pad = WARP_SIZE + 16 / sizeof(U); + constexpr int n_warps = BN / N_READS; + constexpr int n_scans = BN / n_warps; + + __shared__ U read_buffer[BM * BN_pad]; + + Op op; + U init = ReduceInit::value(); + U values[n_scans]; + U prefix[n_scans]; + for (int i = 0; i < n_scans; ++i) { + prefix[i] = init; + } + + // Compute offsets. + int64_t offset = (grid.block_rank() / stride_blocks) * axis_size * stride; + int64_t global_index_x = (grid.block_rank() % stride_blocks) * BN; + uint read_offset_y = (block.thread_rank() * N_READS) / BN; + uint read_offset_x = (block.thread_rank() * N_READS) % BN; + uint scan_offset_y = warp.thread_rank(); + uint scan_offset_x = warp.meta_group_rank() * n_scans; + + uint stride_limit = stride - global_index_x; + in += offset + global_index_x + read_offset_x; + out += offset + global_index_x + read_offset_x; + U* read_into = read_buffer + read_offset_y * BN_pad + read_offset_x; + U* read_from = read_buffer + scan_offset_y * BN_pad + scan_offset_x; + + for (uint j = 0; j < axis_size; j += BM) { + // Calculate the indices for the current thread. + uint index_y = j + read_offset_y; + uint check_index_y = index_y; + if (reverse) { + index_y = axis_size - 1 - index_y; + } + + // Read in SM. + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; ++i) { + read_into[i] = in[index_y * stride + i]; + } + } else { + for (int i = 0; i < N_READS; ++i) { + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + read_into[i] = in[index_y * stride + i]; + } else { + read_into[i] = init; + } + } + } + block.sync(); + + // Read strided into registers. + for (int i = 0; i < n_scans; ++i) { + values[i] = read_from[i]; + } + + // Perform the scan. + for (int i = 0; i < n_scans; ++i) { + values[i] = cg::inclusive_scan(warp, values[i], op); + values[i] = op(values[i], prefix[i]); + prefix[i] = warp.shfl(values[i], WARP_SIZE - 1); + } + + // Write to SM. + for (int i = 0; i < n_scans; ++i) { + read_from[i] = values[i]; + } + block.sync(); + + // Write to device memory. + if (!inclusive) { + if (check_index_y == 0) { + if ((read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; ++i) { + out[index_y * stride + i] = init; + } + } else { + for (int i = 0; i < N_READS; ++i) { + if ((read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = init; + } + } + } + } + if (reverse) { + index_y -= 1; + check_index_y += 1; + } else { + index_y += 1; + check_index_y += 1; + } + } + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; ++i) { + out[index_y * stride + i] = read_into[i]; + } + } else { + for (int i = 0; i < N_READS; ++i) { + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = read_into[i]; + } + } + } + } +} + +} // namespace cu + +template +void dispatch_scan_ops(Scan::ReduceType scan_op, F&& f) { + if (scan_op == Scan::ReduceType::Max) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::Min) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::Sum) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::Prod) { + f(type_identity{}); + } else if (scan_op == Scan::ReduceType::LogAddExp) { + f(type_identity{}); + } else { + throw std::invalid_argument("Unknown reduce type."); + } +} + +template +const char* op_to_string() { + if (cuda::std::is_same_v) { + return "Max"; + } else if (cuda::std::is_same_v) { + return "Min"; + } else if (cuda::std::is_same_v) { + return "Sum"; + } else if (cuda::std::is_same_v) { + return "Prod"; + } else if (cuda::std::is_same_v) { + return "LogAddExp"; + } else { + throw std::invalid_argument("Unknown op."); + } +} + +template +constexpr bool supports_scan_op() { + if constexpr (cuda::std::is_same_v) { + return is_inexact_v; + } else { + return true; + } +} + +void Scan::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Scan::eval_gpu"); + assert(inputs.size() == 1); + auto in = inputs[0]; + auto& s = stream(); + + if (in.flags().contiguous && in.strides()[axis_] != 0) { + if (in.is_donatable() && in.itemsize() == out.itemsize()) { + out.copy_shared_buffer(in); + } else { + out.set_data( + allocator::malloc(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } + } else { + array arr_copy(in.shape(), in.dtype(), nullptr, {}); + copy_gpu(in, arr_copy, CopyType::General, s); + in = std::move(arr_copy); + out.copy_shared_buffer(in); + } + + constexpr int N_READS = 4; + int32_t axis_size = in.shape(axis_); + bool contiguous = in.strides()[axis_] == 1; + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using T = cuda_type_t; + dispatch_scan_ops(reduce_type_, [&](auto scan_op_tag) { + using Op = MLX_GET_TYPE(scan_op_tag); + if constexpr (supports_scan_op) { + using U = typename cu::ScanResult::type; + dispatch_bool(inclusive_, [&](auto inclusive) { + dispatch_bool(reverse_, [&](auto reverse) { + if (contiguous) { + auto kernel = cu::contiguous_scan< + T, + U, + Op, + N_READS, + inclusive.value, + reverse.value>; + int block_dim = cuda::ceil_div(axis_size, N_READS); + block_dim = cuda::ceil_div(block_dim, WARP_SIZE) * WARP_SIZE; + block_dim = std::min(block_dim, WARP_SIZE * WARP_SIZE); + encoder.add_kernel_node( + kernel, + in.data_size() / axis_size, + block_dim, + in.data(), + out.data(), + axis_size); + } else { + constexpr int BM = WARP_SIZE; + constexpr int BN = WARP_SIZE; + auto kernel = cu::strided_scan< + T, + U, + Op, + N_READS, + BM, + BN, + inclusive.value, + reverse.value>; + int64_t stride = in.strides()[axis_]; + int64_t stride_blocks = cuda::ceil_div(stride, BN); + dim3 num_blocks = get_2d_grid_dims( + in.shape(), in.strides(), axis_size * stride); + if (num_blocks.x * stride_blocks <= UINT32_MAX) { + num_blocks.x *= stride_blocks; + } else { + num_blocks.y *= stride_blocks; + } + int block_dim = (BN / N_READS) * WARP_SIZE; + encoder.add_kernel_node( + kernel, + num_blocks, + block_dim, + in.data(), + out.data(), + axis_size, + stride, + stride_blocks); + } + }); + }); + } else { + throw std::runtime_error(fmt::format( + "Can not do scan op {} on inputs of {} with result of {}.", + op_to_string(), + dtype_to_string(in.dtype()), + dtype_to_string(out.dtype()))); + } + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/kernels/cexpf.h b/mlx/backend/metal/kernels/cexpf.h new file mode 100644 index 000000000..b45fe6a2f --- /dev/null +++ b/mlx/backend/metal/kernels/cexpf.h @@ -0,0 +1,134 @@ +// Copyright © 2025 Apple Inc. +// Copyright © 2008-2013 NVIDIA Corporation +// Copyright © 2013 Filipe RNC Maia +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Forked from +// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h + +// TODO: We should use thrust::exp but the thrust header in old CUDA versions +// can not be used in JIT. + +#pragma once + +#include + +using ieee_float_shape_type = union { + float value; + uint32_t word; +}; + +inline void get_float_word(thread uint32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} + +inline void get_float_word(thread int32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} + +inline void set_float_word(thread float& d, uint32_t i) { + ieee_float_shape_type sf_u; + sf_u.word = (i); + (d) = sf_u.value; +} + +inline float frexp_expf(float x, thread int* expt) { + const uint32_t k = 235; + const float kln2 = 162.88958740F; + + float exp_x; + uint32_t hx; + + exp_x = metal::exp(x - kln2); + get_float_word(hx, exp_x); + *expt = (hx >> 23) - (0x7f + 127) + k; + set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23)); + return exp_x; +} + +inline complex64_t ldexp_cexpf(complex64_t z, int expt) { + float x, y, exp_x, scale1, scale2; + int ex_expt, half_expt; + + x = z.real; + y = z.imag; + exp_x = frexp_expf(x, &ex_expt); + expt += ex_expt; + + half_expt = expt / 2; + set_float_word(scale1, (0x7f + half_expt) << 23); + half_expt = expt - half_expt; + set_float_word(scale2, (0x7f + half_expt) << 23); + + return complex64_t{ + metal::cos(y) * exp_x * scale1 * scale2, + metal::sin(y) * exp_x * scale1 * scale2}; +} + +inline complex64_t cexpf(const thread complex64_t& z) { + float x, y, exp_x; + uint32_t hx, hy; + + const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074; + + x = z.real; + y = z.imag; + + get_float_word(hy, y); + hy &= 0x7fffffff; + + /* cexp(x + I 0) = exp(x) + I 0 */ + if (hy == 0) { + return complex64_t{metal::exp(x), y}; + } + get_float_word(hx, x); + /* cexp(0 + I y) = cos(y) + I sin(y) */ + if ((hx & 0x7fffffff) == 0) { + return complex64_t{metal::cos(y), metal::sin(y)}; + } + if (hy >= 0x7f800000) { + if ((hx & 0x7fffffff) != 0x7f800000) { + /* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */ + return complex64_t{y - y, y - y}; + } else if (hx & 0x80000000) { + /* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */ + return complex64_t{0.0, 0.0}; + } else { + /* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */ + return complex64_t{x, y - y}; + } + } + + if (hx >= exp_ovfl && hx <= cexp_ovfl) { + /* + * x is between 88.7 and 192, so we must scale to avoid + * overflow in expf(x). + */ + return ldexp_cexpf(z, 0); + } else { + /* + * Cases covered here: + * - x < exp_ovfl and exp(x) won't overflow (common case) + * - x > cexp_ovfl, so exp(x) * s overflows for all s > 0 + * - x = +-Inf (generated by exp()) + * - x = NaN (spurious inexact exception from y) + */ + exp_x = metal::exp(x); + return complex64_t{exp_x * metal::cos(y), exp_x * metal::sin(y)}; + } +} diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index 09d9f6605..b34bc44ba 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -5,6 +5,7 @@ #include #include +#include "mlx/backend/metal/kernels/cexpf.h" #include "mlx/backend/metal/kernels/erf.h" #include "mlx/backend/metal/kernels/expm1f.h" @@ -178,8 +179,7 @@ struct Exp { return metal::precise::exp(x); }; complex64_t operator()(complex64_t x) { - auto m = metal::precise::exp(x.real); - return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; + return cexpf(x); } }; diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index afd48bd03..005c612ff 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -13,11 +13,6 @@ cuda_skip = { "TestBlas.test_gather_mm_sorted", # Segmented matmul NYI "TestBlas.test_segmented_mm", - # Scan NYI - "TestArray.test_api", - "TestAutograd.test_cumprod_grad", - "TestOps.test_scans", - "TestOps.test_logcumsumexp", # Hadamard NYI "TestOps.test_hadamard", "TestOps.test_hadamard_grad_vmap", diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 1a9781c7c..969bc2ba7 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1350,6 +1350,11 @@ TEST_CASE("test arithmetic unary ops") { x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0]; auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1}); CHECK(allclose(exp(x), expected).item()); + + // Complex of -inf + constexpr float inf = std::numeric_limits::infinity(); + x = array(complex64_t{-inf, -inf}); + CHECK_EQ(exp(x).item(), complex64_t{0, 0}); } // Test expm1 @@ -1830,6 +1835,10 @@ TEST_CASE("test arithmetic binary ops") { x = array(-inf); y = array(inf); CHECK_EQ(logaddexp(x, y).item(), inf); + + x = array(complex64_t{1, 1}); + y = array(complex64_t{-inf, -inf}); + CHECK_EQ(logaddexp(x, y).item(), complex64_t{1, 1}); } TEST_CASE("test broadcast") {