From a0ae49d397252a05d8a6881bf950d9c80614d2d9 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 30 Jul 2025 13:05:51 +0900 Subject: [PATCH] Move arange to its own file (#2438) --- mlx/backend/cuda/CMakeLists.txt | 3 +- mlx/backend/cuda/arange.cu | 55 +++++++++++++++++++ mlx/backend/cuda/device/arange.cuh | 15 ----- .../cuda/{primitives.cu => primitives.cpp} | 36 ------------ 4 files changed, 57 insertions(+), 52 deletions(-) create mode 100644 mlx/backend/cuda/arange.cu delete mode 100644 mlx/backend/cuda/device/arange.cuh rename mlx/backend/cuda/{primitives.cu => primitives.cpp} (56%) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index fc37b61b2d..5bc75e2e0a 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -6,6 +6,7 @@ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/arange.cu ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu @@ -29,7 +30,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu - ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu diff --git a/mlx/backend/cuda/arange.cu b/mlx/backend/cuda/arange.cu new file mode 100644 index 0000000000..6190ac272d --- /dev/null +++ b/mlx/backend/cuda/arange.cu @@ -0,0 +1,55 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/fp16_math.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include +#include + +namespace mlx::core { + +namespace cu { + +template +struct Arange { + const T start; + const T step; + + __device__ T operator()(uint32_t i) const { + return start + i * step; + } +}; + +} // namespace cu + +void Arange::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Arange::eval_gpu"); + if (out.size() == 0) { + return; + } + out.set_data(allocator::malloc(out.nbytes())); + + auto& encoder = cu::get_command_encoder(stream()); + encoder.set_output_array(out); + + auto capture = encoder.capture_context(); + dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + using OutType = cuda_type_t; + CTYPE step = + static_cast(start_ + step_) - static_cast(start_); + thrust::transform( + cu::thrust_policy(encoder.stream()), + thrust::counting_iterator(0), + thrust::counting_iterator(out.data_size()), + thrust::device_pointer_cast(out.data()), + cu::Arange{ + static_cast(start_), static_cast(step)}); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/device/arange.cuh b/mlx/backend/cuda/device/arange.cuh deleted file mode 100644 index 53c261e345..0000000000 --- a/mlx/backend/cuda/device/arange.cuh +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright © 2025 Apple Inc. - -namespace mlx::core::cu { - -template -struct Arange { - const T start; - const T step; - - __device__ T operator()(uint32_t i) const { - return start + i * step; - } -}; - -} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cpp similarity index 56% rename from mlx/backend/cuda/primitives.cu rename to mlx/backend/cuda/primitives.cpp index 0451c9e546..553b2e54a5 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cpp @@ -1,47 +1,11 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/device/arange.cuh" -#include "mlx/backend/cuda/device/fp16_math.cuh" -#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/distributed/primitives.h" -#include "mlx/dtype_utils.h" #include "mlx/fast_primitives.h" #include "mlx/primitives.h" -#include -#include -#include - -#include - namespace mlx::core { -void Arange::eval_gpu(const std::vector& inputs, array& out) { - nvtx3::scoped_range r("Arange::eval_gpu"); - assert(inputs.size() == 0); - out.set_data(allocator::malloc(out.nbytes())); - if (out.size() == 0) { - return; - } - auto& encoder = cu::get_command_encoder(stream()); - encoder.set_output_array(out); - auto capture = encoder.capture_context(); - dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { - using CTYPE = MLX_GET_TYPE(type_tag); - using OutType = cuda_type_t; - CTYPE step = - static_cast(start_ + step_) - static_cast(start_); - thrust::transform( - cu::thrust_policy(encoder.stream()), - thrust::counting_iterator(0), - thrust::counting_iterator(out.data_size()), - thrust::device_pointer_cast(out.data()), - cu::Arange{ - static_cast(start_), static_cast(step)}); - }); -} - bool fast::ScaledDotProductAttention::use_fallback( const array& q, const array& k,