diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 93f24dcd4..07a6c4f63 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -13,12 +13,16 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cu + ${CMAKE_CURRENT_SOURCE_DIR}/gather.cu + ${CMAKE_CURRENT_SOURCE_DIR}/gather_axis.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu + ${CMAKE_CURRENT_SOURCE_DIR}/scatter.cu + ${CMAKE_CURRENT_SOURCE_DIR}/scatter_axis.cu ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu @@ -31,11 +35,22 @@ target_compile_definitions(mlx PUBLIC MLX_USE_CUDA) target_compile_options(mlx PRIVATE "$<$:--extended-lambda>") +# Ignore uncommon types in ops to speed up compilation. +if(MLX_FAST_COMPILE) + target_compile_definitions(mlx PUBLIC MLX_FAST_COMPILE) +endif() + # Compute capability 7 is required for synchronization between CPU/GPU with # managed memory. TODO: Add more architectures for potential performance gain. -set(MLX_CUDA_ARCHITECTURES - "70;80" - CACHE STRING "CUDA architectures") +if(MLX_FAST_COMPILE) + set(MLX_CUDA_ARCHITECTURES + "70" + CACHE STRING "CUDA architectures") +else() + set(MLX_CUDA_ARCHITECTURES + "70;80" + CACHE STRING "CUDA architectures") +endif() message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}") set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES "${MLX_CUDA_ARCHITECTURES}") diff --git a/mlx/backend/cuda/gather.cu b/mlx/backend/cuda/gather.cu new file mode 100644 index 000000000..05502ed92 --- /dev/null +++ b/mlx/backend/cuda/gather.cu @@ -0,0 +1,60 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/indexing.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +void Gather::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Gather::eval_gpu"); + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + + const auto& src = inputs[0]; + size_t nidx = inputs.size() - 1; + auto idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; + auto idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_INDEX_TYPES_CHECKED(idx_dtype, "gather", CTYPE_IDX, { + using IndexType = cuda_type_t; + MLX_SWITCH_NIDX(inputs.size() - 1, NIDX, { + MLX_SWITCH_IDX_NDIM(idx_ndim, IDX_NDIM, { + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_DATA, { + using DataType = cuda_type_t; + auto idx_begin = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), + cu::IndicesOp( + src, + slice_sizes_, + axes_, + inputs.begin() + 1, + inputs.end())); + thrust::gather( + cu::thrust_policy(stream), + idx_begin, + idx_begin + out.size(), + src.data(), + out.data()); + }); + }); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/gather_axis.cu b/mlx/backend/cuda/gather_axis.cu new file mode 100644 index 000000000..098b0a0f7 --- /dev/null +++ b/mlx/backend/cuda/gather_axis.cu @@ -0,0 +1,53 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/indexing.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("GatherAxis::eval_gpu"); + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + + auto& src = inputs[0]; + auto& idx = inputs[1]; + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(src); + encoder.set_input_array(idx); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_INDEX_TYPES_CHECKED(idx.dtype(), "gather_axis", CTYPE_IDX, { + using IndexType = cuda_type_t; + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_DATA, { + using DataType = cuda_type_t; + MLX_SWITCH_BOOL(src.flags().row_contiguous, SRC_CONT, { + MLX_SWITCH_BOOL(idx.flags().row_contiguous, IDX_CONT, { + auto idx_begin = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), + cu::IndexOp(idx, src, axis_)); + thrust::gather( + cu::thrust_policy(stream), + idx_begin, + idx_begin + idx.size(), + src.data(), + out.data()); + }); + }); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/indexing.cuh b/mlx/backend/cuda/indexing.cuh new file mode 100644 index 000000000..260e86b51 --- /dev/null +++ b/mlx/backend/cuda/indexing.cuh @@ -0,0 +1,260 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/fp16_math.cuh" + +#include +#include + +namespace mlx::core::cu { + +// Only allow int32 as index type if MLX_FAST_COMPILE is defined. +#if defined(MLX_FAST_COMPILE) +#define MLX_SWITCH_INDEX_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ + if (TYPE == ::mlx::core::int32) { \ + using CTYPE_ALIAS = int32_t; \ + __VA_ARGS__; \ + } else if (TYPE == ::mlx::core::uint32) { \ + using CTYPE_ALIAS = uint32_t; \ + __VA_ARGS__; \ + } else { \ + throw std::invalid_argument(fmt::format( \ + "Can not use dtype {} as index for {} when MLX_FAST_COMPILE is on.", \ + dtype_to_string(TYPE), \ + NAME)); \ + } +#else +#define MLX_SWITCH_INDEX_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ + MLX_SWITCH_INT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, __VA_ARGS__) +#endif + +// Dispatch dynamic nidx to constexpr. +#if defined(MLX_FAST_COMPILE) +#define MLX_SWITCH_NIDX(nidx, NIDX, ...) \ + { \ + assert(nidx <= 16); \ + constexpr uint32_t NIDX = 16; \ + __VA_ARGS__; \ + } +#else +#define MLX_SWITCH_NIDX(nidx, NIDX, ...) \ + if (nidx <= 2) { \ + constexpr uint32_t NIDX = 2; \ + __VA_ARGS__; \ + } else if (nidx <= 16) { \ + constexpr uint32_t NIDX = 16; \ + __VA_ARGS__; \ + } else { \ + throw std::runtime_error( \ + fmt::format("Indices array can not have more than {} items", nidx)); \ + } +#endif + +// Dispatch dynamic idx_ndim to constexpr. +#define MORE_THAN_ONE MAX_NDIM +#define MLX_SWITCH_IDX_NDIM(idx_ndim, IDX_NDIM, ...) \ + if (idx_ndim == 0) { \ + constexpr uint32_t IDX_NDIM = 0; \ + __VA_ARGS__; \ + } else if (idx_ndim == 1) { \ + constexpr uint32_t IDX_NDIM = 1; \ + __VA_ARGS__; \ + } else { \ + constexpr uint32_t IDX_NDIM = MORE_THAN_ONE; \ + __VA_ARGS__; \ + } + +// Like thrust::scatter but accept custom op. +template +void scatter_n(Policy& policy, It begin, size_t size, Idx idx, Out out, Op op) { + thrust::for_each_n( + policy, + thrust::make_zip_iterator(begin, idx), + size, + [out, op] __device__(auto item) { + op(&out[thrust::get<1>(item)], thrust::get<0>(item)); + }); +} + +// Convert an absolute index to positions in a 3d grid, assuming the index is +// calculated with: +// index = x * dim1 * dim2 + y * dim2 + z +template +inline __device__ cuda::std::tuple +index_to_dims(T index, size_t dim1, size_t dim2) { + T x = index / (dim1 * dim2); + T y = (index % (dim1 * dim2)) / dim2; + T z = index % dim2; + return cuda::std::make_tuple(x, y, z); +} + +// Get absolute index from possible negative index. +template +inline __device__ auto absolute_index(IdxT idx, int32_t size) { + if constexpr (cuda::std::is_unsigned_v) { + return idx; + } else { + return static_cast(idx < 0 ? idx + size : idx); + } +} + +// An op that takes an index of |src|, and returns the corresponding index value +// from |idx| which is the indices at |axis|. +template +struct IndexOp { + const IdxT* idx; + int32_t ndim; + Shape shape; + Strides src_strides; + Strides idx_strides; + int32_t src_axis_size; + int32_t idx_axis_size; + int64_t src_axis_stride; + int64_t idx_axis_stride; + size_t size_post; + + IndexOp(const array& idx, const array& src, int32_t axis) + : idx(idx.data()), + ndim(static_cast(src.ndim()) - 1), + shape(const_param(remove_index(idx.shape(), axis))), + src_strides(const_param(remove_index(src.strides(), axis))), + idx_strides(const_param(remove_index(idx.strides(), axis))), + src_axis_size(src.shape(axis)), + idx_axis_size(idx.shape(axis)), + src_axis_stride(src.strides(axis)), + idx_axis_stride(idx.strides(axis)) { + size_post = 1; + for (int i = axis + 1; i < idx.ndim(); ++i) { + size_post *= idx.shape(i); + } + } + + __device__ LocT operator()(size_t index) { + auto [x, y, z] = index_to_dims(index, idx_axis_size, size_post); + + LocT elem_idx = x * size_post; + + LocT idx_loc = y * idx_axis_stride; + if constexpr (IdxC) { + idx_loc += elem_idx * idx_axis_size + z; + } else { + idx_loc += + elem_to_loc(elem_idx + z, shape.data(), idx_strides.data(), ndim); + } + + auto idx_val = absolute_index(idx[idx_loc], src_axis_size); + + LocT src_idx = idx_val * src_axis_stride; + if constexpr (SrcC) { + src_idx += elem_idx * src_axis_size + z; + } else { + src_idx += + elem_to_loc(elem_idx + z, shape.data(), src_strides.data(), ndim); + } + return src_idx; + } +}; + +// Concatenated |idx| arrays. +template +struct Indices { + size_t size; + size_t ndim; + cuda::std::array buffers; + cuda::std::array row_contiguous; + cuda::std::array shapes; + cuda::std::array strides; + + template + Indices(Iter begin, Iter end) { + size = end - begin; + ndim = size > 0 ? begin->ndim() : 0; + for (size_t i = 0; i < size; ++i) { + const array& arr = *(begin + i); + buffers[i] = arr.data(); + row_contiguous[i] = arr.flags().row_contiguous; + std::copy_n(arr.shape().begin(), ndim, shapes.begin() + i * ndim); + std::copy_n(arr.strides().begin(), ndim, strides.begin() + i * ndim); + } + } + + __device__ auto operator()(size_t i, size_t x, size_t y) { + LocT idx_loc; + if constexpr (IDX_NDIM == 0) { + idx_loc = 0; + } else { + idx_loc = x * strides[ndim * i]; + if constexpr (IDX_NDIM == MORE_THAN_ONE) { + if (row_contiguous[i]) { + idx_loc += y; + } else { + size_t offset = ndim * i + 1; + idx_loc += elem_to_loc( + y, shapes.data() + offset, strides.data() + offset, ndim - 1); + } + } + } + return buffers[i][idx_loc]; + } +}; + +// An op that takes an index of |src|, and returns the corresponding index value +// from |indices| located at |axes|. +template +struct IndicesOp { + size_t ndim; + Shape shape; + Strides strides; + Shape slice_sizes; + Shape axes; + Indices indices; + size_t n_dim0; + size_t slice_size; + + template + IndicesOp( + const array& src, + const std::vector& slice_sizes, + const std::vector& axes, + Iter idx_begin, + Iter idx_end) + : ndim(src.ndim()), + shape(const_param(src.shape())), + strides(const_param(src.strides())), + slice_sizes(const_param(slice_sizes)), + axes(const_param(axes)), + indices(idx_begin, idx_end) { + n_dim0 = 1; + size_t dim0 = 1; + if (indices.ndim >= 1) { + dim0 = idx_begin->shape(0); + } + if (indices.ndim >= 2) { + n_dim0 = idx_begin->size() / dim0; + } + + slice_size = 1; + for (size_t s : slice_sizes) { + slice_size *= s; + } + } + + __device__ LocT operator()(size_t index) { + auto [x, y, z] = index_to_dims(index, n_dim0, slice_size); + + LocT src_idx = 0; + for (size_t i = 0; i < indices.size; ++i) { + auto ax = axes[i]; + auto idx_val = absolute_index(indices(i, x, y), shape[ax]); + src_idx += static_cast(idx_val) * strides[ax]; + } + + LocT src_offset = elem_to_loc(z, slice_sizes.data(), strides.data(), ndim); + return src_idx + src_offset; + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index 6eb863ee1..d10d505a5 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -15,6 +15,16 @@ namespace mlx::core { +// Like MLX_SWITCH_ALL_TYPES but for booleans. +#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \ + if (BOOL) { \ + constexpr bool BOOL_ALIAS = true; \ + __VA_ARGS__; \ + } else { \ + constexpr bool BOOL_ALIAS = false; \ + __VA_ARGS__; \ + } + // Helper macros for dispatch macros (see below). #define MLX_INTERNAL_IF_CASE(DIM, BLOCK_DIM, ...) \ } \ diff --git a/mlx/backend/cuda/kernels/atomic_ops.cuh b/mlx/backend/cuda/kernels/atomic_ops.cuh new file mode 100644 index 000000000..f0815fd29 --- /dev/null +++ b/mlx/backend/cuda/kernels/atomic_ops.cuh @@ -0,0 +1,67 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/kernels/cucomplex_math.cuh" +#include "mlx/backend/cuda/kernels/fp16_math.cuh" + +#include + +namespace mlx::core::cu { + +template +inline __device__ void atomic_add(T* out, T val) { + cuda::atomic_ref ref(*out); + ref += val; +} + +template +inline __device__ void atomic_prod(T* out, T val) { + cuda::atomic_ref ref(*out); + T old = ref.load(); + while (!ref.compare_exchange_strong(old, old * val)) { + } +} + +template +inline __device__ void atomic_max(T* out, T val) { + cuda::atomic_ref ref(*out); + ref.fetch_max(val); +} + +template +inline __device__ void atomic_min(T* out, T val) { + cuda::atomic_ref ref(*out); + ref.fetch_min(val); +} + +// Somehow cuda::atomic_ref does not provide atomic add for following types. +template +inline __device__ void atomic_add_general(T* out, T val) { + cuda::atomic_ref ref(*out); + T old = ref.load(); + while (!ref.compare_exchange_strong(old, old + val)) { + } +} + +inline __device__ void atomic_add(__half* out, __half val) { + atomicAdd(out, val); +} + +inline __device__ void atomic_add(cuComplex* out, cuComplex val) { +#if __CUDA_ARCH__ < 900 + atomic_add_general(out, val); +#else + atomicAdd(out, val); +#endif +} + +inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) { +#if __CUDA_ARCH__ < 800 + atomic_add_general(out, val); +#else + atomicAdd(out, val); +#endif +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/kernels/scatter_ops.cuh b/mlx/backend/cuda/kernels/scatter_ops.cuh new file mode 100644 index 000000000..3b13f3a0c --- /dev/null +++ b/mlx/backend/cuda/kernels/scatter_ops.cuh @@ -0,0 +1,44 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/kernels/atomic_ops.cuh" + +namespace mlx::core::cu { + +template +struct ScatterAssign { + __device__ void operator()(T* out, T val) const { + *out = val; + } +}; + +template +struct ScatterSum { + __device__ void operator()(T* out, T val) const { + atomic_add(out, val); + } +}; + +template +struct ScatterProd { + __device__ void operator()(T* out, T val) const { + atomic_prod(out, val); + } +}; + +template +struct ScatterMax { + __device__ void operator()(T* out, T val) const { + atomic_max(out, val); + } +}; + +template +struct ScatterMin { + __device__ void operator()(T* out, T val) const { + atomic_min(out, val); + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index 4e8c171d5..472841229 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -116,8 +116,6 @@ NO_GPU_MULTI(DivMod) NO_GPU(DynamicSlice) NO_GPU(DynamicSliceUpdate) NO_GPU(FFT) -NO_GPU(Gather) -NO_GPU(GatherAxis) NO_GPU(GatherMM) NO_GPU(GatherQMM) NO_GPU(Hadamard) @@ -128,8 +126,6 @@ NO_GPU(Partition) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) NO_GPU(Scan) -NO_GPU(Scatter) -NO_GPU(ScatterAxis) NO_GPU(Select) NO_GPU(SliceUpdate) NO_GPU(Softmax) diff --git a/mlx/backend/cuda/scatter.cu b/mlx/backend/cuda/scatter.cu new file mode 100644 index 000000000..634e76805 --- /dev/null +++ b/mlx/backend/cuda/scatter.cu @@ -0,0 +1,104 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/indexing.cuh" +#include "mlx/backend/cuda/iterators/general_iterator.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/scatter_ops.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +// Like MLX_SWITCH_ALL_TYPES but for reductions. +#define MLX_SWITCH_SCATTER_OP(REDUCE, REDUCE_ALIAS, ...) \ + if (REDUCE == Scatter::Sum) { \ + using REDUCE_ALIAS = mlx::core::cu::ScatterSum; \ + __VA_ARGS__; \ + } else if (REDUCE == Scatter::Prod) { \ + using REDUCE_ALIAS = mlx::core::cu::ScatterProd; \ + __VA_ARGS__; \ + } else if (REDUCE == Scatter::Max) { \ + using REDUCE_ALIAS = mlx::core::cu::ScatterMax; \ + __VA_ARGS__; \ + } else if (REDUCE == Scatter::Min) { \ + using REDUCE_ALIAS = mlx::core::cu::ScatterMin; \ + __VA_ARGS__; \ + } else { \ + using REDUCE_ALIAS = mlx::core::cu::ScatterAssign; \ + __VA_ARGS__; \ + } + +void Scatter::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Gather::eval_gpu"); + auto& upd = inputs.back(); + + // Copy src into out. + CopyType copy_type; + if (inputs[0].data_size() == 1) { + copy_type = CopyType::Scalar; + } else if (inputs[0].flags().row_contiguous) { + copy_type = CopyType::Vector; + } else { + copy_type = CopyType::General; + } + copy_gpu(inputs[0], out, copy_type); + + // Empty update. + if (upd.size() == 0) { + return; + } + + size_t nidx = axes_.size(); + auto idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; + auto idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; + + Shape upd_shape_post(upd.shape().begin() + idx_ndim, upd.shape().end()); + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_INDEX_TYPES_CHECKED(idx_dtype, "scatter", CTYPE_IDX, { + using IndexType = cuda_type_t; + MLX_SWITCH_NIDX(nidx, NIDX, { + MLX_SWITCH_IDX_NDIM(idx_ndim, IDX_NDIM, { + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_DATA, { + using DataType = cuda_type_t; + MLX_SWITCH_SCATTER_OP(reduce_type_, SCATTER_OP, { + auto policy = cu::thrust_policy(stream); + auto upd_ptr = thrust::device_pointer_cast(upd.data()); + auto size = upd.size(); + auto idx_begin = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), + cu::IndicesOp( + out, + upd_shape_post, + axes_, + inputs.begin() + 1, + inputs.begin() + 1 + nidx)); + auto out_ptr = out.data(); + SCATTER_OP op; + if (upd.flags().row_contiguous) { + cu::scatter_n(policy, upd_ptr, size, idx_begin, out_ptr, op); + } else { + auto upd_begin = cu::make_general_iterator( + upd_ptr, upd.shape(), upd.strides()); + cu::scatter_n(policy, upd_begin, size, idx_begin, out_ptr, op); + } + }); + }); + }); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/scatter_axis.cu b/mlx/backend/cuda/scatter_axis.cu new file mode 100644 index 000000000..cb5dad58b --- /dev/null +++ b/mlx/backend/cuda/scatter_axis.cu @@ -0,0 +1,83 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/indexing.cuh" +#include "mlx/backend/cuda/iterators/general_iterator.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/kernels/scatter_ops.cuh" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +// Like MLX_SWITCH_ALL_TYPES but for reductions. +#define MLX_SWITCH_SCATTER_AXIS_OP(REDUCE, REDUCE_ALIAS, ...) \ + if (REDUCE == ScatterAxis::Sum) { \ + using REDUCE_ALIAS = mlx::core::cu::ScatterSum; \ + __VA_ARGS__; \ + } else { \ + using REDUCE_ALIAS = mlx::core::cu::ScatterAssign; \ + __VA_ARGS__; \ + } + +void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("ScatterAxis::eval_gpu"); + auto& src = inputs[0]; + auto& idx = inputs[1]; + auto& upd = inputs[2]; + + // Copy src into out. + CopyType copy_type; + if (src.data_size() == 1) { + copy_type = CopyType::Scalar; + } else if (src.flags().row_contiguous) { + copy_type = CopyType::Vector; + } else { + copy_type = CopyType::General; + } + copy_gpu(src, out, copy_type); + + // Empty update. + if (upd.size() == 0) { + return; + } + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(idx); + encoder.set_input_array(upd); + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + MLX_SWITCH_INDEX_TYPES_CHECKED(idx.dtype(), "scatter_axis", CTYPE_IDX, { + using IndexType = cuda_type_t; + MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_DATA, { + using DataType = cuda_type_t; + MLX_SWITCH_BOOL(idx.flags().row_contiguous, IDX_CONT, { + MLX_SWITCH_SCATTER_AXIS_OP(reduce_type_, SCATTER_OP, { + auto policy = cu::thrust_policy(stream); + auto upd_ptr = thrust::device_pointer_cast(upd.data()); + auto size = upd.size(); + auto idx_begin = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), + cu::IndexOp(idx, out, axis_)); + auto out_ptr = out.data(); + SCATTER_OP op; + if (upd.flags().row_contiguous) { + cu::scatter_n(policy, upd_ptr, size, idx_begin, out_ptr, op); + } else { + auto upd_begin = cu::make_general_iterator( + upd_ptr, upd.shape(), upd.strides()); + cu::scatter_n(policy, upd_begin, size, idx_begin, out_ptr, op); + } + }); + }); + }); + }); + }); +} + +} // namespace mlx::core