mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
CUDA backend: indexing ops (#2277)
This commit is contained in:
72
mlx/backend/cuda/device/atomic_ops.cuh
Normal file
72
mlx/backend/cuda/device/atomic_ops.cuh
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
|
||||
#include <cuda/atomic>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_add(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
ref += val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_prod(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
T old = ref.load();
|
||||
while (!ref.compare_exchange_strong(old, old * val)) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_max(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
ref.fetch_max(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_min(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
ref.fetch_min(val);
|
||||
}
|
||||
|
||||
// Somehow cuda::atomic_ref does not provide atomic add for following types.
|
||||
template <typename T>
|
||||
inline __device__ void atomic_add_general(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
T old = ref.load();
|
||||
while (!ref.compare_exchange_strong(old, old + val)) {
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(__half* out, __half val) {
|
||||
atomicAdd(out, val);
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(cuComplex* out, cuComplex val) {
|
||||
#if __CUDA_ARCH__ < 900
|
||||
atomic_add_general(out, val);
|
||||
#else
|
||||
atomicAdd(out, val);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
|
||||
#if __CUDA_ARCH__ < 800
|
||||
#if CCCL_VERSION >= 2008000
|
||||
atomic_add_general(out, val);
|
||||
#else
|
||||
bool cccl_version_too_old_for_bfloat16_atomic_add = false;
|
||||
assert(cccl_version_too_old_for_bfloat16_atomic_add);
|
||||
#endif
|
||||
#else
|
||||
atomicAdd(out, val);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
53
mlx/backend/cuda/device/gather.cuh
Normal file
53
mlx/backend/cuda/device/gather.cuh
Normal file
@@ -0,0 +1,53 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
|
||||
__global__ void gather(
|
||||
const T* src,
|
||||
T* out,
|
||||
LocT size,
|
||||
const __grid_constant__ Shape src_shape,
|
||||
const __grid_constant__ Strides src_strides,
|
||||
int32_t src_ndim,
|
||||
const __grid_constant__ Shape slice_sizes,
|
||||
uint32_t slice_size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX> axes,
|
||||
const __grid_constant__ cuda::std::array<IdxT*, NIDX> indices,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX * IDX_NDIM>
|
||||
indices_shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NIDX * IDX_NDIM>
|
||||
indices_strides) {
|
||||
LocT out_idx = cg::this_grid().thread_rank();
|
||||
if (out_idx >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
LocT src_elem = out_idx % slice_size;
|
||||
LocT idx_elem = out_idx / slice_size;
|
||||
|
||||
LocT src_loc =
|
||||
elem_to_loc(src_elem, slice_sizes.data(), src_strides.data(), src_ndim);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
LocT idx_loc = elem_to_loc_nd<IDX_NDIM>(
|
||||
idx_elem,
|
||||
indices_shape.data() + i * IDX_NDIM,
|
||||
indices_strides.data() + i * IDX_NDIM);
|
||||
int32_t axis = axes[i];
|
||||
LocT idx_val = absolute_index(indices[i][idx_loc], src_shape[axis]);
|
||||
src_loc += idx_val * src_strides[axis];
|
||||
}
|
||||
|
||||
out[out_idx] = src[src_loc];
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
65
mlx/backend/cuda/device/gather_axis.cuh
Normal file
65
mlx/backend/cuda/device/gather_axis.cuh
Normal file
@@ -0,0 +1,65 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
int NDIM,
|
||||
bool SrcC,
|
||||
bool IdxC,
|
||||
typename LocT>
|
||||
__global__ void gather_axis(
|
||||
const T* src,
|
||||
const IdxT* indices,
|
||||
T* out,
|
||||
LocT idx_size_pre,
|
||||
LocT idx_size_axis,
|
||||
LocT idx_size_post,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> src_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> idx_strides,
|
||||
int32_t axis,
|
||||
int32_t axis_size,
|
||||
int64_t src_stride_axis,
|
||||
int64_t idx_stride_axis) {
|
||||
LocT index = cg::this_grid().thread_rank();
|
||||
if (index >= idx_size_pre * idx_size_axis * idx_size_post) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre);
|
||||
|
||||
LocT elem_idx = z * idx_size_post;
|
||||
|
||||
LocT idx_loc = y * idx_stride_axis;
|
||||
if constexpr (IdxC) {
|
||||
idx_loc += elem_idx * idx_size_axis + x;
|
||||
} else {
|
||||
idx_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), idx_strides.data());
|
||||
}
|
||||
|
||||
auto idx_val = absolute_index(indices[idx_loc], axis_size);
|
||||
|
||||
LocT src_loc = idx_val * src_stride_axis;
|
||||
if constexpr (SrcC) {
|
||||
src_loc += elem_idx * axis_size + x;
|
||||
} else {
|
||||
src_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), src_strides.data());
|
||||
}
|
||||
|
||||
LocT out_idx = y * idx_size_post + elem_idx * idx_size_axis + x;
|
||||
|
||||
out[out_idx] = src[src_loc];
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
30
mlx/backend/cuda/device/indexing.cuh
Normal file
30
mlx/backend/cuda/device/indexing.cuh
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <cuda/std/tuple>
|
||||
#include <cuda/std/type_traits>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Convert an absolute index to positions in a 3d grid, assuming the index is
|
||||
// calculated with:
|
||||
// index = x * dim1 * dim2 + y * dim2 + z
|
||||
template <typename T>
|
||||
inline __host__ __device__ cuda::std::tuple<T, T, T>
|
||||
index_to_dims(T index, T dim1, T dim2) {
|
||||
T x = index / (dim1 * dim2);
|
||||
T y = (index % (dim1 * dim2)) / dim2;
|
||||
T z = index % dim2;
|
||||
return cuda::std::make_tuple(x, y, z);
|
||||
}
|
||||
|
||||
// Get absolute index from possible negative index.
|
||||
template <typename IdxT>
|
||||
inline __host__ __device__ auto absolute_index(IdxT idx, int32_t size) {
|
||||
if constexpr (cuda::std::is_unsigned_v<IdxT>) {
|
||||
return idx;
|
||||
} else {
|
||||
return static_cast<int32_t>(idx < 0 ? idx + size : idx);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
68
mlx/backend/cuda/device/scatter.cuh
Normal file
68
mlx/backend/cuda/device/scatter.cuh
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/scatter_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
typename Op,
|
||||
int NIDX,
|
||||
int IDX_NDIM,
|
||||
typename LocT>
|
||||
__global__ void scatter(
|
||||
const T* upd,
|
||||
T* out,
|
||||
LocT size,
|
||||
const __grid_constant__ Shape upd_shape,
|
||||
const __grid_constant__ Strides upd_strides,
|
||||
int32_t upd_ndim,
|
||||
LocT upd_post_idx_size,
|
||||
const __grid_constant__ Shape out_shape,
|
||||
const __grid_constant__ Strides out_strides,
|
||||
int32_t out_ndim,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX> axes,
|
||||
const __grid_constant__ cuda::std::array<IdxT*, NIDX> indices,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX * IDX_NDIM>
|
||||
indices_shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NIDX * IDX_NDIM>
|
||||
indices_strides) {
|
||||
LocT upd_idx = cg::this_grid().thread_rank();
|
||||
if (upd_idx >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
LocT out_elem = upd_idx % upd_post_idx_size;
|
||||
LocT idx_elem = upd_idx / upd_post_idx_size;
|
||||
|
||||
LocT out_idx = elem_to_loc(
|
||||
out_elem, upd_shape.data() + IDX_NDIM, out_strides.data(), out_ndim);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
LocT idx_loc = elem_to_loc_nd<IDX_NDIM>(
|
||||
idx_elem,
|
||||
indices_shape.data() + i * IDX_NDIM,
|
||||
indices_strides.data() + i * IDX_NDIM);
|
||||
int32_t axis = axes[i];
|
||||
LocT idx_val = absolute_index(indices[i][idx_loc], out_shape[axis]);
|
||||
out_idx += idx_val * out_strides[axis];
|
||||
}
|
||||
|
||||
LocT upd_loc = elem_to_loc(
|
||||
out_elem + idx_elem * upd_post_idx_size,
|
||||
upd_shape.data(),
|
||||
upd_strides.data(),
|
||||
upd_ndim);
|
||||
|
||||
Op{}(out + out_idx, upd[upd_loc]);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
67
mlx/backend/cuda/device/scatter_axis.cuh
Normal file
67
mlx/backend/cuda/device/scatter_axis.cuh
Normal file
@@ -0,0 +1,67 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/scatter_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
bool UpdC,
|
||||
bool IdxC,
|
||||
typename LocT>
|
||||
__global__ void scatter_axis(
|
||||
const T* upd,
|
||||
const IdxT* indices,
|
||||
T* out,
|
||||
LocT idx_size_pre,
|
||||
LocT idx_size_axis,
|
||||
LocT idx_size_post,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> upd_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> idx_strides,
|
||||
int32_t axis,
|
||||
int32_t axis_size,
|
||||
int64_t upd_stride_axis,
|
||||
int64_t idx_stride_axis) {
|
||||
LocT index = cg::this_grid().thread_rank();
|
||||
if (index >= idx_size_pre * idx_size_axis * idx_size_post) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre);
|
||||
|
||||
LocT elem_idx = z * idx_size_post;
|
||||
|
||||
LocT idx_loc = y * idx_stride_axis;
|
||||
if constexpr (IdxC) {
|
||||
idx_loc += elem_idx * idx_size_axis + x;
|
||||
} else {
|
||||
idx_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), idx_strides.data());
|
||||
}
|
||||
|
||||
auto idx_val = absolute_index(indices[idx_loc], axis_size);
|
||||
|
||||
LocT upd_loc = y * upd_stride_axis;
|
||||
if constexpr (UpdC) {
|
||||
upd_loc += elem_idx * idx_size_axis + x;
|
||||
} else {
|
||||
upd_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), upd_strides.data());
|
||||
}
|
||||
|
||||
LocT out_idx = idx_val * idx_size_post + elem_idx * axis_size + x;
|
||||
|
||||
Op{}(out + out_idx, upd[upd_loc]);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
44
mlx/backend/cuda/device/scatter_ops.cuh
Normal file
44
mlx/backend/cuda/device/scatter_ops.cuh
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/atomic_ops.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct ScatterAssign {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
*out = val;
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterSum {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_add(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterProd {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_prod(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterMax {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_max(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterMin {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_min(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
||||
Reference in New Issue
Block a user