mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
Up to 10x faster scatter. (#709)
* Faster scatter. Add specialization for 1-d index tensors. * Address review comments. - Check for row contiguity of index, update tensors instead of checking strides. - Add support for 1d specialization with col contiguous update tensor, along with a test. * Nit1 Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * Nit2 Co-authored-by: Awni Hannun <awni.hannun@gmail.com> --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -142,7 +142,28 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// Get kernel name
|
||||
std::ostringstream kname;
|
||||
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
|
||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
bool index_nd1_specialization = (idx_ndim == 1);
|
||||
|
||||
// Bail from fast path (1d index specialization) if scatter dims aren't
|
||||
// the outermost dims and contiguous since update access won't be raster
|
||||
// order.
|
||||
for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) {
|
||||
index_nd1_specialization &= (axes_[i] == i);
|
||||
}
|
||||
|
||||
// Bail from fast path (1d index specialization) if any of the dims are
|
||||
// broadcasted, since we can't rely on linear indexing in that case.
|
||||
for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) {
|
||||
index_nd1_specialization &= inputs[i].flags().row_contiguous;
|
||||
}
|
||||
|
||||
if (index_nd1_specialization) {
|
||||
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
|
||||
} else {
|
||||
kname << "scatter" << type_to_name(out) << idx_type_name;
|
||||
}
|
||||
switch (reduce_type_) {
|
||||
case Scatter::None:
|
||||
kname << "_none";
|
||||
@@ -170,85 +191,106 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
// Collect all idx shapes and strides into one place
|
||||
int idx_ndim = nidx ? inputs[1].ndim() : 0;
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end());
|
||||
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end());
|
||||
}
|
||||
|
||||
// Set all the buffers
|
||||
set_array_buffer(compute_encoder, upd, 1);
|
||||
set_array_buffer(compute_encoder, out, 2);
|
||||
|
||||
// Set update info
|
||||
size_t upd_ndim = upd.ndim();
|
||||
uint upd_ndim = upd.ndim();
|
||||
size_t upd_size = 1;
|
||||
for (int i = idx_ndim; i < upd.ndim(); ++i) {
|
||||
upd_size *= upd.shape(i);
|
||||
}
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
|
||||
|
||||
if (index_nd1_specialization) {
|
||||
bool upd_col_contiguous = upd.flags().col_contiguous;
|
||||
compute_encoder->setBytes(
|
||||
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
|
||||
} else {
|
||||
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
|
||||
out.shape().data(), out.shape().size() * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out_ndim * sizeof(size_t), 8);
|
||||
}
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
out.strides().data(), out.strides().size() * sizeof(size_t), 4);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_col_contiguous, sizeof(bool), 6);
|
||||
|
||||
// Set index info
|
||||
if (idx_ndim == 0) {
|
||||
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
|
||||
// error in the metal API.
|
||||
idx_shapes.push_back(0);
|
||||
idx_strides.push_back(0);
|
||||
}
|
||||
compute_encoder->setBytes(
|
||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
|
||||
compute_encoder->setBytes(
|
||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
}
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
} else {
|
||||
// Collect all idx shapes and strides into one place
|
||||
std::vector<int> idx_shapes;
|
||||
std::vector<size_t> idx_strides;
|
||||
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
idx_shapes.insert(
|
||||
idx_shapes.end(),
|
||||
inputs[i + 1].shape().begin(),
|
||||
inputs[i + 1].shape().end());
|
||||
|
||||
idx_strides.insert(
|
||||
idx_strides.end(),
|
||||
inputs[i + 1].strides().begin(),
|
||||
inputs[i + 1].strides().end());
|
||||
}
|
||||
|
||||
if (upd_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 3);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
|
||||
} else {
|
||||
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
|
||||
compute_encoder->setBytes(
|
||||
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
|
||||
}
|
||||
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
|
||||
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
|
||||
|
||||
// Set output info
|
||||
size_t out_ndim = out.ndim();
|
||||
if (out_ndim == 0) {
|
||||
// Need placeholders so Metal doesn't compalain
|
||||
int shape_ = 0;
|
||||
size_t stride_ = 0;
|
||||
compute_encoder->setBytes(&shape_, sizeof(int), 7);
|
||||
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
|
||||
} else {
|
||||
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
|
||||
compute_encoder->setBytes(
|
||||
out.strides().data(), out_ndim * sizeof(size_t), 8);
|
||||
}
|
||||
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
|
||||
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
|
||||
|
||||
// Set index info
|
||||
if (idx_ndim == 0) {
|
||||
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
|
||||
// error in the metal API.
|
||||
idx_shapes.push_back(0);
|
||||
idx_strides.push_back(0);
|
||||
}
|
||||
compute_encoder->setBytes(
|
||||
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
|
||||
compute_encoder->setBytes(
|
||||
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
|
||||
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
|
||||
|
||||
// Set index buffers
|
||||
for (int i = 1; i < nidx + 1; ++i) {
|
||||
set_array_buffer(compute_encoder, inputs[i], 20 + i);
|
||||
}
|
||||
|
||||
// Launch grid
|
||||
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
|
||||
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
|
||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -13,6 +13,58 @@ using namespace metal;
|
||||
// Scatter kernel
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
METAL_FUNC void scatter_1d_index_impl(
|
||||
const device T *updates [[buffer(1)]],
|
||||
device mlx_atomic<T> *out [[buffer(2)]],
|
||||
const constant int* out_shape [[buffer(3)]],
|
||||
const constant size_t* out_strides [[buffer(4)]],
|
||||
const constant size_t& upd_size [[buffer(5)]],
|
||||
const constant bool& upd_col_contiguous [[buffer(6)]],
|
||||
const thread array<const device IdxT*, NIDX>& idx_buffers,
|
||||
uint2 gid [[thread_position_in_grid]]) {
|
||||
|
||||
Op op;
|
||||
|
||||
uint out_idx = 0;
|
||||
for (int i = 0; i < NIDX; i++) {
|
||||
auto idx_val = offset_neg_idx(
|
||||
idx_buffers[i][gid.y], out_shape[i]);
|
||||
out_idx += idx_val * out_strides[i];
|
||||
}
|
||||
|
||||
if (!upd_col_contiguous) {
|
||||
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
|
||||
} else {
|
||||
op.atomic_update(out, updates[gid.x * upd_size + gid.y], out_idx + gid.x);
|
||||
}
|
||||
}
|
||||
|
||||
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
|
||||
template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
[[kernel]] void scatter_1d_index( \
|
||||
const device T *updates [[buffer(1)]], \
|
||||
device mlx_atomic<T> *out [[buffer(2)]], \
|
||||
const constant int* out_shape [[buffer(3)]], \
|
||||
const constant size_t* out_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_size [[buffer(5)]], \
|
||||
const constant bool& upd_col_contiguous [[buffer(6)]], \
|
||||
IDX_ARG(IdxT) \
|
||||
uint2 gid [[thread_position_in_grid]]) { \
|
||||
\
|
||||
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
|
||||
\
|
||||
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
|
||||
updates, \
|
||||
out, \
|
||||
out_shape, \
|
||||
out_strides, \
|
||||
upd_size, \
|
||||
upd_col_contiguous, \
|
||||
idx_buffers, \
|
||||
gid); \
|
||||
\
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, typename Op, int NIDX>
|
||||
METAL_FUNC void scatter_impl(
|
||||
@@ -46,10 +98,14 @@ METAL_FUNC void scatter_impl(
|
||||
out_idx += idx_val * out_strides[ax];
|
||||
}
|
||||
|
||||
auto out_offset = elem_to_loc(
|
||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
if (upd_size > 1) {
|
||||
auto out_offset = elem_to_loc(
|
||||
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
|
||||
out_idx += out_offset;
|
||||
}
|
||||
|
||||
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
|
||||
op.atomic_update(out, updates[upd_idx], out_idx + out_offset);
|
||||
op.atomic_update(out, updates[upd_idx], out_idx);
|
||||
}
|
||||
|
||||
#define make_scatter_impl(IDX_ARG, IDX_ARR) \
|
||||
@@ -90,9 +146,11 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
|
||||
axes, \
|
||||
idxs, \
|
||||
gid); \
|
||||
}
|
||||
}
|
||||
|
||||
#define make_scatter(n) make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n)
|
||||
#define make_scatter(n) \
|
||||
make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) \
|
||||
make_scatter_1d_index(IDX_ARG_ ##n, IDX_ARR_ ##n)
|
||||
|
||||
make_scatter(0)
|
||||
make_scatter(1)
|
||||
@@ -129,8 +187,21 @@ template [[host_name("scatter" name "_" #nidx)]] \
|
||||
IDX_ARG(idx_t) \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
|
||||
template [[host_name("scatter_1d_index" name "_" #nidx)]] \
|
||||
[[kernel]] void scatter_1d_index<src_t, idx_t, op_t, nidx>( \
|
||||
const device src_t *updates [[buffer(1)]], \
|
||||
device mlx_atomic<src_t> *out [[buffer(2)]], \
|
||||
const constant int* out_shape [[buffer(3)]], \
|
||||
const constant size_t* out_strides [[buffer(4)]], \
|
||||
const constant size_t& upd_size [[buffer(5)]], \
|
||||
const constant bool& upd_col_contiguous [[buffer(6)]], \
|
||||
IDX_ARG(idx_t) \
|
||||
uint2 gid [[thread_position_in_grid]]);
|
||||
|
||||
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
|
||||
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx)
|
||||
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \
|
||||
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx)
|
||||
|
||||
// Special case NINDEX=0
|
||||
#define instantiate_scatter_nd0(name, type) \
|
||||
|
Reference in New Issue
Block a user