mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
CUDA backend: indexing ops (#2277)
This commit is contained in:
parent
2188199ff8
commit
c8b4787e4e
@ -1,8 +1,8 @@
|
||||
# Filename rules in cuda backend:
|
||||
#
|
||||
# * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
|
||||
# * Device-only kernel code should be put in kernels/ subdir.
|
||||
# * Files in kernels/ subdir should not include files outside.
|
||||
# * Device-only code should be put in device/ subdir.
|
||||
# * Files in device/ subdir should not include files outside.
|
||||
target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
|
||||
@ -20,6 +20,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
||||
|
72
mlx/backend/cuda/device/atomic_ops.cuh
Normal file
72
mlx/backend/cuda/device/atomic_ops.cuh
Normal file
@ -0,0 +1,72 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
|
||||
#include <cuda/atomic>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_add(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
ref += val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_prod(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
T old = ref.load();
|
||||
while (!ref.compare_exchange_strong(old, old * val)) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_max(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
ref.fetch_max(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void atomic_min(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
ref.fetch_min(val);
|
||||
}
|
||||
|
||||
// Somehow cuda::atomic_ref does not provide atomic add for following types.
|
||||
template <typename T>
|
||||
inline __device__ void atomic_add_general(T* out, T val) {
|
||||
cuda::atomic_ref<T, cuda::thread_scope_device> ref(*out);
|
||||
T old = ref.load();
|
||||
while (!ref.compare_exchange_strong(old, old + val)) {
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(__half* out, __half val) {
|
||||
atomicAdd(out, val);
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(cuComplex* out, cuComplex val) {
|
||||
#if __CUDA_ARCH__ < 900
|
||||
atomic_add_general(out, val);
|
||||
#else
|
||||
atomicAdd(out, val);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ void atomic_add(__nv_bfloat16* out, __nv_bfloat16 val) {
|
||||
#if __CUDA_ARCH__ < 800
|
||||
#if CCCL_VERSION >= 2008000
|
||||
atomic_add_general(out, val);
|
||||
#else
|
||||
bool cccl_version_too_old_for_bfloat16_atomic_add = false;
|
||||
assert(cccl_version_too_old_for_bfloat16_atomic_add);
|
||||
#endif
|
||||
#else
|
||||
atomicAdd(out, val);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
53
mlx/backend/cuda/device/gather.cuh
Normal file
53
mlx/backend/cuda/device/gather.cuh
Normal file
@ -0,0 +1,53 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
|
||||
__global__ void gather(
|
||||
const T* src,
|
||||
T* out,
|
||||
LocT size,
|
||||
const __grid_constant__ Shape src_shape,
|
||||
const __grid_constant__ Strides src_strides,
|
||||
int32_t src_ndim,
|
||||
const __grid_constant__ Shape slice_sizes,
|
||||
uint32_t slice_size,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX> axes,
|
||||
const __grid_constant__ cuda::std::array<IdxT*, NIDX> indices,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX * IDX_NDIM>
|
||||
indices_shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NIDX * IDX_NDIM>
|
||||
indices_strides) {
|
||||
LocT out_idx = cg::this_grid().thread_rank();
|
||||
if (out_idx >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
LocT src_elem = out_idx % slice_size;
|
||||
LocT idx_elem = out_idx / slice_size;
|
||||
|
||||
LocT src_loc =
|
||||
elem_to_loc(src_elem, slice_sizes.data(), src_strides.data(), src_ndim);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
LocT idx_loc = elem_to_loc_nd<IDX_NDIM>(
|
||||
idx_elem,
|
||||
indices_shape.data() + i * IDX_NDIM,
|
||||
indices_strides.data() + i * IDX_NDIM);
|
||||
int32_t axis = axes[i];
|
||||
LocT idx_val = absolute_index(indices[i][idx_loc], src_shape[axis]);
|
||||
src_loc += idx_val * src_strides[axis];
|
||||
}
|
||||
|
||||
out[out_idx] = src[src_loc];
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
65
mlx/backend/cuda/device/gather_axis.cuh
Normal file
65
mlx/backend/cuda/device/gather_axis.cuh
Normal file
@ -0,0 +1,65 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
int NDIM,
|
||||
bool SrcC,
|
||||
bool IdxC,
|
||||
typename LocT>
|
||||
__global__ void gather_axis(
|
||||
const T* src,
|
||||
const IdxT* indices,
|
||||
T* out,
|
||||
LocT idx_size_pre,
|
||||
LocT idx_size_axis,
|
||||
LocT idx_size_post,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> src_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> idx_strides,
|
||||
int32_t axis,
|
||||
int32_t axis_size,
|
||||
int64_t src_stride_axis,
|
||||
int64_t idx_stride_axis) {
|
||||
LocT index = cg::this_grid().thread_rank();
|
||||
if (index >= idx_size_pre * idx_size_axis * idx_size_post) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre);
|
||||
|
||||
LocT elem_idx = z * idx_size_post;
|
||||
|
||||
LocT idx_loc = y * idx_stride_axis;
|
||||
if constexpr (IdxC) {
|
||||
idx_loc += elem_idx * idx_size_axis + x;
|
||||
} else {
|
||||
idx_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), idx_strides.data());
|
||||
}
|
||||
|
||||
auto idx_val = absolute_index(indices[idx_loc], axis_size);
|
||||
|
||||
LocT src_loc = idx_val * src_stride_axis;
|
||||
if constexpr (SrcC) {
|
||||
src_loc += elem_idx * axis_size + x;
|
||||
} else {
|
||||
src_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), src_strides.data());
|
||||
}
|
||||
|
||||
LocT out_idx = y * idx_size_post + elem_idx * idx_size_axis + x;
|
||||
|
||||
out[out_idx] = src[src_loc];
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
30
mlx/backend/cuda/device/indexing.cuh
Normal file
30
mlx/backend/cuda/device/indexing.cuh
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <cuda/std/tuple>
|
||||
#include <cuda/std/type_traits>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
// Convert an absolute index to positions in a 3d grid, assuming the index is
|
||||
// calculated with:
|
||||
// index = x * dim1 * dim2 + y * dim2 + z
|
||||
template <typename T>
|
||||
inline __host__ __device__ cuda::std::tuple<T, T, T>
|
||||
index_to_dims(T index, T dim1, T dim2) {
|
||||
T x = index / (dim1 * dim2);
|
||||
T y = (index % (dim1 * dim2)) / dim2;
|
||||
T z = index % dim2;
|
||||
return cuda::std::make_tuple(x, y, z);
|
||||
}
|
||||
|
||||
// Get absolute index from possible negative index.
|
||||
template <typename IdxT>
|
||||
inline __host__ __device__ auto absolute_index(IdxT idx, int32_t size) {
|
||||
if constexpr (cuda::std::is_unsigned_v<IdxT>) {
|
||||
return idx;
|
||||
} else {
|
||||
return static_cast<int32_t>(idx < 0 ? idx + size : idx);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
68
mlx/backend/cuda/device/scatter.cuh
Normal file
68
mlx/backend/cuda/device/scatter.cuh
Normal file
@ -0,0 +1,68 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/scatter_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
typename Op,
|
||||
int NIDX,
|
||||
int IDX_NDIM,
|
||||
typename LocT>
|
||||
__global__ void scatter(
|
||||
const T* upd,
|
||||
T* out,
|
||||
LocT size,
|
||||
const __grid_constant__ Shape upd_shape,
|
||||
const __grid_constant__ Strides upd_strides,
|
||||
int32_t upd_ndim,
|
||||
LocT upd_post_idx_size,
|
||||
const __grid_constant__ Shape out_shape,
|
||||
const __grid_constant__ Strides out_strides,
|
||||
int32_t out_ndim,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX> axes,
|
||||
const __grid_constant__ cuda::std::array<IdxT*, NIDX> indices,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NIDX * IDX_NDIM>
|
||||
indices_shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NIDX * IDX_NDIM>
|
||||
indices_strides) {
|
||||
LocT upd_idx = cg::this_grid().thread_rank();
|
||||
if (upd_idx >= size) {
|
||||
return;
|
||||
}
|
||||
|
||||
LocT out_elem = upd_idx % upd_post_idx_size;
|
||||
LocT idx_elem = upd_idx / upd_post_idx_size;
|
||||
|
||||
LocT out_idx = elem_to_loc(
|
||||
out_elem, upd_shape.data() + IDX_NDIM, out_strides.data(), out_ndim);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NIDX; ++i) {
|
||||
LocT idx_loc = elem_to_loc_nd<IDX_NDIM>(
|
||||
idx_elem,
|
||||
indices_shape.data() + i * IDX_NDIM,
|
||||
indices_strides.data() + i * IDX_NDIM);
|
||||
int32_t axis = axes[i];
|
||||
LocT idx_val = absolute_index(indices[i][idx_loc], out_shape[axis]);
|
||||
out_idx += idx_val * out_strides[axis];
|
||||
}
|
||||
|
||||
LocT upd_loc = elem_to_loc(
|
||||
out_elem + idx_elem * upd_post_idx_size,
|
||||
upd_shape.data(),
|
||||
upd_strides.data(),
|
||||
upd_ndim);
|
||||
|
||||
Op{}(out + out_idx, upd[upd_loc]);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
67
mlx/backend/cuda/device/scatter_axis.cuh
Normal file
67
mlx/backend/cuda/device/scatter_axis.cuh
Normal file
@ -0,0 +1,67 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device/indexing.cuh"
|
||||
#include "mlx/backend/cuda/device/scatter_ops.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename IdxT,
|
||||
typename Op,
|
||||
int NDIM,
|
||||
bool UpdC,
|
||||
bool IdxC,
|
||||
typename LocT>
|
||||
__global__ void scatter_axis(
|
||||
const T* upd,
|
||||
const IdxT* indices,
|
||||
T* out,
|
||||
LocT idx_size_pre,
|
||||
LocT idx_size_axis,
|
||||
LocT idx_size_post,
|
||||
const __grid_constant__ cuda::std::array<int32_t, NDIM> shape,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> upd_strides,
|
||||
const __grid_constant__ cuda::std::array<int64_t, NDIM> idx_strides,
|
||||
int32_t axis,
|
||||
int32_t axis_size,
|
||||
int64_t upd_stride_axis,
|
||||
int64_t idx_stride_axis) {
|
||||
LocT index = cg::this_grid().thread_rank();
|
||||
if (index >= idx_size_pre * idx_size_axis * idx_size_post) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto [x, y, z] = index_to_dims(index, idx_size_axis, idx_size_pre);
|
||||
|
||||
LocT elem_idx = z * idx_size_post;
|
||||
|
||||
LocT idx_loc = y * idx_stride_axis;
|
||||
if constexpr (IdxC) {
|
||||
idx_loc += elem_idx * idx_size_axis + x;
|
||||
} else {
|
||||
idx_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), idx_strides.data());
|
||||
}
|
||||
|
||||
auto idx_val = absolute_index(indices[idx_loc], axis_size);
|
||||
|
||||
LocT upd_loc = y * upd_stride_axis;
|
||||
if constexpr (UpdC) {
|
||||
upd_loc += elem_idx * idx_size_axis + x;
|
||||
} else {
|
||||
upd_loc +=
|
||||
elem_to_loc_nd<NDIM>(elem_idx + x, shape.data(), upd_strides.data());
|
||||
}
|
||||
|
||||
LocT out_idx = idx_val * idx_size_post + elem_idx * axis_size + x;
|
||||
|
||||
Op{}(out + out_idx, upd[upd_loc]);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::cu
|
44
mlx/backend/cuda/device/scatter_ops.cuh
Normal file
44
mlx/backend/cuda/device/scatter_ops.cuh
Normal file
@ -0,0 +1,44 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/atomic_ops.cuh"
|
||||
|
||||
namespace mlx::core::cu {
|
||||
|
||||
struct ScatterAssign {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
*out = val;
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterSum {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_add(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterProd {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_prod(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterMax {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_max(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
struct ScatterMin {
|
||||
template <typename T>
|
||||
__device__ void operator()(T* out, T val) const {
|
||||
atomic_min(out, val);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlx::core::cu
|
420
mlx/backend/cuda/indexing.cpp
Normal file
420
mlx/backend/cuda/indexing.cpp
Normal file
@ -0,0 +1,420 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/jit_module.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include "cuda_jit_sources.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr const char* g_scatter_ops[] = {"Max", "Min", "Sum", "Prod", "Assign"};
|
||||
|
||||
void append_indices_arg(
|
||||
cu::JitModule& mod,
|
||||
const std::vector<array>& inputs,
|
||||
int nidx,
|
||||
int idx_ndim) {
|
||||
std::vector<const void*> indices(nidx);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
indices[i] = inputs[i + 1].data<void>();
|
||||
}
|
||||
mod.append_arg(std::move(indices));
|
||||
std::vector<int32_t> indices_shape(nidx * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy_n(
|
||||
inputs[i + 1].shape().begin(),
|
||||
idx_ndim,
|
||||
indices_shape.data() + i * idx_ndim);
|
||||
}
|
||||
mod.append_arg(std::move(indices_shape));
|
||||
std::vector<int64_t> indices_strides(nidx * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy_n(
|
||||
inputs[i + 1].strides().begin(),
|
||||
idx_ndim,
|
||||
indices_strides.data() + i * idx_ndim);
|
||||
}
|
||||
mod.append_arg(std::move(indices_strides));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Gather::eval_gpu");
|
||||
assert(inputs.size() > 0);
|
||||
const auto& src = inputs[0];
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int nidx = inputs.size() - 1;
|
||||
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
||||
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
||||
|
||||
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
|
||||
(src.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
|
||||
|
||||
uint32_t slice_size = std::accumulate(
|
||||
slice_sizes_.begin(), slice_sizes_.end(), 1, std::multiplies<uint32_t>());
|
||||
|
||||
std::string module_name = fmt::format(
|
||||
"gather_{}_{}_{}",
|
||||
dtype_to_string(out.dtype()),
|
||||
dtype_to_string(idx_dtype),
|
||||
nidx);
|
||||
|
||||
auto& s = stream();
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||
std::vector<std::string> kernel_names;
|
||||
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||
for (int large = 0; large <= 1; ++large) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::gather<{}, {}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx_dtype),
|
||||
nidx,
|
||||
ndim,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_gather, std::move(kernel_names));
|
||||
});
|
||||
|
||||
mod.append_arg(src);
|
||||
mod.append_arg(out);
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(out.size());
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(out.size());
|
||||
}
|
||||
mod.append_ndim_arg(src.shape());
|
||||
mod.append_ndim_arg(src.strides());
|
||||
mod.append_arg<int32_t>(src.ndim());
|
||||
mod.append_ndim_arg(slice_sizes_);
|
||||
mod.append_arg(slice_size);
|
||||
mod.append_arg(axes_);
|
||||
append_indices_arg(mod, inputs, nidx, idx_ndim);
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"mlx::core::cu::gather<{}, {}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx_dtype),
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, out, large);
|
||||
});
|
||||
}
|
||||
|
||||
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Gather::eval_gpu");
|
||||
assert(inputs.size() > 1);
|
||||
auto& upd = inputs.back();
|
||||
|
||||
// Copy src into out.
|
||||
CopyType copy_type;
|
||||
if (inputs[0].data_size() == 1) {
|
||||
copy_type = CopyType::Scalar;
|
||||
} else if (inputs[0].flags().row_contiguous) {
|
||||
copy_type = CopyType::Vector;
|
||||
} else {
|
||||
copy_type = CopyType::General;
|
||||
}
|
||||
copy_gpu(inputs[0], out, copy_type);
|
||||
|
||||
// Empty update.
|
||||
if (upd.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int nidx = axes_.size();
|
||||
Dtype idx_dtype = nidx > 0 ? inputs[1].dtype() : int32;
|
||||
int32_t idx_ndim = nidx > 0 ? inputs[1].ndim() : 0;
|
||||
|
||||
bool large = (nidx > 0 && inputs[1].size() > UINT32_MAX) ||
|
||||
(upd.size() > UINT32_MAX) || (out.size() > UINT32_MAX);
|
||||
|
||||
uint32_t upd_post_idx_size = std::accumulate(
|
||||
upd.shape().begin() + idx_ndim,
|
||||
upd.shape().end(),
|
||||
1,
|
||||
std::multiplies<uint32_t>());
|
||||
|
||||
const char* op = g_scatter_ops[reduce_type_];
|
||||
std::string module_name = fmt::format(
|
||||
"scatter_{}_{}_{}_{}",
|
||||
dtype_to_string(out.dtype()),
|
||||
dtype_to_string(idx_dtype),
|
||||
op,
|
||||
nidx);
|
||||
|
||||
auto& s = stream();
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||
std::vector<std::string> kernel_names;
|
||||
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||
for (int large = 0; large <= 1; ++large) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx_dtype),
|
||||
op,
|
||||
nidx,
|
||||
ndim,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_scatter, std::move(kernel_names));
|
||||
});
|
||||
|
||||
mod.append_arg(upd);
|
||||
mod.append_arg(out);
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(upd.size());
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(upd.size());
|
||||
}
|
||||
mod.append_ndim_arg(upd.shape());
|
||||
mod.append_ndim_arg(upd.strides());
|
||||
mod.append_arg<int32_t>(upd.ndim());
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(upd_post_idx_size);
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(upd_post_idx_size);
|
||||
}
|
||||
mod.append_ndim_arg(out.shape());
|
||||
mod.append_ndim_arg(out.strides());
|
||||
mod.append_arg<int32_t>(out.ndim());
|
||||
mod.append_arg(axes_);
|
||||
append_indices_arg(mod, inputs, nidx, idx_ndim);
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"mlx::core::cu::scatter<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx_dtype),
|
||||
op,
|
||||
nidx,
|
||||
idx_ndim,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, upd, large);
|
||||
});
|
||||
}
|
||||
|
||||
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("GatherAxis::eval_gpu");
|
||||
assert(inputs.size() > 1);
|
||||
const auto& src = inputs[0];
|
||||
const auto& idx = inputs[1];
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
if (out.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
|
||||
|
||||
std::string module_name = fmt::format(
|
||||
"gather_axis_{}_{}",
|
||||
dtype_to_string(out.dtype()),
|
||||
dtype_to_string(idx.dtype()));
|
||||
|
||||
auto& s = stream();
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||
std::vector<std::string> kernel_names;
|
||||
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||
for (int contiguous = 0; contiguous < 4; ++contiguous) {
|
||||
for (int large = 0; large <= 1; ++large) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx.dtype()),
|
||||
ndim,
|
||||
contiguous & 1 ? true : false,
|
||||
contiguous & 2 ? true : false,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_gather_axis, std::move(kernel_names));
|
||||
});
|
||||
|
||||
size_t idx_size_pre = 1;
|
||||
size_t idx_size_post = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
idx_size_pre *= idx.shape(i);
|
||||
}
|
||||
for (int i = axis_ + 1; i < idx.ndim(); ++i) {
|
||||
idx_size_post *= idx.shape(i);
|
||||
}
|
||||
size_t idx_size_axis = idx.shape(axis_);
|
||||
|
||||
mod.append_arg(src);
|
||||
mod.append_arg(idx);
|
||||
mod.append_arg(out);
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(idx_size_pre);
|
||||
mod.append_arg<int64_t>(idx_size_axis);
|
||||
mod.append_arg<int64_t>(idx_size_post);
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(idx_size_pre);
|
||||
mod.append_arg<uint32_t>(idx_size_axis);
|
||||
mod.append_arg<uint32_t>(idx_size_post);
|
||||
}
|
||||
mod.append_arg(remove_index(idx.shape(), axis_));
|
||||
mod.append_arg(remove_index(src.strides(), axis_));
|
||||
mod.append_arg(remove_index(idx.strides(), axis_));
|
||||
mod.append_arg<int32_t>(axis_);
|
||||
mod.append_arg(src.shape(axis_));
|
||||
mod.append_arg(src.strides(axis_));
|
||||
mod.append_arg(idx.strides(axis_));
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"mlx::core::cu::gather_axis<{}, {}, {}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx.dtype()),
|
||||
src.ndim() - 1,
|
||||
src.flags().row_contiguous,
|
||||
idx.flags().row_contiguous,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, idx, large);
|
||||
});
|
||||
}
|
||||
|
||||
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("ScatterAxis::eval_gpu");
|
||||
assert(inputs.size() > 2);
|
||||
const auto& src = inputs[0];
|
||||
const auto& idx = inputs[1];
|
||||
const auto& upd = inputs[2];
|
||||
|
||||
// Copy src into out.
|
||||
CopyType copy_type;
|
||||
if (src.data_size() == 1) {
|
||||
copy_type = CopyType::Scalar;
|
||||
} else if (src.flags().row_contiguous) {
|
||||
copy_type = CopyType::Vector;
|
||||
} else {
|
||||
copy_type = CopyType::General;
|
||||
}
|
||||
copy_gpu(src, out, copy_type);
|
||||
|
||||
// Empty update.
|
||||
if (upd.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
bool large = idx.size() > UINT32_MAX || src.size() > UINT32_MAX;
|
||||
|
||||
const char* op = reduce_type_ == ScatterAxis::Sum ? "Sum" : "Assign";
|
||||
std::string module_name = fmt::format(
|
||||
"scatter_axis_{}_{}_{}",
|
||||
dtype_to_string(out.dtype()),
|
||||
dtype_to_string(idx.dtype()),
|
||||
op);
|
||||
|
||||
auto& s = stream();
|
||||
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
|
||||
std::vector<std::string> kernel_names;
|
||||
for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) {
|
||||
for (int contiguous = 0; contiguous < 4; ++contiguous) {
|
||||
for (int large = 0; large <= 1; ++large) {
|
||||
kernel_names.push_back(fmt::format(
|
||||
"mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx.dtype()),
|
||||
op,
|
||||
ndim,
|
||||
contiguous & 1 ? true : false,
|
||||
contiguous & 2 ? true : false,
|
||||
large ? "int64_t" : "uint32_t"));
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_pair(jit_source_scatter_axis, std::move(kernel_names));
|
||||
});
|
||||
|
||||
size_t idx_size_pre = 1;
|
||||
size_t idx_size_post = 1;
|
||||
for (int i = 0; i < axis_; ++i) {
|
||||
idx_size_pre *= idx.shape(i);
|
||||
}
|
||||
for (int i = axis_ + 1; i < idx.ndim(); ++i) {
|
||||
idx_size_post *= idx.shape(i);
|
||||
}
|
||||
size_t idx_size_axis = idx.shape(axis_);
|
||||
|
||||
mod.append_arg(upd);
|
||||
mod.append_arg(idx);
|
||||
mod.append_arg(out);
|
||||
if (large) {
|
||||
mod.append_arg<int64_t>(idx_size_pre);
|
||||
mod.append_arg<int64_t>(idx_size_axis);
|
||||
mod.append_arg<int64_t>(idx_size_post);
|
||||
} else {
|
||||
mod.append_arg<uint32_t>(idx_size_pre);
|
||||
mod.append_arg<uint32_t>(idx_size_axis);
|
||||
mod.append_arg<uint32_t>(idx_size_post);
|
||||
}
|
||||
mod.append_arg(remove_index(idx.shape(), axis_));
|
||||
mod.append_arg(remove_index(upd.strides(), axis_));
|
||||
mod.append_arg(remove_index(idx.strides(), axis_));
|
||||
mod.append_arg<int32_t>(axis_);
|
||||
mod.append_arg(out.shape(axis_));
|
||||
mod.append_arg(upd.strides(axis_));
|
||||
mod.append_arg(idx.strides(axis_));
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
"mlx::core::cu::scatter_axis<{}, {}, mlx::core::cu::Scatter{}, {}, {}, {}, {}>",
|
||||
dtype_to_cuda_type(out.dtype()),
|
||||
dtype_to_cuda_type(idx.dtype()),
|
||||
op,
|
||||
idx.ndim() - 1,
|
||||
upd.flags().row_contiguous,
|
||||
idx.flags().row_contiguous,
|
||||
large ? "int64_t" : "uint32_t");
|
||||
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
for (const auto& in : inputs) {
|
||||
encoder.set_input_array(in);
|
||||
}
|
||||
encoder.set_output_array(out);
|
||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||
mod.launch_kernel(stream, kernel_name, idx, large);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@ -148,24 +148,32 @@ bool compiler_supports_device_sass(Device& device) {
|
||||
#define INCLUDE_PREFIX "mlx/backend/cuda/kernels/"
|
||||
|
||||
constexpr const char* g_include_names[] = {
|
||||
INCLUDE_PREFIX "atomic_ops.cuh",
|
||||
INCLUDE_PREFIX "binary_ops.cuh",
|
||||
INCLUDE_PREFIX "cast_op.cuh",
|
||||
INCLUDE_PREFIX "config.h",
|
||||
INCLUDE_PREFIX "cucomplex_math.cuh",
|
||||
INCLUDE_PREFIX "fp16_math.cuh",
|
||||
INCLUDE_PREFIX "indexing.cuh",
|
||||
INCLUDE_PREFIX "scatter_ops.cuh",
|
||||
INCLUDE_PREFIX "unary_ops.cuh",
|
||||
INCLUDE_PREFIX "ternary_ops.cuh",
|
||||
INCLUDE_PREFIX "utils.cuh",
|
||||
};
|
||||
|
||||
#undef INCLUDE_PREFIX
|
||||
|
||||
constexpr const char* g_headers[] = {
|
||||
jit_source_atomic_ops,
|
||||
jit_source_binary_ops,
|
||||
jit_source_cast_op,
|
||||
jit_source_config,
|
||||
jit_source_cucomplex_math,
|
||||
jit_source_fp16_math,
|
||||
jit_source_indexing,
|
||||
jit_source_scatter_ops,
|
||||
jit_source_unary_ops,
|
||||
jit_source_ternary_ops,
|
||||
jit_source_utils,
|
||||
};
|
||||
|
||||
|
@ -78,8 +78,6 @@ NO_GPU_MULTI(DivMod)
|
||||
NO_GPU(DynamicSlice)
|
||||
NO_GPU(DynamicSliceUpdate)
|
||||
NO_GPU(FFT)
|
||||
NO_GPU(Gather)
|
||||
NO_GPU(GatherAxis)
|
||||
NO_GPU(GatherMM)
|
||||
NO_GPU(GatherQMM)
|
||||
NO_GPU(Hadamard)
|
||||
@ -89,8 +87,6 @@ NO_GPU(Partition)
|
||||
NO_GPU_MULTI(QRF)
|
||||
NO_GPU(QuantizedMatmul)
|
||||
NO_GPU(Scan)
|
||||
NO_GPU(Scatter)
|
||||
NO_GPU(ScatterAxis)
|
||||
NO_GPU_MULTI(SVD)
|
||||
NO_GPU(Inverse)
|
||||
NO_GPU(Cholesky)
|
||||
|
Loading…
Reference in New Issue
Block a user