Up to 10x faster scatter. (#709)

* Faster scatter.

Add specialization for 1-d index tensors.

* Address review comments.

- Check for row contiguity of index, update tensors
  instead of checking strides.
- Add support for 1d specialization with col contiguous update
  tensor, along with a test.

* Nit1

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Nit2

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Vijay Krish 2024-02-21 11:09:30 -08:00 committed by GitHub
parent 7dcdd88e27
commit 972d9a3aea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 244 additions and 83 deletions

View File

@ -7,12 +7,14 @@ import torch
from time_utils import measure_runtime from time_utils import measure_runtime
def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape): def benchmark_scatter_mlx(dst_shape, x_shape, idx_shapes):
def scatter(dst, x, idx): def scatter(dst, x, idx):
dst[idx] = x dst[*idx] = x
mx.eval(dst) mx.eval(dst)
idx = mx.random.randint(0, dst_shape[0] - 1, idx_shape) idx = []
for idx_shape in idx_shapes:
idx.append(mx.random.randint(0, dst_shape[0] - 1, idx_shape))
x = mx.random.normal(x_shape).astype(mx.float32) x = mx.random.normal(x_shape).astype(mx.float32)
dst = mx.random.normal(dst_shape).astype(mx.float32) dst = mx.random.normal(dst_shape).astype(mx.float32)
@ -20,13 +22,15 @@ def benchmark_scatter_mlx(dst_shape, x_shape, idx_shape):
print(f"MLX: {runtime:.3f}ms") print(f"MLX: {runtime:.3f}ms")
def benchmark_scatter_torch(dst_shape, x_shape, idx_shape, device): def benchmark_scatter_torch(dst_shape, x_shape, idx_shapes, device):
def gather(dst, x, idx, device): def gather(dst, x, idx, device):
dst[idx] = x dst[*idx] = x
if device == torch.device("mps"): if device == torch.device("mps"):
torch.mps.synchronize() torch.mps.synchronize()
idx = torch.randint(0, dst_shape[0] - 1, idx_shape).to(device) idx = []
for idx_shape in idx_shapes:
idx.append(torch.randint(0, dst_shape[0] - 1, idx_shape).to(device))
x = torch.randn(x_shape, dtype=torch.float32).to(device) x = torch.randn(x_shape, dtype=torch.float32).to(device)
dst = torch.randn(dst_shape, dtype=torch.float32).to(device) dst = torch.randn(dst_shape, dtype=torch.float32).to(device)
@ -45,9 +49,45 @@ if __name__ == "__main__":
else: else:
device = torch.device("mps") device = torch.device("mps")
dst_shapes = [(10, 64), (100_000, 64), (1_000_000, 64)] dst_shapes = [
idx_shapes = [(1_000_000,), (1_000_000,), (100_000,)] (10, 64),
x_shapes = [(1_000_000, 64), (1_000_000, 64), (100_000, 64)] (100_000, 64),
(1_000_000, 64),
(100_000,),
(2_000_00,),
(20_000_000,),
(10000, 64),
(100, 64),
(100, 10_000, 64),
(10, 100, 100, 21),
(1_000, 1_000, 10),
]
idx_shapes = [
[(1_000_000,)],
[(1_000_000,)],
[(100_000,)],
[(1_000_000,)],
[(20_000_000,)],
[(20_000_000,)],
[(1000000,)],
[(10000000,)],
[(1_000,)],
[(10_000,)],
[(1_000,), (1_000,)],
]
x_shapes = [
(1_000_000, 64),
(1_000_000, 64),
(100_000, 64),
(1_000_000,),
(20_000_000,),
(20_000_000,),
(1000000, 64),
(10000000, 64),
(1_000, 10_000, 64),
(10_000, 100, 100, 21),
(1_000, 10),
]
for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes): for dst_shape, x_shape, idx_shape in zip(dst_shapes, x_shapes, idx_shapes):
print("=" * 20) print("=" * 20)

View File

@ -142,7 +142,28 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
// Get kernel name // Get kernel name
std::ostringstream kname; std::ostringstream kname;
std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
kname << "scatter" << type_to_name(out) << idx_type_name;
int idx_ndim = nidx ? inputs[1].ndim() : 0;
bool index_nd1_specialization = (idx_ndim == 1);
// Bail from fast path (1d index specialization) if scatter dims aren't
// the outermost dims and contiguous since update access won't be raster
// order.
for (auto i = 0; i < axes_.size() && index_nd1_specialization; i++) {
index_nd1_specialization &= (axes_[i] == i);
}
// Bail from fast path (1d index specialization) if any of the dims are
// broadcasted, since we can't rely on linear indexing in that case.
for (int i = 1; i < inputs.size() && index_nd1_specialization; i++) {
index_nd1_specialization &= inputs[i].flags().row_contiguous;
}
if (index_nd1_specialization) {
kname << "scatter_1d_index" << type_to_name(out) << idx_type_name;
} else {
kname << "scatter" << type_to_name(out) << idx_type_name;
}
switch (reduce_type_) { switch (reduce_type_) {
case Scatter::None: case Scatter::None:
kname << "_none"; kname << "_none";
@ -170,85 +191,106 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
// Collect all idx shapes and strides into one place
int idx_ndim = nidx ? inputs[1].ndim() : 0;
std::vector<int> idx_shapes;
std::vector<size_t> idx_strides;
for (int i = 0; i < nidx; ++i) {
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end());
idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end());
}
// Set all the buffers // Set all the buffers
set_array_buffer(compute_encoder, upd, 1); set_array_buffer(compute_encoder, upd, 1);
set_array_buffer(compute_encoder, out, 2); set_array_buffer(compute_encoder, out, 2);
// Set update info // Set update info
size_t upd_ndim = upd.ndim(); uint upd_ndim = upd.ndim();
size_t upd_size = 1; size_t upd_size = 1;
for (int i = idx_ndim; i < upd.ndim(); ++i) { for (int i = idx_ndim; i < upd.ndim(); ++i) {
upd_size *= upd.shape(i); upd_size *= upd.shape(i);
} }
if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain if (index_nd1_specialization) {
int shape_ = 0; bool upd_col_contiguous = upd.flags().col_contiguous;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
} else {
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
compute_encoder->setBytes( compute_encoder->setBytes(
upd.strides().data(), upd_ndim * sizeof(size_t), 4); out.shape().data(), out.shape().size() * sizeof(int), 3);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
// Set output info
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 7);
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
} else {
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
compute_encoder->setBytes( compute_encoder->setBytes(
out.strides().data(), out_ndim * sizeof(size_t), 8); out.strides().data(), out.strides().size() * sizeof(size_t), 4);
} compute_encoder->setBytes(&upd_size, sizeof(size_t), 5);
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9); compute_encoder->setBytes(&upd_col_contiguous, sizeof(bool), 6);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
// Set index info // Set index buffers
if (idx_ndim == 0) { for (int i = 1; i < nidx + 1; ++i) {
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding set_array_buffer(compute_encoder, inputs[i], 20 + i);
// error in the metal API. }
idx_shapes.push_back(0);
idx_strides.push_back(0);
}
compute_encoder->setBytes(
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
compute_encoder->setBytes(
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
// Set index buffers // Launch grid
for (int i = 1; i < nidx + 1; ++i) { MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
set_array_buffer(compute_encoder, inputs[i], 20 + i); MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
} compute_encoder->dispatchThreads(grid_dims, group_dims);
// Launch grid } else {
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); // Collect all idx shapes and strides into one place
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); std::vector<int> idx_shapes;
compute_encoder->dispatchThreads(grid_dims, group_dims); std::vector<size_t> idx_strides;
for (int i = 0; i < nidx; ++i) {
idx_shapes.insert(
idx_shapes.end(),
inputs[i + 1].shape().begin(),
inputs[i + 1].shape().end());
idx_strides.insert(
idx_strides.end(),
inputs[i + 1].strides().begin(),
inputs[i + 1].strides().end());
}
if (upd_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 3);
compute_encoder->setBytes(&stride_, sizeof(size_t), 4);
} else {
compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3);
compute_encoder->setBytes(
upd.strides().data(), upd_ndim * sizeof(size_t), 4);
}
compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5);
compute_encoder->setBytes(&upd_size, sizeof(size_t), 6);
// Set output info
size_t out_ndim = out.ndim();
if (out_ndim == 0) {
// Need placeholders so Metal doesn't compalain
int shape_ = 0;
size_t stride_ = 0;
compute_encoder->setBytes(&shape_, sizeof(int), 7);
compute_encoder->setBytes(&stride_, sizeof(size_t), 8);
} else {
compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7);
compute_encoder->setBytes(
out.strides().data(), out_ndim * sizeof(size_t), 8);
}
compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9);
compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10);
// Set index info
if (idx_ndim == 0) {
// Add a 0 in idx_shapes and strides to avoid the missing buffer binding
// error in the metal API.
idx_shapes.push_back(0);
idx_strides.push_back(0);
}
compute_encoder->setBytes(
idx_shapes.data(), idx_shapes.size() * sizeof(int), 11);
compute_encoder->setBytes(
idx_strides.data(), idx_strides.size() * sizeof(size_t), 12);
compute_encoder->setBytes(&idx_ndim, sizeof(int), 13);
// Set index buffers
for (int i = 1; i < nidx + 1; ++i) {
set_array_buffer(compute_encoder, inputs[i], 20 + i);
}
// Launch grid
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
compute_encoder->dispatchThreads(grid_dims, group_dims);
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -13,6 +13,58 @@ using namespace metal;
// Scatter kernel // Scatter kernel
///////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////
template <typename T, typename IdxT, typename Op, int NIDX> \
METAL_FUNC void scatter_1d_index_impl(
const device T *updates [[buffer(1)]],
device mlx_atomic<T> *out [[buffer(2)]],
const constant int* out_shape [[buffer(3)]],
const constant size_t* out_strides [[buffer(4)]],
const constant size_t& upd_size [[buffer(5)]],
const constant bool& upd_col_contiguous [[buffer(6)]],
const thread array<const device IdxT*, NIDX>& idx_buffers,
uint2 gid [[thread_position_in_grid]]) {
Op op;
uint out_idx = 0;
for (int i = 0; i < NIDX; i++) {
auto idx_val = offset_neg_idx(
idx_buffers[i][gid.y], out_shape[i]);
out_idx += idx_val * out_strides[i];
}
if (!upd_col_contiguous) {
op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx + gid.x);
} else {
op.atomic_update(out, updates[gid.x * upd_size + gid.y], out_idx + gid.x);
}
}
#define make_scatter_1d_index(IDX_ARG, IDX_ARR) \
template <typename T, typename IdxT, typename Op, int NIDX> \
[[kernel]] void scatter_1d_index( \
const device T *updates [[buffer(1)]], \
device mlx_atomic<T> *out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
const constant bool& upd_col_contiguous [[buffer(6)]], \
IDX_ARG(IdxT) \
uint2 gid [[thread_position_in_grid]]) { \
\
const array<const device IdxT*, NIDX> idx_buffers = {IDX_ARR()}; \
\
return scatter_1d_index_impl<T, IdxT, Op, NIDX>( \
updates, \
out, \
out_shape, \
out_strides, \
upd_size, \
upd_col_contiguous, \
idx_buffers, \
gid); \
\
}
template <typename T, typename IdxT, typename Op, int NIDX> template <typename T, typename IdxT, typename Op, int NIDX>
METAL_FUNC void scatter_impl( METAL_FUNC void scatter_impl(
@ -46,10 +98,14 @@ METAL_FUNC void scatter_impl(
out_idx += idx_val * out_strides[ax]; out_idx += idx_val * out_strides[ax];
} }
auto out_offset = elem_to_loc( if (upd_size > 1) {
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); auto out_offset = elem_to_loc(
ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
out_idx += out_offset;
}
auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
op.atomic_update(out, updates[upd_idx], out_idx + out_offset); op.atomic_update(out, updates[upd_idx], out_idx);
} }
#define make_scatter_impl(IDX_ARG, IDX_ARR) \ #define make_scatter_impl(IDX_ARG, IDX_ARR) \
@ -92,7 +148,9 @@ template <typename T, typename IdxT, typename Op, int NIDX> \
gid); \ gid); \
} }
#define make_scatter(n) make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) #define make_scatter(n) \
make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) \
make_scatter_1d_index(IDX_ARG_ ##n, IDX_ARR_ ##n)
make_scatter(0) make_scatter(0)
make_scatter(1) make_scatter(1)
@ -129,8 +187,21 @@ template [[host_name("scatter" name "_" #nidx)]] \
IDX_ARG(idx_t) \ IDX_ARG(idx_t) \
uint2 gid [[thread_position_in_grid]]); uint2 gid [[thread_position_in_grid]]);
#define instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG) \
template [[host_name("scatter_1d_index" name "_" #nidx)]] \
[[kernel]] void scatter_1d_index<src_t, idx_t, op_t, nidx>( \
const device src_t *updates [[buffer(1)]], \
device mlx_atomic<src_t> *out [[buffer(2)]], \
const constant int* out_shape [[buffer(3)]], \
const constant size_t* out_strides [[buffer(4)]], \
const constant size_t& upd_size [[buffer(5)]], \
const constant bool& upd_col_contiguous [[buffer(6)]], \
IDX_ARG(idx_t) \
uint2 gid [[thread_position_in_grid]]);
#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \ #define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \
instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) \
instantiate_scatter6(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx)
// Special case NINDEX=0 // Special case NINDEX=0
#define instantiate_scatter_nd0(name, type) \ #define instantiate_scatter_nd0(name, type) \

View File

@ -1858,6 +1858,14 @@ TEST_CASE("test scatter") {
out = scatter(in, inds, updates, 0); out = scatter(in, inds, updates, 0);
CHECK(array_equal(out, reshape(arange(16, float32), {4, 4})).item<bool>()); CHECK(array_equal(out, reshape(arange(16, float32), {4, 4})).item<bool>());
// Array scatters with col contiguous updates
in = zeros({4, 4}, float32);
inds = array({0, 1, 2, 3});
updates = transpose(reshape(arange(16, float32), {4, 1, 4}));
out = scatter(in, inds, updates, 0);
CHECK(array_equal(out, transpose(reshape(arange(16, float32), {4, 4})))
.item<bool>());
// Irregular strided index and reduce collision test // Irregular strided index and reduce collision test
in = zeros({10}, float32); in = zeros({10}, float32);
inds = broadcast_to(array(3), {10}); inds = broadcast_to(array(3), {10});
@ -1877,10 +1885,10 @@ TEST_CASE("test scatter") {
// Irregularly strided updates test // Irregularly strided updates test
in = ones({3, 3}); in = ones({3, 3});
updates = broadcast_to(array({0, 0, 0}), {1, 3, 3}); updates = broadcast_to(array({2, 2, 2}), {1, 3, 3});
inds = array({0}); inds = array({0});
out = scatter(in, inds, updates, 0); out = scatter(in, inds, updates, 0);
CHECK(array_equal(out, zeros({3, 3})).item<bool>()); CHECK(array_equal(out, ones({3, 3}) * 2).item<bool>());
// Along different axis // Along different axis
in = zeros({2, 3}); in = zeros({2, 3});