CUDA backend: indexing ops

This commit is contained in:
Cheng
2025-05-21 02:15:09 +00:00
parent 171679a176
commit e346a063ab
10 changed files with 699 additions and 7 deletions

View File

@@ -13,12 +13,16 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu
${CMAKE_CURRENT_SOURCE_DIR}/gather.cu
${CMAKE_CURRENT_SOURCE_DIR}/gather_axis.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/scatter.cu
${CMAKE_CURRENT_SOURCE_DIR}/scatter_axis.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
@@ -31,11 +35,22 @@ target_compile_definitions(mlx PUBLIC MLX_USE_CUDA)
target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
# Ignore uncommon types in ops to speed up compilation.
if(MLX_FAST_COMPILE)
target_compile_definitions(mlx PUBLIC MLX_FAST_COMPILE)
endif()
# Compute capability 7 is required for synchronization between CPU/GPU with
# managed memory. TODO: Add more architectures for potential performance gain.
if(MLX_FAST_COMPILE)
set(MLX_CUDA_ARCHITECTURES
"70"
CACHE STRING "CUDA architectures")
else()
set(MLX_CUDA_ARCHITECTURES
"70;80"
CACHE STRING "CUDA architectures")
endif()
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
"${MLX_CUDA_ARCHITECTURES}")

View File

@@ -0,0 +1,60 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/indexing.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/gather.h>
namespace mlx::core {
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Gather::eval_gpu");
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
const auto& src = inputs[0];
size_t nidx = inputs.size() - 1;
auto idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
auto idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
encoder.set_input_array(in);
}
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_INDEX_TYPES_CHECKED(idx_dtype, "gather", CTYPE_IDX, {
using IndexType = cuda_type_t<CTYPE_IDX>;
MLX_SWITCH_NIDX(inputs.size() - 1, NIDX, {
MLX_SWITCH_IDX_NDIM(idx_ndim, IDX_NDIM, {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_DATA, {
using DataType = cuda_type_t<CTYPE_DATA>;
auto idx_begin = thrust::make_transform_iterator(
thrust::make_counting_iterator(0),
cu::IndicesOp<IndexType, NIDX, IDX_NDIM>(
src,
slice_sizes_,
axes_,
inputs.begin() + 1,
inputs.end()));
thrust::gather(
cu::thrust_policy(stream),
idx_begin,
idx_begin + out.size(),
src.data<DataType>(),
out.data<DataType>());
});
});
});
});
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,53 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/indexing.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/gather.h>
namespace mlx::core {
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("GatherAxis::eval_gpu");
out.set_data(allocator::malloc(out.nbytes()));
if (out.size() == 0) {
return;
}
auto& src = inputs[0];
auto& idx = inputs[1];
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(src);
encoder.set_input_array(idx);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_INDEX_TYPES_CHECKED(idx.dtype(), "gather_axis", CTYPE_IDX, {
using IndexType = cuda_type_t<CTYPE_IDX>;
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_DATA, {
using DataType = cuda_type_t<CTYPE_DATA>;
MLX_SWITCH_BOOL(src.flags().row_contiguous, SRC_CONT, {
MLX_SWITCH_BOOL(idx.flags().row_contiguous, IDX_CONT, {
auto idx_begin = thrust::make_transform_iterator(
thrust::make_counting_iterator(0),
cu::IndexOp<IndexType, IDX_CONT, SRC_CONT>(idx, src, axis_));
thrust::gather(
cu::thrust_policy(stream),
idx_begin,
idx_begin + idx.size(),
src.data<DataType>(),
out.data<DataType>());
});
});
});
});
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,260 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
#include <thrust/for_each.h>
#include <cuda/std/utility>
namespace mlx::core::cu {
// Only allow int32 as index type if MLX_FAST_COMPILE is defined.
#if defined(MLX_FAST_COMPILE)
#define MLX_SWITCH_INDEX_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
if (TYPE == ::mlx::core::int32) { \
using CTYPE_ALIAS = int32_t; \
__VA_ARGS__; \
} else if (TYPE == ::mlx::core::uint32) { \
using CTYPE_ALIAS = uint32_t; \
__VA_ARGS__; \
} else { \
throw std::invalid_argument(fmt::format( \
"Can not use dtype {} as index for {} when MLX_FAST_COMPILE is on.", \
dtype_to_string(TYPE), \
NAME)); \
}
#else
#define MLX_SWITCH_INDEX_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
MLX_SWITCH_INT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, __VA_ARGS__)
#endif
// Dispatch dynamic nidx to constexpr.
#if defined(MLX_FAST_COMPILE)
#define MLX_SWITCH_NIDX(nidx, NIDX, ...) \
{ \
assert(nidx <= 16); \
constexpr uint32_t NIDX = 16; \
__VA_ARGS__; \
}
#else
#define MLX_SWITCH_NIDX(nidx, NIDX, ...) \
if (nidx <= 2) { \
constexpr uint32_t NIDX = 2; \
__VA_ARGS__; \
} else if (nidx <= 16) { \
constexpr uint32_t NIDX = 16; \
__VA_ARGS__; \
} else { \
throw std::runtime_error( \
fmt::format("Indices array can not have more than {} items", nidx)); \
}
#endif
// Dispatch dynamic idx_ndim to constexpr.
#define MORE_THAN_ONE MAX_NDIM
#define MLX_SWITCH_IDX_NDIM(idx_ndim, IDX_NDIM, ...) \
if (idx_ndim == 0) { \
constexpr uint32_t IDX_NDIM = 0; \
__VA_ARGS__; \
} else if (idx_ndim == 1) { \
constexpr uint32_t IDX_NDIM = 1; \
__VA_ARGS__; \
} else { \
constexpr uint32_t IDX_NDIM = MORE_THAN_ONE; \
__VA_ARGS__; \
}
// Like thrust::scatter but accept custom op.
template <typename Policy, typename It, typename Idx, typename Out, typename Op>
void scatter_n(Policy& policy, It begin, size_t size, Idx idx, Out out, Op op) {
thrust::for_each_n(
policy,
thrust::make_zip_iterator(begin, idx),
size,
[out, op] __device__(auto item) {
op(&out[thrust::get<1>(item)], thrust::get<0>(item));
});
}
// Convert an absolute index to positions in a 3d grid, assuming the index is
// calculated with:
// index = x * dim1 * dim2 + y * dim2 + z
template <typename T>
inline __device__ cuda::std::tuple<T, T, T>
index_to_dims(T index, size_t dim1, size_t dim2) {
T x = index / (dim1 * dim2);
T y = (index % (dim1 * dim2)) / dim2;
T z = index % dim2;
return cuda::std::make_tuple(x, y, z);
}
// Get absolute index from possible negative index.
template <typename IdxT>
inline __device__ auto absolute_index(IdxT idx, int32_t size) {
if constexpr (cuda::std::is_unsigned_v<IdxT>) {
return idx;
} else {
return static_cast<int32_t>(idx < 0 ? idx + size : idx);
}
}
// An op that takes an index of |src|, and returns the corresponding index value
// from |idx| which is the indices at |axis|.
template <typename IdxT, bool IdxC, bool SrcC = true, typename LocT = int64_t>
struct IndexOp {
const IdxT* idx;
int32_t ndim;
Shape shape;
Strides src_strides;
Strides idx_strides;
int32_t src_axis_size;
int32_t idx_axis_size;
int64_t src_axis_stride;
int64_t idx_axis_stride;
size_t size_post;
IndexOp(const array& idx, const array& src, int32_t axis)
: idx(idx.data<IdxT>()),
ndim(static_cast<int32_t>(src.ndim()) - 1),
shape(const_param(remove_index(idx.shape(), axis))),
src_strides(const_param(remove_index(src.strides(), axis))),
idx_strides(const_param(remove_index(idx.strides(), axis))),
src_axis_size(src.shape(axis)),
idx_axis_size(idx.shape(axis)),
src_axis_stride(src.strides(axis)),
idx_axis_stride(idx.strides(axis)) {
size_post = 1;
for (int i = axis + 1; i < idx.ndim(); ++i) {
size_post *= idx.shape(i);
}
}
__device__ LocT operator()(size_t index) {
auto [x, y, z] = index_to_dims(index, idx_axis_size, size_post);
LocT elem_idx = x * size_post;
LocT idx_loc = y * idx_axis_stride;
if constexpr (IdxC) {
idx_loc += elem_idx * idx_axis_size + z;
} else {
idx_loc +=
elem_to_loc(elem_idx + z, shape.data(), idx_strides.data(), ndim);
}
auto idx_val = absolute_index(idx[idx_loc], src_axis_size);
LocT src_idx = idx_val * src_axis_stride;
if constexpr (SrcC) {
src_idx += elem_idx * src_axis_size + z;
} else {
src_idx +=
elem_to_loc(elem_idx + z, shape.data(), src_strides.data(), ndim);
}
return src_idx;
}
};
// Concatenated |idx| arrays.
template <typename T, size_t NIDX, size_t IDX_NDIM, typename LocT = int64_t>
struct Indices {
size_t size;
size_t ndim;
cuda::std::array<const T*, NIDX> buffers;
cuda::std::array<bool, NIDX> row_contiguous;
cuda::std::array<int32_t, NIDX * IDX_NDIM> shapes;
cuda::std::array<int64_t, NIDX * IDX_NDIM> strides;
template <typename Iter>
Indices(Iter begin, Iter end) {
size = end - begin;
ndim = size > 0 ? begin->ndim() : 0;
for (size_t i = 0; i < size; ++i) {
const array& arr = *(begin + i);
buffers[i] = arr.data<T>();
row_contiguous[i] = arr.flags().row_contiguous;
std::copy_n(arr.shape().begin(), ndim, shapes.begin() + i * ndim);
std::copy_n(arr.strides().begin(), ndim, strides.begin() + i * ndim);
}
}
__device__ auto operator()(size_t i, size_t x, size_t y) {
LocT idx_loc;
if constexpr (IDX_NDIM == 0) {
idx_loc = 0;
} else {
idx_loc = x * strides[ndim * i];
if constexpr (IDX_NDIM == MORE_THAN_ONE) {
if (row_contiguous[i]) {
idx_loc += y;
} else {
size_t offset = ndim * i + 1;
idx_loc += elem_to_loc(
y, shapes.data() + offset, strides.data() + offset, ndim - 1);
}
}
}
return buffers[i][idx_loc];
}
};
// An op that takes an index of |src|, and returns the corresponding index value
// from |indices| located at |axes|.
template <typename IdxT, size_t NIDX, size_t IDX_NDIM, typename LocT = int64_t>
struct IndicesOp {
size_t ndim;
Shape shape;
Strides strides;
Shape slice_sizes;
Shape axes;
Indices<IdxT, NIDX, IDX_NDIM, LocT> indices;
size_t n_dim0;
size_t slice_size;
template <typename Iter>
IndicesOp(
const array& src,
const std::vector<int32_t>& slice_sizes,
const std::vector<int32_t>& axes,
Iter idx_begin,
Iter idx_end)
: ndim(src.ndim()),
shape(const_param(src.shape())),
strides(const_param(src.strides())),
slice_sizes(const_param(slice_sizes)),
axes(const_param(axes)),
indices(idx_begin, idx_end) {
n_dim0 = 1;
size_t dim0 = 1;
if (indices.ndim >= 1) {
dim0 = idx_begin->shape(0);
}
if (indices.ndim >= 2) {
n_dim0 = idx_begin->size() / dim0;
}
slice_size = 1;
for (size_t s : slice_sizes) {
slice_size *= s;
}
}
__device__ LocT operator()(size_t index) {
auto [x, y, z] = index_to_dims(index, n_dim0, slice_size);
LocT src_idx = 0;
for (size_t i = 0; i < indices.size; ++i) {
auto ax = axes[i];
auto idx_val = absolute_index(indices(i, x, y), shape[ax]);
src_idx += static_cast<LocT>(idx_val) * strides[ax];
}
LocT src_offset = elem_to_loc(z, slice_sizes.data(), strides.data(), ndim);
return src_idx + src_offset;
}
};
} // namespace mlx::core::cu

View File

@@ -15,6 +15,16 @@
namespace mlx::core {
// Like MLX_SWITCH_ALL_TYPES but for booleans.
#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \
if (BOOL) { \
constexpr bool BOOL_ALIAS = true; \
__VA_ARGS__; \
} else { \
constexpr bool BOOL_ALIAS = false; \
__VA_ARGS__; \
}
// Helper macros for dispatch macros (see below).
#define MLX_INTERNAL_IF_CASE(DIM, BLOCK_DIM, ...) \
} \

View File

@@ -0,0 +1,67 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
#include <cuda/atomic>
namespace mlx::core::cu {
template <typename T>
inline __device__ void atomic_add(T* out, T val) {
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
ref += val;
}
template <typename T>
inline __device__ void atomic_prod(T* out, T val) {
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
T old = ref.load();
while (!ref.compare_exchange_strong(old, old * val)) {
}
}
template <typename T>
inline __device__ void atomic_max(T* out, T val) {
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
ref.fetch_max(val);
}
template <typename T>
inline __device__ void atomic_min(T* out, T val) {
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
ref.fetch_min(val);
}
// Somehow cuda::atomic_ref does not provide atomic add for following types.
template <typename T>
inline __device__ void atomic_add_general(T* out, T val) {
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
T old = ref.load();
while (!ref.compare_exchange_strong(old, old + val)) {
}
}
inline __device__ void atomic_add(__half* out, __half val) {
atomicAdd(out, val);
}
inline __device__ void atomic_add(cuComplex* out, cuComplex val) {
#if __CUDA_ARCH__ < 900
atomic_add_general(out, val);
#else
atomicAdd(out, val);
#endif
}
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
#if __CUDA_ARCH__ < 800
atomic_add_general(out, val);
#else
atomicAdd(out, val);
#endif
}
} // namespace mlx::core::cu

View File

@@ -0,0 +1,44 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/backend/cuda/kernels/atomic_ops.cuh"
namespace mlx::core::cu {
template <typename T>
struct ScatterAssign {
__device__ void operator()(T* out, T val) const {
*out = val;
}
};
template <typename T>
struct ScatterSum {
__device__ void operator()(T* out, T val) const {
atomic_add(out, val);
}
};
template <typename T>
struct ScatterProd {
__device__ void operator()(T* out, T val) const {
atomic_prod(out, val);
}
};
template <typename T>
struct ScatterMax {
__device__ void operator()(T* out, T val) const {
atomic_max(out, val);
}
};
template <typename T>
struct ScatterMin {
__device__ void operator()(T* out, T val) const {
atomic_min(out, val);
}
};
} // namespace mlx::core::cu

View File

@@ -116,8 +116,6 @@ NO_GPU_MULTI(DivMod)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(FFT)
NO_GPU(Gather)
NO_GPU(GatherAxis)
NO_GPU(GatherMM)
NO_GPU(GatherQMM)
NO_GPU(Hadamard)
@@ -128,8 +126,6 @@ NO_GPU(Partition)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(Scan)
NO_GPU(Scatter)
NO_GPU(ScatterAxis)
NO_GPU(Select)
NO_GPU(SliceUpdate)
NO_GPU(Softmax)

104
mlx/backend/cuda/scatter.cu Normal file
View File

@@ -0,0 +1,104 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/indexing.cuh"
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/scatter_ops.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
namespace mlx::core {
// Like MLX_SWITCH_ALL_TYPES but for reductions.
#define MLX_SWITCH_SCATTER_OP(REDUCE, REDUCE_ALIAS, ...) \
if (REDUCE == Scatter::Sum) { \
using REDUCE_ALIAS = mlx::core::cu::ScatterSum<DataType>; \
__VA_ARGS__; \
} else if (REDUCE == Scatter::Prod) { \
using REDUCE_ALIAS = mlx::core::cu::ScatterProd<DataType>; \
__VA_ARGS__; \
} else if (REDUCE == Scatter::Max) { \
using REDUCE_ALIAS = mlx::core::cu::ScatterMax<DataType>; \
__VA_ARGS__; \
} else if (REDUCE == Scatter::Min) { \
using REDUCE_ALIAS = mlx::core::cu::ScatterMin<DataType>; \
__VA_ARGS__; \
} else { \
using REDUCE_ALIAS = mlx::core::cu::ScatterAssign<DataType>; \
__VA_ARGS__; \
}
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Gather::eval_gpu");
auto& upd = inputs.back();
// Copy src into out.
CopyType copy_type;
if (inputs[0].data_size() == 1) {
copy_type = CopyType::Scalar;
} else if (inputs[0].flags().row_contiguous) {
copy_type = CopyType::Vector;
} else {
copy_type = CopyType::General;
}
copy_gpu(inputs[0], out, copy_type);
// Empty update.
if (upd.size() == 0) {
return;
}
size_t nidx = axes_.size();
auto idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
auto idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
Shape upd_shape_post(upd.shape().begin() + idx_ndim, upd.shape().end());
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
for (const auto& in : inputs) {
encoder.set_input_array(in);
}
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_INDEX_TYPES_CHECKED(idx_dtype, "scatter", CTYPE_IDX, {
using IndexType = cuda_type_t<CTYPE_IDX>;
MLX_SWITCH_NIDX(nidx, NIDX, {
MLX_SWITCH_IDX_NDIM(idx_ndim, IDX_NDIM, {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_DATA, {
using DataType = cuda_type_t<CTYPE_DATA>;
MLX_SWITCH_SCATTER_OP(reduce_type_, SCATTER_OP, {
auto policy = cu::thrust_policy(stream);
auto upd_ptr = thrust::device_pointer_cast(upd.data<DataType>());
auto size = upd.size();
auto idx_begin = thrust::make_transform_iterator(
thrust::make_counting_iterator(0),
cu::IndicesOp<IndexType, NIDX, IDX_NDIM>(
out,
upd_shape_post,
axes_,
inputs.begin() + 1,
inputs.begin() + 1 + nidx));
auto out_ptr = out.data<DataType>();
SCATTER_OP op;
if (upd.flags().row_contiguous) {
cu::scatter_n(policy, upd_ptr, size, idx_begin, out_ptr, op);
} else {
auto upd_begin = cu::make_general_iterator<int64_t>(
upd_ptr, upd.shape(), upd.strides());
cu::scatter_n(policy, upd_begin, size, idx_begin, out_ptr, op);
}
});
});
});
});
});
});
}
} // namespace mlx::core

View File

@@ -0,0 +1,83 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/indexing.cuh"
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/kernels/scatter_ops.cuh"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
#include <nvtx3/nvtx3.hpp>
#include <thrust/device_ptr.h>
namespace mlx::core {
// Like MLX_SWITCH_ALL_TYPES but for reductions.
#define MLX_SWITCH_SCATTER_AXIS_OP(REDUCE, REDUCE_ALIAS, ...) \
if (REDUCE == ScatterAxis::Sum) { \
using REDUCE_ALIAS = mlx::core::cu::ScatterSum<DataType>; \
__VA_ARGS__; \
} else { \
using REDUCE_ALIAS = mlx::core::cu::ScatterAssign<DataType>; \
__VA_ARGS__; \
}
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("ScatterAxis::eval_gpu");
auto& src = inputs[0];
auto& idx = inputs[1];
auto& upd = inputs[2];
// Copy src into out.
CopyType copy_type;
if (src.data_size() == 1) {
copy_type = CopyType::Scalar;
} else if (src.flags().row_contiguous) {
copy_type = CopyType::Vector;
} else {
copy_type = CopyType::General;
}
copy_gpu(src, out, copy_type);
// Empty update.
if (upd.size() == 0) {
return;
}
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(idx);
encoder.set_input_array(upd);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_INDEX_TYPES_CHECKED(idx.dtype(), "scatter_axis", CTYPE_IDX, {
using IndexType = cuda_type_t<CTYPE_IDX>;
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_DATA, {
using DataType = cuda_type_t<CTYPE_DATA>;
MLX_SWITCH_BOOL(idx.flags().row_contiguous, IDX_CONT, {
MLX_SWITCH_SCATTER_AXIS_OP(reduce_type_, SCATTER_OP, {
auto policy = cu::thrust_policy(stream);
auto upd_ptr = thrust::device_pointer_cast(upd.data<DataType>());
auto size = upd.size();
auto idx_begin = thrust::make_transform_iterator(
thrust::make_counting_iterator(0),
cu::IndexOp<IndexType, IDX_CONT>(idx, out, axis_));
auto out_ptr = out.data<DataType>();
SCATTER_OP op;
if (upd.flags().row_contiguous) {
cu::scatter_n(policy, upd_ptr, size, idx_begin, out_ptr, op);
} else {
auto upd_begin = cu::make_general_iterator<int64_t>(
upd_ptr, upd.shape(), upd.strides());
cu::scatter_n(policy, upd_begin, size, idx_begin, out_ptr, op);
}
});
});
});
});
});
}
} // namespace mlx::core