mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
CUDA backend: indexing ops
This commit is contained in:
@@ -13,12 +13,16 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gather.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gather_axis.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scatter.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scatter_axis.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
|
||||
@@ -31,11 +35,22 @@ target_compile_definitions(mlx PUBLIC MLX_USE_CUDA)
|
||||
target_compile_options(mlx
|
||||
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
|
||||
|
||||
# Ignore uncommon types in ops to speed up compilation.
|
||||
if(MLX_FAST_COMPILE)
|
||||
target_compile_definitions(mlx PUBLIC MLX_FAST_COMPILE)
|
||||
endif()
|
||||
|
||||
# Compute capability 7 is required for synchronization between CPU/GPU with
|
||||
# managed memory. TODO: Add more architectures for potential performance gain.
|
||||
set(MLX_CUDA_ARCHITECTURES
|
||||
"70;80"
|
||||
CACHE STRING "CUDA architectures")
|
||||
if(MLX_FAST_COMPILE)
|
||||
set(MLX_CUDA_ARCHITECTURES
|
||||
"70"
|
||||
CACHE STRING "CUDA architectures")
|
||||
else()
|
||||
set(MLX_CUDA_ARCHITECTURES
|
||||
"70;80"
|
||||
CACHE STRING "CUDA architectures")
|
||||
endif()
|
||||
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
|
||||
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
|
||||
"${MLX_CUDA_ARCHITECTURES}")
|
||||
|
||||
60
mlx/backend/cuda/gather.cu
Normal file
60
mlx/backend/cuda/gather.cu
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/indexing.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/gather.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Gather::eval_gpu");
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& src = inputs[0];
|
||||
size_t nidx = inputs.size() - 1;
|
||||
auto idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
||||
auto idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_INDEX_TYPES_CHECKED(idx_dtype, "gather", CTYPE_IDX, {
|
||||
using IndexType = cuda_type_t<CTYPE_IDX>;
|
||||
MLX_SWITCH_NIDX(inputs.size() - 1, NIDX, {
|
||||
MLX_SWITCH_IDX_NDIM(idx_ndim, IDX_NDIM, {
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_DATA, {
|
||||
using DataType = cuda_type_t<CTYPE_DATA>;
|
||||
auto idx_begin = thrust::make_transform_iterator(
|
||||
thrust::make_counting_iterator(0),
|
||||
cu::IndicesOp<IndexType, NIDX, IDX_NDIM>(
|
||||
src,
|
||||
slice_sizes_,
|
||||
axes_,
|
||||
inputs.begin() + 1,
|
||||
inputs.end()));
|
||||
thrust::gather(
|
||||
cu::thrust_policy(stream),
|
||||
idx_begin,
|
||||
idx_begin + out.size(),
|
||||
src.data<DataType>(),
|
||||
out.data<DataType>());
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
53
mlx/backend/cuda/gather_axis.cu
Normal file
53
mlx/backend/cuda/gather_axis.cu
Normal file
@@ -0,0 +1,53 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/indexing.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/gather.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("GatherAxis::eval_gpu");
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& src = inputs[0];
|
||||
auto& idx = inputs[1];
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(src);
|
||||
encoder.set_input_array(idx);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_INDEX_TYPES_CHECKED(idx.dtype(), "gather_axis", CTYPE_IDX, {
|
||||
using IndexType = cuda_type_t<CTYPE_IDX>;
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_DATA, {
|
||||
using DataType = cuda_type_t<CTYPE_DATA>;
|
||||
MLX_SWITCH_BOOL(src.flags().row_contiguous, SRC_CONT, {
|
||||
MLX_SWITCH_BOOL(idx.flags().row_contiguous, IDX_CONT, {
|
||||
auto idx_begin = thrust::make_transform_iterator(
|
||||
thrust::make_counting_iterator(0),
|
||||
cu::IndexOp<IndexType, IDX_CONT, SRC_CONT>(idx, src, axis_));
|
||||
thrust::gather(
|
||||
cu::thrust_policy(stream),
|
||||
idx_begin,
|
||||
idx_begin + idx.size(),
|
||||
src.data<DataType>(),
|
||||
out.data<DataType>());
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
260
mlx/backend/cuda/indexing.cuh
Normal file
260
mlx/backend/cuda/indexing.cuh
Normal file
@@ -0,0 +1,260 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
|
||||
#include <thrust/for_each.h>
|
||||
#include <cuda/std/utility>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Only allow int32 as index type if MLX_FAST_COMPILE is defined.
|
||||
#if defined(MLX_FAST_COMPILE)
|
||||
#define MLX_SWITCH_INDEX_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
|
||||
if (TYPE == ::mlx::core::int32) { \
|
||||
using CTYPE_ALIAS = int32_t; \
|
||||
__VA_ARGS__; \
|
||||
} else if (TYPE == ::mlx::core::uint32) { \
|
||||
using CTYPE_ALIAS = uint32_t; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
throw std::invalid_argument(fmt::format( \
|
||||
"Can not use dtype {} as index for {} when MLX_FAST_COMPILE is on.", \
|
||||
dtype_to_string(TYPE), \
|
||||
NAME)); \
|
||||
}
|
||||
#else
|
||||
#define MLX_SWITCH_INDEX_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
|
||||
MLX_SWITCH_INT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
// Dispatch dynamic nidx to constexpr.
|
||||
#if defined(MLX_FAST_COMPILE)
|
||||
#define MLX_SWITCH_NIDX(nidx, NIDX, ...) \
|
||||
{ \
|
||||
assert(nidx <= 16); \
|
||||
constexpr uint32_t NIDX = 16; \
|
||||
__VA_ARGS__; \
|
||||
}
|
||||
#else
|
||||
#define MLX_SWITCH_NIDX(nidx, NIDX, ...) \
|
||||
if (nidx <= 2) { \
|
||||
constexpr uint32_t NIDX = 2; \
|
||||
__VA_ARGS__; \
|
||||
} else if (nidx <= 16) { \
|
||||
constexpr uint32_t NIDX = 16; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
throw std::runtime_error( \
|
||||
fmt::format("Indices array can not have more than {} items", nidx)); \
|
||||
}
|
||||
#endif
|
||||
|
||||
// Dispatch dynamic idx_ndim to constexpr.
|
||||
#define MORE_THAN_ONE MAX_NDIM
|
||||
#define MLX_SWITCH_IDX_NDIM(idx_ndim, IDX_NDIM, ...) \
|
||||
if (idx_ndim == 0) { \
|
||||
constexpr uint32_t IDX_NDIM = 0; \
|
||||
__VA_ARGS__; \
|
||||
} else if (idx_ndim == 1) { \
|
||||
constexpr uint32_t IDX_NDIM = 1; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr uint32_t IDX_NDIM = MORE_THAN_ONE; \
|
||||
__VA_ARGS__; \
|
||||
}
|
||||
|
||||
// Like thrust::scatter but accept custom op.
|
||||
template <typename Policy, typename It, typename Idx, typename Out, typename Op>
|
||||
void scatter_n(Policy& policy, It begin, size_t size, Idx idx, Out out, Op op) {
|
||||
thrust::for_each_n(
|
||||
policy,
|
||||
thrust::make_zip_iterator(begin, idx),
|
||||
size,
|
||||
[out, op] __device__(auto item) {
|
||||
op(&out[thrust::get<1>(item)], thrust::get<0>(item));
|
||||
});
|
||||
}
|
||||
|
||||
// Convert an absolute index to positions in a 3d grid, assuming the index is
|
||||
// calculated with:
|
||||
// index = x * dim1 * dim2 + y * dim2 + z
|
||||
template <typename T>
|
||||
inline __device__ cuda::std::tuple<T, T, T>
|
||||
index_to_dims(T index, size_t dim1, size_t dim2) {
|
||||
T x = index / (dim1 * dim2);
|
||||
T y = (index % (dim1 * dim2)) / dim2;
|
||||
T z = index % dim2;
|
||||
return cuda::std::make_tuple(x, y, z);
|
||||
}
|
||||
|
||||
// Get absolute index from possible negative index.
|
||||
template <typename IdxT>
|
||||
inline __device__ auto absolute_index(IdxT idx, int32_t size) {
|
||||
if constexpr (cuda::std::is_unsigned_v<IdxT>) {
|
||||
return idx;
|
||||
} else {
|
||||
return static_cast<int32_t>(idx < 0 ? idx + size : idx);
|
||||
}
|
||||
}
|
||||
|
||||
// An op that takes an index of |src|, and returns the corresponding index value
|
||||
// from |idx| which is the indices at |axis|.
|
||||
template <typename IdxT, bool IdxC, bool SrcC = true, typename LocT = int64_t>
|
||||
struct IndexOp {
|
||||
const IdxT* idx;
|
||||
int32_t ndim;
|
||||
Shape shape;
|
||||
Strides src_strides;
|
||||
Strides idx_strides;
|
||||
int32_t src_axis_size;
|
||||
int32_t idx_axis_size;
|
||||
int64_t src_axis_stride;
|
||||
int64_t idx_axis_stride;
|
||||
size_t size_post;
|
||||
|
||||
IndexOp(const array& idx, const array& src, int32_t axis)
|
||||
: idx(idx.data<IdxT>()),
|
||||
ndim(static_cast<int32_t>(src.ndim()) - 1),
|
||||
shape(const_param(remove_index(idx.shape(), axis))),
|
||||
src_strides(const_param(remove_index(src.strides(), axis))),
|
||||
idx_strides(const_param(remove_index(idx.strides(), axis))),
|
||||
src_axis_size(src.shape(axis)),
|
||||
idx_axis_size(idx.shape(axis)),
|
||||
src_axis_stride(src.strides(axis)),
|
||||
idx_axis_stride(idx.strides(axis)) {
|
||||
size_post = 1;
|
||||
for (int i = axis + 1; i < idx.ndim(); ++i) {
|
||||
size_post *= idx.shape(i);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ LocT operator()(size_t index) {
|
||||
auto [x, y, z] = index_to_dims(index, idx_axis_size, size_post);
|
||||
|
||||
LocT elem_idx = x * size_post;
|
||||
|
||||
LocT idx_loc = y * idx_axis_stride;
|
||||
if constexpr (IdxC) {
|
||||
idx_loc += elem_idx * idx_axis_size + z;
|
||||
} else {
|
||||
idx_loc +=
|
||||
elem_to_loc(elem_idx + z, shape.data(), idx_strides.data(), ndim);
|
||||
}
|
||||
|
||||
auto idx_val = absolute_index(idx[idx_loc], src_axis_size);
|
||||
|
||||
LocT src_idx = idx_val * src_axis_stride;
|
||||
if constexpr (SrcC) {
|
||||
src_idx += elem_idx * src_axis_size + z;
|
||||
} else {
|
||||
src_idx +=
|
||||
elem_to_loc(elem_idx + z, shape.data(), src_strides.data(), ndim);
|
||||
}
|
||||
return src_idx;
|
||||
}
|
||||
};
|
||||
|
||||
// Concatenated |idx| arrays.
|
||||
template <typename T, size_t NIDX, size_t IDX_NDIM, typename LocT = int64_t>
|
||||
struct Indices {
|
||||
size_t size;
|
||||
size_t ndim;
|
||||
cuda::std::array<const T*, NIDX> buffers;
|
||||
cuda::std::array<bool, NIDX> row_contiguous;
|
||||
cuda::std::array<int32_t, NIDX * IDX_NDIM> shapes;
|
||||
cuda::std::array<int64_t, NIDX * IDX_NDIM> strides;
|
||||
|
||||
template <typename Iter>
|
||||
Indices(Iter begin, Iter end) {
|
||||
size = end - begin;
|
||||
ndim = size > 0 ? begin->ndim() : 0;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
const array& arr = *(begin + i);
|
||||
buffers[i] = arr.data<T>();
|
||||
row_contiguous[i] = arr.flags().row_contiguous;
|
||||
std::copy_n(arr.shape().begin(), ndim, shapes.begin() + i * ndim);
|
||||
std::copy_n(arr.strides().begin(), ndim, strides.begin() + i * ndim);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ auto operator()(size_t i, size_t x, size_t y) {
|
||||
LocT idx_loc;
|
||||
if constexpr (IDX_NDIM == 0) {
|
||||
idx_loc = 0;
|
||||
} else {
|
||||
idx_loc = x * strides[ndim * i];
|
||||
if constexpr (IDX_NDIM == MORE_THAN_ONE) {
|
||||
if (row_contiguous[i]) {
|
||||
idx_loc += y;
|
||||
} else {
|
||||
size_t offset = ndim * i + 1;
|
||||
idx_loc += elem_to_loc(
|
||||
y, shapes.data() + offset, strides.data() + offset, ndim - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
return buffers[i][idx_loc];
|
||||
}
|
||||
};
|
||||
|
||||
// An op that takes an index of |src|, and returns the corresponding index value
|
||||
// from |indices| located at |axes|.
|
||||
template <typename IdxT, size_t NIDX, size_t IDX_NDIM, typename LocT = int64_t>
|
||||
struct IndicesOp {
|
||||
size_t ndim;
|
||||
Shape shape;
|
||||
Strides strides;
|
||||
Shape slice_sizes;
|
||||
Shape axes;
|
||||
Indices<IdxT, NIDX, IDX_NDIM, LocT> indices;
|
||||
size_t n_dim0;
|
||||
size_t slice_size;
|
||||
|
||||
template <typename Iter>
|
||||
IndicesOp(
|
||||
const array& src,
|
||||
const std::vector<int32_t>& slice_sizes,
|
||||
const std::vector<int32_t>& axes,
|
||||
Iter idx_begin,
|
||||
Iter idx_end)
|
||||
: ndim(src.ndim()),
|
||||
shape(const_param(src.shape())),
|
||||
strides(const_param(src.strides())),
|
||||
slice_sizes(const_param(slice_sizes)),
|
||||
axes(const_param(axes)),
|
||||
indices(idx_begin, idx_end) {
|
||||
n_dim0 = 1;
|
||||
size_t dim0 = 1;
|
||||
if (indices.ndim >= 1) {
|
||||
dim0 = idx_begin->shape(0);
|
||||
}
|
||||
if (indices.ndim >= 2) {
|
||||
n_dim0 = idx_begin->size() / dim0;
|
||||
}
|
||||
|
||||
slice_size = 1;
|
||||
for (size_t s : slice_sizes) {
|
||||
slice_size *= s;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ LocT operator()(size_t index) {
|
||||
auto [x, y, z] = index_to_dims(index, n_dim0, slice_size);
|
||||
|
||||
LocT src_idx = 0;
|
||||
for (size_t i = 0; i < indices.size; ++i) {
|
||||
auto ax = axes[i];
|
||||
auto idx_val = absolute_index(indices(i, x, y), shape[ax]);
|
||||
src_idx += static_cast<LocT>(idx_val) * strides[ax];
|
||||
}
|
||||
|
||||
LocT src_offset = elem_to_loc(z, slice_sizes.data(), strides.data(), ndim);
|
||||
return src_idx + src_offset;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
@@ -15,6 +15,16 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Like MLX_SWITCH_ALL_TYPES but for booleans.
|
||||
#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \
|
||||
if (BOOL) { \
|
||||
constexpr bool BOOL_ALIAS = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool BOOL_ALIAS = false; \
|
||||
__VA_ARGS__; \
|
||||
}
|
||||
|
||||
// Helper macros for dispatch macros (see below).
|
||||
#define MLX_INTERNAL_IF_CASE(DIM, BLOCK_DIM, ...) \
|
||||
} \
|
||||
|
||||
67
mlx/backend/cuda/kernels/atomic_ops.cuh
Normal file
67
mlx/backend/cuda/kernels/atomic_ops.cuh
Normal file
@@ -0,0 +1,67 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/kernels/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/kernels/fp16_math.cuh"
|
||||
|
||||
#include <cuda/atomic>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_add(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
ref += val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_prod(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
T old = ref.load();
|
||||
while (!ref.compare_exchange_strong(old, old * val)) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_max(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
ref.fetch_max(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_min(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
ref.fetch_min(val);
|
||||
}
|
||||
|
||||
// Somehow cuda::atomic_ref does not provide atomic add for following types.
|
||||
template <typename T>
|
||||
inline __device__ void atomic_add_general(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
T old = ref.load();
|
||||
while (!ref.compare_exchange_strong(old, old + val)) {
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(__half* out, __half val) {
|
||||
atomicAdd(out, val);
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(cuComplex* out, cuComplex val) {
|
||||
#if __CUDA_ARCH__ < 900
|
||||
atomic_add_general(out, val);
|
||||
#else
|
||||
atomicAdd(out, val);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
|
||||
#if __CUDA_ARCH__ < 800
|
||||
atomic_add_general(out, val);
|
||||
#else
|
||||
atomicAdd(out, val);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
44
mlx/backend/cuda/kernels/scatter_ops.cuh
Normal file
44
mlx/backend/cuda/kernels/scatter_ops.cuh
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/kernels/atomic_ops.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T>
|
||||
struct ScatterAssign {
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
*out = val;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ScatterSum {
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_add(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ScatterProd {
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_prod(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ScatterMax {
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_max(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ScatterMin {
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_min(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
@@ -116,8 +116,6 @@ NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(DynamicSlice)
|
||||
NO_GPU(DynamicSliceUpdate)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(Gather)
|
||||
NO_GPU(GatherAxis)
|
||||
NO_GPU(GatherMM)
|
||||
NO_GPU(GatherQMM)
|
||||
NO_GPU(Hadamard)
|
||||
@@ -128,8 +126,6 @@ NO_GPU(Partition)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(Scan)
|
||||
NO_GPU(Scatter)
|
||||
NO_GPU(ScatterAxis)
|
||||
NO_GPU(Select)
|
||||
NO_GPU(SliceUpdate)
|
||||
NO_GPU(Softmax)
|
||||
|
||||
104
mlx/backend/cuda/scatter.cu
Normal file
104
mlx/backend/cuda/scatter.cu
Normal file
@@ -0,0 +1,104 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/indexing.cuh"
|
||||
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/scatter_ops.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Like MLX_SWITCH_ALL_TYPES but for reductions.
|
||||
#define MLX_SWITCH_SCATTER_OP(REDUCE, REDUCE_ALIAS, ...) \
|
||||
if (REDUCE == Scatter::Sum) { \
|
||||
using REDUCE_ALIAS = mlx::core::cu::ScatterSum<DataType>; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Scatter::Prod) { \
|
||||
using REDUCE_ALIAS = mlx::core::cu::ScatterProd<DataType>; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Scatter::Max) { \
|
||||
using REDUCE_ALIAS = mlx::core::cu::ScatterMax<DataType>; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Scatter::Min) { \
|
||||
using REDUCE_ALIAS = mlx::core::cu::ScatterMin<DataType>; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
using REDUCE_ALIAS = mlx::core::cu::ScatterAssign<DataType>; \
|
||||
__VA_ARGS__; \
|
||||
}
|
||||
|
||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Gather::eval_gpu");
|
||||
auto& upd = inputs.back();
|
||||
|
||||
// Copy src into out.
|
||||
CopyType copy_type;
|
||||
if (inputs[0].data_size() == 1) {
|
||||
copy_type = CopyType::Scalar;
|
||||
} else if (inputs[0].flags().row_contiguous) {
|
||||
copy_type = CopyType::Vector;
|
||||
} else {
|
||||
copy_type = CopyType::General;
|
||||
}
|
||||
copy_gpu(inputs[0], out, copy_type);
|
||||
|
||||
// Empty update.
|
||||
if (upd.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t nidx = axes_.size();
|
||||
auto idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
||||
auto idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
||||
|
||||
Shape upd_shape_post(upd.shape().begin() + idx_ndim, upd.shape().end());
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_INDEX_TYPES_CHECKED(idx_dtype, "scatter", CTYPE_IDX, {
|
||||
using IndexType = cuda_type_t<CTYPE_IDX>;
|
||||
MLX_SWITCH_NIDX(nidx, NIDX, {
|
||||
MLX_SWITCH_IDX_NDIM(idx_ndim, IDX_NDIM, {
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_DATA, {
|
||||
using DataType = cuda_type_t<CTYPE_DATA>;
|
||||
MLX_SWITCH_SCATTER_OP(reduce_type_, SCATTER_OP, {
|
||||
auto policy = cu::thrust_policy(stream);
|
||||
auto upd_ptr = thrust::device_pointer_cast(upd.data<DataType>());
|
||||
auto size = upd.size();
|
||||
auto idx_begin = thrust::make_transform_iterator(
|
||||
thrust::make_counting_iterator(0),
|
||||
cu::IndicesOp<IndexType, NIDX, IDX_NDIM>(
|
||||
out,
|
||||
upd_shape_post,
|
||||
axes_,
|
||||
inputs.begin() + 1,
|
||||
inputs.begin() + 1 + nidx));
|
||||
auto out_ptr = out.data<DataType>();
|
||||
SCATTER_OP op;
|
||||
if (upd.flags().row_contiguous) {
|
||||
cu::scatter_n(policy, upd_ptr, size, idx_begin, out_ptr, op);
|
||||
} else {
|
||||
auto upd_begin = cu::make_general_iterator<int64_t>(
|
||||
upd_ptr, upd.shape(), upd.strides());
|
||||
cu::scatter_n(policy, upd_begin, size, idx_begin, out_ptr, op);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
83
mlx/backend/cuda/scatter_axis.cu
Normal file
83
mlx/backend/cuda/scatter_axis.cu
Normal file
@@ -0,0 +1,83 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/indexing.cuh"
|
||||
#include "mlx/backend/cuda/iterators/general_iterator.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/cuda/kernels/scatter_ops.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
#include <thrust/device_ptr.h>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Like MLX_SWITCH_ALL_TYPES but for reductions.
|
||||
#define MLX_SWITCH_SCATTER_AXIS_OP(REDUCE, REDUCE_ALIAS, ...) \
|
||||
if (REDUCE == ScatterAxis::Sum) { \
|
||||
using REDUCE_ALIAS = mlx::core::cu::ScatterSum<DataType>; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
using REDUCE_ALIAS = mlx::core::cu::ScatterAssign<DataType>; \
|
||||
__VA_ARGS__; \
|
||||
}
|
||||
|
||||
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("ScatterAxis::eval_gpu");
|
||||
auto& src = inputs[0];
|
||||
auto& idx = inputs[1];
|
||||
auto& upd = inputs[2];
|
||||
|
||||
// Copy src into out.
|
||||
CopyType copy_type;
|
||||
if (src.data_size() == 1) {
|
||||
copy_type = CopyType::Scalar;
|
||||
} else if (src.flags().row_contiguous) {
|
||||
copy_type = CopyType::Vector;
|
||||
} else {
|
||||
copy_type = CopyType::General;
|
||||
}
|
||||
copy_gpu(src, out, copy_type);
|
||||
|
||||
// Empty update.
|
||||
if (upd.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
encoder.set_input_array(idx);
|
||||
encoder.set_input_array(upd);
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
MLX_SWITCH_INDEX_TYPES_CHECKED(idx.dtype(), "scatter_axis", CTYPE_IDX, {
|
||||
using IndexType = cuda_type_t<CTYPE_IDX>;
|
||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_DATA, {
|
||||
using DataType = cuda_type_t<CTYPE_DATA>;
|
||||
MLX_SWITCH_BOOL(idx.flags().row_contiguous, IDX_CONT, {
|
||||
MLX_SWITCH_SCATTER_AXIS_OP(reduce_type_, SCATTER_OP, {
|
||||
auto policy = cu::thrust_policy(stream);
|
||||
auto upd_ptr = thrust::device_pointer_cast(upd.data<DataType>());
|
||||
auto size = upd.size();
|
||||
auto idx_begin = thrust::make_transform_iterator(
|
||||
thrust::make_counting_iterator(0),
|
||||
cu::IndexOp<IndexType, IDX_CONT>(idx, out, axis_));
|
||||
auto out_ptr = out.data<DataType>();
|
||||
SCATTER_OP op;
|
||||
if (upd.flags().row_contiguous) {
|
||||
cu::scatter_n(policy, upd_ptr, size, idx_begin, out_ptr, op);
|
||||
} else {
|
||||
auto upd_begin = cu::make_general_iterator<int64_t>(
|
||||
upd_ptr, upd.shape(), upd.strides());
|
||||
cu::scatter_n(policy, upd_begin, size, idx_begin, out_ptr, op);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
Reference in New Issue
Block a user