From c8f79d38ec51090be35c2e5d414b961093398f62 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 21 May 2025 02:15:09 +0000 Subject: [PATCH] CUDA backend: indexing ops --- mlx/backend/cuda/CMakeLists.txt | 5 +- mlx/backend/cuda/device/atomic_ops.cuh | 72 ++++ mlx/backend/cuda/device/gather.cuh | 53 +++ mlx/backend/cuda/device/gather_axis.cuh | 65 ++++ mlx/backend/cuda/device/indexing.cuh | 30 ++ mlx/backend/cuda/device/scatter.cuh | 68 ++++ mlx/backend/cuda/device/scatter_axis.cuh | 67 ++++ mlx/backend/cuda/device/scatter_ops.cuh | 44 +++ mlx/backend/cuda/indexing.cpp | 420 +++++++++++++++++++++++ mlx/backend/cuda/jit_module.cpp | 8 + mlx/backend/cuda/primitives.cu | 4 - 11 files changed, 830 insertions(+), 6 deletions(-) create mode 100644 mlx/backend/cuda/device/atomic_ops.cuh create mode 100644 mlx/backend/cuda/device/gather.cuh create mode 100644 mlx/backend/cuda/device/gather_axis.cuh create mode 100644 mlx/backend/cuda/device/indexing.cuh create mode 100644 mlx/backend/cuda/device/scatter.cuh create mode 100644 mlx/backend/cuda/device/scatter_axis.cuh create mode 100644 mlx/backend/cuda/device/scatter_ops.cuh create mode 100644 mlx/backend/cuda/indexing.cpp diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 1567feafd..7cc74353a 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -1,8 +1,8 @@ # Filename rules in cuda backend: # # * Use .cu/.cuh if code contains device code, and .cpp/.h if not. -# * Device-only kernel code should be put in kernels/ subdir. -# * Files in kernels/ subdir should not include files outside. +# * Device-only code should be put in device/ subdir. +# * Files in device/ subdir should not include files outside. target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp @@ -20,6 +20,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu diff --git a/mlx/backend/cuda/device/atomic_ops.cuh b/mlx/backend/cuda/device/atomic_ops.cuh new file mode 100644 index 000000000..b6915606e --- /dev/null +++ b/mlx/backend/cuda/device/atomic_ops.cuh @@ -0,0 +1,72 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/cucomplex_math.cuh" +#include "mlx/backend/cuda/device/fp16_math.cuh" + +#include + +namespace mlx::core::cu { + +template +inline __device__ void atomic_add(T* out, T val) { + cuda::atomic_ref ref(*out); + ref += val; +} + +template +inline __device__ void atomic_prod(T* out, T val) { + cuda::atomic_ref ref(*out); + T old = ref.load(); + while (!ref.compare_exchange_strong(old, old * val)) { + } +} + +template +inline __device__ void atomic_max(T* out, T val) { + cuda::atomic_ref ref(*out); + ref.fetch_max(val); +} + +template +inline __device__ void atomic_min(T* out, T val) { + cuda::atomic_ref ref(*out); + ref.fetch_min(val); +} + +// Somehow cuda::atomic_ref does not provide atomic add for following types. +template +inline __device__ void atomic_add_general(T* out, T val) { + cuda::atomic_ref ref(*out); + T old = ref.load(); + while (!ref.compare_exchange_strong(old, old + val)) { + } +} + +inline __device__ void atomic_add(__half* out, __half val) { + atomicAdd(out, val); +} + +inline __device__ void atomic_add(cuComplex* out, cuComplex val) { +#if __CUDA_ARCH__ < 900 + atomic_add_general(out, val); +#else + atomicAdd(out, val); +#endif +} + +inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) { +#if __CUDA_ARCH__ < 800 +#if CCCL_VERSION >= 2008000 + atomic_add_general(out, val); +#else + bool cccl_version_too_old_for_bfloat16_atomic_add = false; + assert(cccl_version_too_old_for_bfloat16_atomic_add); +#endif +#else + atomicAdd(out, val); +#endif +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/gather.cuh b/mlx/backend/cuda/device/gather.cuh new file mode 100644 index 000000000..7dbd84ac3 --- /dev/null +++ b/mlx/backend/cuda/device/gather.cuh @@ -0,0 +1,53 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device/indexing.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template +__global__ void gather( + const T* src, + T* out, + LocT size, + const __grid_constant__ Shape src_shape, + const __grid_constant__ Strides src_strides, + int32_t src_ndim, + const __grid_constant__ Shape slice_sizes, + uint32_t slice_size, + const __grid_constant__ cuda::std::array axes, + const __grid_constant__ cuda::std::array indices, + const __grid_constant__ cuda::std::array + indices_shape, + const __grid_constant__ cuda::std::array + indices_strides) { + LocT out_idx = cg::this_grid().thread_rank(); + if (out_idx >= size) { + return; + } + + LocT src_elem = out_idx % slice_size; + LocT idx_elem = out_idx / slice_size; + + LocT src_loc = + elem_to_loc(src_elem, slice_sizes.data(), src_strides.data(), src_ndim); + +#pragma unroll + for (int i = 0; i < NIDX; ++i) { + LocT idx_loc = elem_to_loc_nd( + idx_elem, + indices_shape.data() + i * IDX_NDIM, + indices_strides.data() + i * IDX_NDIM); + int32_t axis = axes[i]; + LocT idx_val = absolute_index(indices[i][idx_loc], src_shape[axis]); + src_loc += idx_val * src_strides[axis]; + } + + out[out_idx] = src[src_loc]; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/gather_axis.cuh b/mlx/backend/cuda/device/gather_axis.cuh new file mode 100644 index 000000000..f863b2d95 --- /dev/null +++ b/mlx/backend/cuda/device/gather_axis.cuh @@ -0,0 +1,65 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device/indexing.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template < + typename T, + typename IdxT, + int NDIM, + bool SrcC, + bool IdxC, + typename LocT> +__global__ void gather_axis( + const T* src, + const IdxT* indices, + T* out, + LocT idx_size_pre, + LocT idx_size_axis, + LocT idx_size_post, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array src_strides, + const __grid_constant__ cuda::std::array idx_strides, + int32_t axis, + int32_t axis_size, + int64_t src_stride_axis, + int64_t idx_stride_axis) { + LocT index = cg::this_grid().thread_rank(); + if (index >= idx_size_pre * idx_size_axis * idx_size_post) { + return; + } + + auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre); + + LocT elem_idx = z * idx_size_post; + + LocT idx_loc = y * idx_stride_axis; + if constexpr (IdxC) { + idx_loc += elem_idx * idx_size_axis + x; + } else { + idx_loc += + elem_to_loc_nd(elem_idx + x, shape.data(), idx_strides.data()); + } + + auto idx_val = absolute_index(indices[idx_loc], axis_size); + + LocT src_loc = idx_val * src_stride_axis; + if constexpr (SrcC) { + src_loc += elem_idx * axis_size + x; + } else { + src_loc += + elem_to_loc_nd(elem_idx + x, shape.data(), src_strides.data()); + } + + LocT out_idx = y * idx_size_post + elem_idx * idx_size_axis + x; + + out[out_idx] = src[src_loc]; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/indexing.cuh b/mlx/backend/cuda/device/indexing.cuh new file mode 100644 index 000000000..31cba1a90 --- /dev/null +++ b/mlx/backend/cuda/device/indexing.cuh @@ -0,0 +1,30 @@ +// Copyright © 2025 Apple Inc. + +#include +#include + +namespace mlx::core::cu { + +// Convert an absolute index to positions in a 3d grid, assuming the index is +// calculated with: +// index = x * dim1 * dim2 + y * dim2 + z +template +inline __host__ __device__ cuda::std::tuple +index_to_dims(T index, T dim1, T dim2) { + T x = index / (dim1 * dim2); + T y = (index % (dim1 * dim2)) / dim2; + T z = index % dim2; + return cuda::std::make_tuple(x, y, z); +} + +// Get absolute index from possible negative index. +template +inline __host__ __device__ auto absolute_index(IdxT idx, int32_t size) { + if constexpr (cuda::std::is_unsigned_v) { + return idx; + } else { + return static_cast(idx < 0 ? idx + size : idx); + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/scatter.cuh b/mlx/backend/cuda/device/scatter.cuh new file mode 100644 index 000000000..b2f640350 --- /dev/null +++ b/mlx/backend/cuda/device/scatter.cuh @@ -0,0 +1,68 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device/indexing.cuh" +#include "mlx/backend/cuda/device/scatter_ops.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template < + typename T, + typename IdxT, + typename Op, + int NIDX, + int IDX_NDIM, + typename LocT> +__global__ void scatter( + const T* upd, + T* out, + LocT size, + const __grid_constant__ Shape upd_shape, + const __grid_constant__ Strides upd_strides, + int32_t upd_ndim, + LocT upd_post_idx_size, + const __grid_constant__ Shape out_shape, + const __grid_constant__ Strides out_strides, + int32_t out_ndim, + const __grid_constant__ cuda::std::array axes, + const __grid_constant__ cuda::std::array indices, + const __grid_constant__ cuda::std::array + indices_shape, + const __grid_constant__ cuda::std::array + indices_strides) { + LocT upd_idx = cg::this_grid().thread_rank(); + if (upd_idx >= size) { + return; + } + + LocT out_elem = upd_idx % upd_post_idx_size; + LocT idx_elem = upd_idx / upd_post_idx_size; + + LocT out_idx = elem_to_loc( + out_elem, upd_shape.data() + IDX_NDIM, out_strides.data(), out_ndim); + +#pragma unroll + for (int i = 0; i < NIDX; ++i) { + LocT idx_loc = elem_to_loc_nd( + idx_elem, + indices_shape.data() + i * IDX_NDIM, + indices_strides.data() + i * IDX_NDIM); + int32_t axis = axes[i]; + LocT idx_val = absolute_index(indices[i][idx_loc], out_shape[axis]); + out_idx += idx_val * out_strides[axis]; + } + + LocT upd_loc = elem_to_loc( + out_elem + idx_elem * upd_post_idx_size, + upd_shape.data(), + upd_strides.data(), + upd_ndim); + + Op{}(out + out_idx, upd[upd_loc]); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/scatter_axis.cuh b/mlx/backend/cuda/device/scatter_axis.cuh new file mode 100644 index 000000000..1f30f2ebd --- /dev/null +++ b/mlx/backend/cuda/device/scatter_axis.cuh @@ -0,0 +1,67 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device/indexing.cuh" +#include "mlx/backend/cuda/device/scatter_ops.cuh" +#include "mlx/backend/cuda/device/utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +template < + typename T, + typename IdxT, + typename Op, + int NDIM, + bool UpdC, + bool IdxC, + typename LocT> +__global__ void scatter_axis( + const T* upd, + const IdxT* indices, + T* out, + LocT idx_size_pre, + LocT idx_size_axis, + LocT idx_size_post, + const __grid_constant__ cuda::std::array shape, + const __grid_constant__ cuda::std::array upd_strides, + const __grid_constant__ cuda::std::array idx_strides, + int32_t axis, + int32_t axis_size, + int64_t upd_stride_axis, + int64_t idx_stride_axis) { + LocT index = cg::this_grid().thread_rank(); + if (index >= idx_size_pre * idx_size_axis * idx_size_post) { + return; + } + + auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre); + + LocT elem_idx = z * idx_size_post; + + LocT idx_loc = y * idx_stride_axis; + if constexpr (IdxC) { + idx_loc += elem_idx * idx_size_axis + x; + } else { + idx_loc += + elem_to_loc_nd(elem_idx + x, shape.data(), idx_strides.data()); + } + + auto idx_val = absolute_index(indices[idx_loc], axis_size); + + LocT upd_loc = y * upd_stride_axis; + if constexpr (UpdC) { + upd_loc += elem_idx * idx_size_axis + x; + } else { + upd_loc += + elem_to_loc_nd(elem_idx + x, shape.data(), upd_strides.data()); + } + + LocT out_idx = idx_val * idx_size_post + elem_idx * axis_size + x; + + Op{}(out + out_idx, upd[upd_loc]); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/scatter_ops.cuh b/mlx/backend/cuda/device/scatter_ops.cuh new file mode 100644 index 000000000..d88f896ad --- /dev/null +++ b/mlx/backend/cuda/device/scatter_ops.cuh @@ -0,0 +1,44 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/atomic_ops.cuh" + +namespace mlx::core::cu { + +struct ScatterAssign { + template + __device__ void operator()(T* out, T val) const { + *out = val; + } +}; + +struct ScatterSum { + template + __device__ void operator()(T* out, T val) const { + atomic_add(out, val); + } +}; + +struct ScatterProd { + template + __device__ void operator()(T* out, T val) const { + atomic_prod(out, val); + } +}; + +struct ScatterMax { + template + __device__ void operator()(T* out, T val) const { + atomic_max(out, val); + } +}; + +struct ScatterMin { + template + __device__ void operator()(T* out, T val) const { + atomic_min(out, val); + } +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp new file mode 100644 index 000000000..3603605c4 --- /dev/null +++ b/mlx/backend/cuda/indexing.cpp @@ -0,0 +1,420 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/compiled.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include "cuda_jit_sources.h" + +#include +#include + +#include +#include + +namespace mlx::core { + +namespace { + +constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"}; + +void append_indices_arg( + cu::JitModule& mod, + const std::vector& inputs, + int nidx, + int idx_ndim) { + std::vector indices(nidx); + for (int i = 0; i < nidx; ++i) { + indices[i] = inputs[i + 1].data(); + } + mod.append_arg(std::move(indices)); + std::vector indices_shape(nidx * idx_ndim); + for (int i = 0; i < nidx; ++i) { + std::copy_n( + inputs[i + 1].shape().begin(), + idx_ndim, + indices_shape.data() + i * idx_ndim); + } + mod.append_arg(std::move(indices_shape)); + std::vector indices_strides(nidx * idx_ndim); + for (int i = 0; i < nidx; ++i) { + std::copy_n( + inputs[i + 1].strides().begin(), + idx_ndim, + indices_strides.data() + i * idx_ndim); + } + mod.append_arg(std::move(indices_strides)); +} + +} // namespace + +void Gather::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Gather::eval_gpu"); + assert(inputs.size() > 0); + const auto& src = inputs[0]; + + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + + int nidx = inputs.size() - 1; + Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; + int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; + + bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) || + (src.size() > UINT32_MAX) || (out.size() > UINT32_MAX); + + uint32_t slice_size = std::accumulate( + slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies()); + + std::string module_name = fmt::format( + "gather_{}_{}_{}", + dtype_to_string(out.dtype()), + dtype_to_string(idx_dtype), + nidx); + + auto& s = stream(); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { + for (int large = 0; large <= 1; ++large) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::gather<{}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx_dtype), + nidx, + ndim, + large ? "int64_t" : "uint32_t")); + } + } + return std::make_pair(jit_source_gather, std::move(kernel_names)); + }); + + mod.append_arg(src); + mod.append_arg(out); + if (large) { + mod.append_arg(out.size()); + } else { + mod.append_arg(out.size()); + } + mod.append_ndim_arg(src.shape()); + mod.append_ndim_arg(src.strides()); + mod.append_arg(src.ndim()); + mod.append_ndim_arg(slice_sizes_); + mod.append_arg(slice_size); + mod.append_arg(axes_); + append_indices_arg(mod, inputs, nidx, idx_ndim); + + std::string kernel_name = fmt::format( + "mlx::core::cu::gather<{}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx_dtype), + nidx, + idx_ndim, + large ? "int64_t" : "uint32_t"); + + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + mod.launch_kernel(stream, kernel_name, out, large); + }); +} + +void Scatter::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Gather::eval_gpu"); + assert(inputs.size() > 1); + auto& upd = inputs.back(); + + // Copy src into out. + CopyType copy_type; + if (inputs[0].data_size() == 1) { + copy_type = CopyType::Scalar; + } else if (inputs[0].flags().row_contiguous) { + copy_type = CopyType::Vector; + } else { + copy_type = CopyType::General; + } + copy_gpu(inputs[0], out, copy_type); + + // Empty update. + if (upd.size() == 0) { + return; + } + + int nidx = axes_.size(); + Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32; + int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0; + + bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) || + (upd.size() > UINT32_MAX) || (out.size() > UINT32_MAX); + + uint32_t upd_post_idx_size = std::accumulate( + upd.shape().begin() + idx_ndim, + upd.shape().end(), + 1, + std::multiplies()); + + const char* op = g_scatter_ops[reduce_type_]; + std::string module_name = fmt::format( + "scatter_{}_{}_{}_{}", + dtype_to_string(out.dtype()), + dtype_to_string(idx_dtype), + op, + nidx); + + auto& s = stream(); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { + for (int large = 0; large <= 1; ++large) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx_dtype), + op, + nidx, + ndim, + large ? "int64_t" : "uint32_t")); + } + } + return std::make_pair(jit_source_scatter, std::move(kernel_names)); + }); + + mod.append_arg(upd); + mod.append_arg(out); + if (large) { + mod.append_arg(upd.size()); + } else { + mod.append_arg(upd.size()); + } + mod.append_ndim_arg(upd.shape()); + mod.append_ndim_arg(upd.strides()); + mod.append_arg(upd.ndim()); + if (large) { + mod.append_arg(upd_post_idx_size); + } else { + mod.append_arg(upd_post_idx_size); + } + mod.append_ndim_arg(out.shape()); + mod.append_ndim_arg(out.strides()); + mod.append_arg(out.ndim()); + mod.append_arg(axes_); + append_indices_arg(mod, inputs, nidx, idx_ndim); + + std::string kernel_name = fmt::format( + "mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx_dtype), + op, + nidx, + idx_ndim, + large ? "int64_t" : "uint32_t"); + + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + mod.launch_kernel(stream, kernel_name, upd, large); + }); +} + +void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("GatherAxis::eval_gpu"); + assert(inputs.size() > 1); + const auto& src = inputs[0]; + const auto& idx = inputs[1]; + + out.set_data(allocator::malloc(out.nbytes())); + if (out.size() == 0) { + return; + } + + bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX; + + std::string module_name = fmt::format( + "gather_axis_{}_{}", + dtype_to_string(out.dtype()), + dtype_to_string(idx.dtype())); + + auto& s = stream(); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { + for (int contiguous = 0; contiguous < 4; ++contiguous) { + for (int large = 0; large <= 1; ++large) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx.dtype()), + ndim, + contiguous & 1 ? true : false, + contiguous & 2 ? true : false, + large ? "int64_t" : "uint32_t")); + } + } + } + return std::make_pair(jit_source_gather_axis, std::move(kernel_names)); + }); + + size_t idx_size_pre = 1; + size_t idx_size_post = 1; + for (int i = 0; i < axis_; ++i) { + idx_size_pre *= idx.shape(i); + } + for (int i = axis_ + 1; i < idx.ndim(); ++i) { + idx_size_post *= idx.shape(i); + } + size_t idx_size_axis = idx.shape(axis_); + + mod.append_arg(src); + mod.append_arg(idx); + mod.append_arg(out); + if (large) { + mod.append_arg(idx_size_pre); + mod.append_arg(idx_size_axis); + mod.append_arg(idx_size_post); + } else { + mod.append_arg(idx_size_pre); + mod.append_arg(idx_size_axis); + mod.append_arg(idx_size_post); + } + mod.append_arg(remove_index(idx.shape(), axis_)); + mod.append_arg(remove_index(src.strides(), axis_)); + mod.append_arg(remove_index(idx.strides(), axis_)); + mod.append_arg(axis_); + mod.append_arg(src.shape(axis_)); + mod.append_arg(src.strides(axis_)); + mod.append_arg(idx.strides(axis_)); + + std::string kernel_name = fmt::format( + "mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx.dtype()), + src.ndim() - 1, + src.flags().row_contiguous, + idx.flags().row_contiguous, + large ? "int64_t" : "uint32_t"); + + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + mod.launch_kernel(stream, kernel_name, idx, large); + }); +} + +void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("ScatterAxis::eval_gpu"); + assert(inputs.size() > 2); + const auto& src = inputs[0]; + const auto& idx = inputs[1]; + const auto& upd = inputs[2]; + + // Copy src into out. + CopyType copy_type; + if (src.data_size() == 1) { + copy_type = CopyType::Scalar; + } else if (src.flags().row_contiguous) { + copy_type = CopyType::Vector; + } else { + copy_type = CopyType::General; + } + copy_gpu(src, out, copy_type); + + // Empty update. + if (upd.size() == 0) { + return; + } + + bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX; + + const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign"; + std::string module_name = fmt::format( + "scatter_axis_{}_{}_{}", + dtype_to_string(out.dtype()), + dtype_to_string(idx.dtype()), + op); + + auto& s = stream(); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { + for (int contiguous = 0; contiguous < 4; ++contiguous) { + for (int large = 0; large <= 1; ++large) { + kernel_names.push_back(fmt::format( + "mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx.dtype()), + op, + ndim, + contiguous & 1 ? true : false, + contiguous & 2 ? true : false, + large ? "int64_t" : "uint32_t")); + } + } + } + return std::make_pair(jit_source_scatter_axis, std::move(kernel_names)); + }); + + size_t idx_size_pre = 1; + size_t idx_size_post = 1; + for (int i = 0; i < axis_; ++i) { + idx_size_pre *= idx.shape(i); + } + for (int i = axis_ + 1; i < idx.ndim(); ++i) { + idx_size_post *= idx.shape(i); + } + size_t idx_size_axis = idx.shape(axis_); + + mod.append_arg(upd); + mod.append_arg(idx); + mod.append_arg(out); + if (large) { + mod.append_arg(idx_size_pre); + mod.append_arg(idx_size_axis); + mod.append_arg(idx_size_post); + } else { + mod.append_arg(idx_size_pre); + mod.append_arg(idx_size_axis); + mod.append_arg(idx_size_post); + } + mod.append_arg(remove_index(idx.shape(), axis_)); + mod.append_arg(remove_index(upd.strides(), axis_)); + mod.append_arg(remove_index(idx.strides(), axis_)); + mod.append_arg(axis_); + mod.append_arg(out.shape(axis_)); + mod.append_arg(upd.strides(axis_)); + mod.append_arg(idx.strides(axis_)); + + std::string kernel_name = fmt::format( + "mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + dtype_to_cuda_type(idx.dtype()), + op, + idx.ndim() - 1, + upd.flags().row_contiguous, + idx.flags().row_contiguous, + large ? "int64_t" : "uint32_t"); + + auto& encoder = cu::get_command_encoder(s); + for (const auto& in : inputs) { + encoder.set_input_array(in); + } + encoder.set_output_array(out); + encoder.launch_kernel([&](cudaStream_t stream) { + mod.launch_kernel(stream, kernel_name, idx, large); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 3c00dd7f0..b8be103cc 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -148,24 +148,32 @@ bool compiler_supports_device_sass(Device& device) { #define INCLUDE_PREFIX "mlx/backend/cuda/kernels/" constexpr const char* g_include_names[] = { + INCLUDE_PREFIX "atomic_ops.cuh", INCLUDE_PREFIX "binary_ops.cuh", INCLUDE_PREFIX "cast_op.cuh", INCLUDE_PREFIX "config.h", INCLUDE_PREFIX "cucomplex_math.cuh", INCLUDE_PREFIX "fp16_math.cuh", + INCLUDE_PREFIX "indexing.cuh", + INCLUDE_PREFIX "scatter_ops.cuh", INCLUDE_PREFIX "unary_ops.cuh", + INCLUDE_PREFIX "ternary_ops.cuh", INCLUDE_PREFIX "utils.cuh", }; #undef INCLUDE_PREFIX constexpr const char* g_headers[] = { + jit_source_atomic_ops, jit_source_binary_ops, jit_source_cast_op, jit_source_config, jit_source_cucomplex_math, jit_source_fp16_math, + jit_source_indexing, + jit_source_scatter_ops, jit_source_unary_ops, + jit_source_ternary_ops, jit_source_utils, }; diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index eb451f49d..0c4d3a8aa 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -78,8 +78,6 @@ NO_GPU_MULTI(DivMod) NO_GPU(DynamicSlice) NO_GPU(DynamicSliceUpdate) NO_GPU(FFT) -NO_GPU(Gather) -NO_GPU(GatherAxis) NO_GPU(GatherMM) NO_GPU(GatherQMM) NO_GPU(Hadamard) @@ -89,8 +87,6 @@ NO_GPU(Partition) NO_GPU_MULTI(QRF) NO_GPU(QuantizedMatmul) NO_GPU(Scan) -NO_GPU(Scatter) -NO_GPU(ScatterAxis) NO_GPU_MULTI(SVD) NO_GPU(Inverse) NO_GPU(Cholesky)