diff --git a/.circleci/config.yml b/.circleci/config.yml index 537f15969..5f26778c4 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -94,8 +94,7 @@ jobs: command: | source env/bin/activate LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu - # TODO: Reenable when Circle CI can run gpu jobs - # DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu + LOW_MEMORY=1 DEVICE=gpu python3.9 -m xmlrunner discover -v python/tests -o test-results/gpu # TODO: Reenable when extension api becomes stable # - run: # name: Build example extension @@ -110,8 +109,9 @@ jobs: mkdir -p build && cd build && cmake .. && make -j - run: name: Run CPP tests - #command: METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests - command: DEVICE=cpu ./build/tests/tests + command: | + DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests + DEVICE=cpu ./build/tests/tests build_release: parameters: @@ -225,7 +225,9 @@ workflows: build_and_test: when: and: - - equal: [ main, << pipeline.git.branch >> ] + - matches: + pattern: "^(?!pull/)[-\\w]+$" + value: << pipeline.git.branch >> - not: << pipeline.parameters.nightly_build >> - not: << pipeline.parameters.weekly_build >> - not: << pipeline.parameters.test_release >> diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index e50441d48..1fb2bd46f 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -215,15 +215,6 @@ MTL::ComputeCommandEncoder* Device::get_command_encoder(int index) { return eit->second; } -MTL::ArgumentEncoder* Device::argument_encoder( - const std::vector& arg_descs) const { - // NB array here is already autoreleased but the returned argument - // encoder is owned by the caller and must be released/autoreleased - NS::Array* arg_desc_arr = NS::Array::array( - reinterpret_cast(arg_descs.data()), arg_descs.size()); - return device_->newArgumentEncoder(arg_desc_arr); -} - void Device::register_library( const std::string& lib_name, const std::string& lib_path) { diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index cf2256846..6908f8684 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -51,6 +51,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { auto compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); size_t slice_size = 1; for (auto s : slice_sizes_) { @@ -63,91 +64,50 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { auto group_dims = get_block_dims(dim0, dim1, 1); MTL::Size grid_dims = MTL::Size(dim0, dim1, 1); - compute_encoder->setComputePipelineState(kernel); + // Collect all idx shapes and strides into one place + std::vector idx_shapes; + std::vector idx_strides; - // Make the argument buffer to store the indices for the - // `Indices` struct in kernels/indexing.metal - std::vector arg_descs(4); - arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[0]->setIndex(0); - arg_descs[0]->setDataType(MTL::DataType::DataTypePointer); - arg_descs[0]->setArrayLength(nidx); - - // Shapes - arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[1]->setDataType(MTL::DataType::DataTypePointer); - arg_descs[1]->setIndex(nidx + 1); - - // Strides - arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[2]->setDataType(MTL::DataType::DataTypePointer); - arg_descs[2]->setIndex(nidx + 2); - - // Indices ndim - arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[3]->setDataType(MTL::DataType::DataTypeInt); - arg_descs[3]->setIndex(nidx + 3); - - // Get the argument encoder - auto arg_enc = d.argument_encoder(arg_descs); - - // Allocate and fill buffers for shapes and strides - auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim); - auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim); for (int i = 0; i < nidx; ++i) { - std::copy( + idx_shapes.insert( + idx_shapes.end(), inputs[i + 1].shape().begin(), - inputs[i + 1].shape().end(), - static_cast(idx_shapes_buf.raw_ptr()) + i * idx_ndim); - std::copy( + inputs[i + 1].shape().end()); + + idx_strides.insert( + idx_strides.end(), inputs[i + 1].strides().begin(), - inputs[i + 1].strides().end(), - static_cast(idx_strides_buf.raw_ptr()) + i * idx_ndim); + inputs[i + 1].strides().end()); } - // Allocate the argument buffer - auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength()); - - // Register data with the encoder - arg_enc->setArgumentBuffer(static_cast(arg_buf.ptr()), 0); - for (int i = 0; i < nidx; ++i) { - set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i); - } - if (idx_ndim > 0) { - arg_enc->setBuffer( - static_cast(idx_shapes_buf.ptr()), 0, nidx + 1); - compute_encoder->useResource( - static_cast(idx_shapes_buf.ptr()), - MTL::ResourceUsageRead); - arg_enc->setBuffer( - static_cast(idx_strides_buf.ptr()), 0, nidx + 2); - compute_encoder->useResource( - static_cast(idx_strides_buf.ptr()), - MTL::ResourceUsageRead); - } - *static_cast(arg_enc->constantData(nidx + 3)) = idx_ndim; - // Set all the buffers set_array_buffer(compute_encoder, src, 0); - compute_encoder->setBuffer(static_cast(arg_buf.ptr()), 0, 1); - set_array_buffer(compute_encoder, out, 2); + set_array_buffer(compute_encoder, out, 1); - compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3); - compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4); - compute_encoder->setBytes(&ndim, sizeof(size_t), 5); - compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6); - compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 7); + // Set source info + compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 2); + compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 3); + compute_encoder->setBytes(&ndim, sizeof(size_t), 4); + compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 5); + compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 6); + // Set index info + // + // We don't need to check for empty idx_shapes because gather has a + // idx_ndim == 0 specialization + compute_encoder->setBytes( + idx_shapes.data(), idx_shapes.size() * sizeof(int), 7); + compute_encoder->setBytes( + idx_strides.data(), idx_strides.size() * sizeof(size_t), 8); + compute_encoder->setBytes(&idx_ndim, sizeof(int), 9); + + // Set index buffers + for (int i = 1; i < nidx + 1; ++i) { + set_array_buffer(compute_encoder, inputs[i], 20 + i); + } + + // Launch grid compute_encoder->dispatchThreads(grid_dims, group_dims); - - // Cleanup temporaries - arg_enc->release(); - d.get_command_buffer(s.index)->addCompletedHandler( - [arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) { - allocator::free(arg_buf); - allocator::free(idx_shapes_buf); - allocator::free(idx_strides_buf); - }); } void Scatter::eval_gpu(const std::vector& inputs, array& out) { @@ -214,77 +174,33 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setComputePipelineState(kernel); - // Make the argument buffer to store the indices for the - // `Indices` struct in kernels/indexing.metal - std::vector arg_descs(4); - arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[0]->setIndex(0); - arg_descs[0]->setDataType(MTL::DataType::DataTypePointer); - arg_descs[0]->setArrayLength(nidx); - - // Shapes - arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[1]->setDataType(MTL::DataType::DataTypePointer); - arg_descs[1]->setIndex(nidx + 1); - - // Strides - arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[2]->setDataType(MTL::DataType::DataTypePointer); - arg_descs[2]->setIndex(nidx + 2); - - // Indices ndim - arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor(); - arg_descs[3]->setDataType(MTL::DataType::DataTypeInt); - arg_descs[3]->setIndex(nidx + 3); - - // Get the argument encoder - auto arg_enc = d.argument_encoder(arg_descs); - - // Allocate and fill buffers for shapes and strides + // Collect all idx shapes and strides into one place int idx_ndim = nidx ? inputs[1].ndim() : 0; - auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim); - auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim); + std::vector idx_shapes; + std::vector idx_strides; + for (int i = 0; i < nidx; ++i) { - std::copy( + idx_shapes.insert( + idx_shapes.end(), inputs[i + 1].shape().begin(), - inputs[i + 1].shape().end(), - static_cast(idx_shapes_buf.raw_ptr()) + i * idx_ndim); - std::copy( + inputs[i + 1].shape().end()); + + idx_strides.insert( + idx_strides.end(), inputs[i + 1].strides().begin(), - inputs[i + 1].strides().end(), - static_cast(idx_strides_buf.raw_ptr()) + i * idx_ndim); + inputs[i + 1].strides().end()); } - // Allocate the argument buffer - auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength()); + // Set all the buffers + set_array_buffer(compute_encoder, upd, 1); + set_array_buffer(compute_encoder, out, 2); - // Register data with the encoder - arg_enc->setArgumentBuffer(static_cast(arg_buf.ptr()), 0); - for (int i = 0; i < nidx; ++i) { - set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i); - } - if (idx_ndim > 0) { - arg_enc->setBuffer( - static_cast(idx_shapes_buf.ptr()), 0, nidx + 1); - compute_encoder->useResource( - static_cast(idx_shapes_buf.ptr()), - MTL::ResourceUsageRead); - arg_enc->setBuffer( - static_cast(idx_strides_buf.ptr()), 0, nidx + 2); - compute_encoder->useResource( - static_cast(idx_strides_buf.ptr()), - MTL::ResourceUsageRead); - } - *static_cast(arg_enc->constantData(nidx + 3)) = idx_ndim; - - compute_encoder->setBuffer(static_cast(arg_buf.ptr()), 0, 0); + // Set update info size_t upd_ndim = upd.ndim(); size_t upd_size = 1; for (int i = idx_ndim; i < upd.ndim(); ++i) { upd_size *= upd.shape(i); } - set_array_buffer(compute_encoder, upd, 1); - set_array_buffer(compute_encoder, out, 2); if (upd_ndim == 0) { // Need placeholders so Metal doesn't compalain int shape_ = 0; @@ -299,6 +215,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5); compute_encoder->setBytes(&upd_size, sizeof(size_t), 6); + // Set output info size_t out_ndim = out.ndim(); if (out_ndim == 0) { // Need placeholders so Metal doesn't compalain @@ -314,18 +231,28 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9); compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10); + // Set index info + if (idx_ndim == 0) { + // Add a 0 in idx_shapes and strides to avoid the missing buffer binding + // error in the metal API. + idx_shapes.push_back(0); + idx_strides.push_back(0); + } + compute_encoder->setBytes( + idx_shapes.data(), idx_shapes.size() * sizeof(int), 11); + compute_encoder->setBytes( + idx_strides.data(), idx_strides.size() * sizeof(size_t), 12); + compute_encoder->setBytes(&idx_ndim, sizeof(int), 13); + + // Set index buffers + for (int i = 1; i < nidx + 1; ++i) { + set_array_buffer(compute_encoder, inputs[i], 20 + i); + } + + // Launch grid MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1); MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1); compute_encoder->dispatchThreads(grid_dims, group_dims); - - // Cleanup temporaries - arg_enc->release(); - d.get_command_buffer(s.index)->addCompletedHandler( - [arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) { - allocator::free(arg_buf); - allocator::free(idx_shapes_buf); - allocator::free(idx_strides_buf); - }); } } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 2d271abb4..12e09deaa 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -6,6 +6,7 @@ set( ${CMAKE_CURRENT_SOURCE_DIR}/complex.h ${CMAKE_CURRENT_SOURCE_DIR}/defines.h ${CMAKE_CURRENT_SOURCE_DIR}/erf.h + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.h ${CMAKE_CURRENT_SOURCE_DIR}/reduce.h ${CMAKE_CURRENT_SOURCE_DIR}/utils.h ) @@ -26,7 +27,8 @@ set( "softmax" "sort" "unary" - "indexing" + "gather" + "scatter" ) function(build_kernel_base TARGET SRCFILE DEPS) diff --git a/mlx/backend/metal/kernels/gather.metal b/mlx/backend/metal/kernels/gather.metal new file mode 100644 index 000000000..793b2af62 --- /dev/null +++ b/mlx/backend/metal/kernels/gather.metal @@ -0,0 +1,187 @@ +// Copyright © 2023-2024 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/indexing.h" +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; + +///////////////////////////////////////////////////////////////////// +// Gather kernel +///////////////////////////////////////////////////////////////////// + +template +METAL_FUNC void gather_impl( + const device T *src [[buffer(0)]], + device T *out [[buffer(1)]], + const constant int *src_shape [[buffer(2)]], + const constant size_t *src_strides [[buffer(3)]], + const constant size_t& src_ndim [[buffer(4)]], + const constant int *slice_sizes [[buffer(5)]], + const constant int *axes [[buffer(6)]], + const thread Indices& indices, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + + auto ind_idx = index.x; + auto ind_offset = index.y; + + size_t src_idx = 0; + for (int i = 0; i < NIDX; ++i) { + size_t idx_loc; + if (IDX_NDIM == 0) { + idx_loc = 0; + } else if (IDX_NDIM == 1) { + idx_loc = ind_idx * indices.strides[indices.ndim * i]; + } else { + idx_loc = elem_to_loc( + ind_idx, + &indices.shapes[indices.ndim * i], + &indices.strides[indices.ndim * i], + indices.ndim); + } + auto ax = axes[i]; + auto idx_val = offset_neg_idx( + indices.buffers[i][idx_loc], src_shape[ax]); + src_idx += idx_val * src_strides[ax]; + } + + auto src_offset = elem_to_loc( + ind_offset, slice_sizes, src_strides, src_ndim); + + size_t out_idx = index.y + static_cast(grid_dim.y) * index.x; + out[out_idx] = src[src_offset + src_idx]; + +} + +#define make_gather_impl(IDX_ARG, IDX_ARR) \ +template \ +[[kernel]] void gather( \ + const device T *src [[buffer(0)]], \ + device T *out [[buffer(1)]], \ + const constant int *src_shape [[buffer(2)]], \ + const constant size_t *src_strides [[buffer(3)]], \ + const constant size_t& src_ndim [[buffer(4)]], \ + const constant int *slice_sizes [[buffer(5)]], \ + const constant int *axes [[buffer(6)]], \ + const constant int *idx_shapes [[buffer(7)]], \ + const constant size_t *idx_strides [[buffer(8)]], \ + const constant int& idx_ndim [[buffer(9)]], \ + IDX_ARG(IdxT) \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]) { \ + \ + Indices idxs{ \ + {{IDX_ARR()}}, \ + idx_shapes, \ + idx_strides, \ + idx_ndim}; \ + \ + return gather_impl( \ + src, \ + out, \ + src_shape, \ + src_strides, \ + src_ndim, \ + slice_sizes, \ + axes, \ + idxs, \ + index, \ + grid_dim); \ +} + +#define make_gather(n) make_gather_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) + +make_gather(0) +make_gather(1) +make_gather(2) +make_gather(3) +make_gather(4) +make_gather(5) +make_gather(6) +make_gather(7) +make_gather(8) +make_gather(9) +make_gather(10) + +///////////////////////////////////////////////////////////////////// +// Gather instantiations +///////////////////////////////////////////////////////////////////// + +#define instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG, nd, nd_name) \ +template [[host_name("gather" name "_" #nidx "" #nd_name)]] \ +[[kernel]] void gather( \ + const device src_t *src [[buffer(0)]], \ + device src_t *out [[buffer(1)]], \ + const constant int *src_shape [[buffer(2)]], \ + const constant size_t *src_strides [[buffer(3)]], \ + const constant size_t& src_ndim [[buffer(4)]], \ + const constant int *slice_sizes [[buffer(5)]], \ + const constant int *axes [[buffer(6)]], \ + const constant int *idx_shapes [[buffer(7)]], \ + const constant size_t *idx_strides [[buffer(8)]], \ + const constant int& idx_ndim [[buffer(9)]], \ + IDX_ARG(idx_t) \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]); + +#define instantiate_gather5(name, src_t, idx_t, nidx, nd, nd_name) \ + instantiate_gather6(name, src_t, idx_t, nidx, IDX_ARG_ ##nidx, nd, nd_name) + +#define instantiate_gather4(name, src_t, idx_t, nidx) \ + instantiate_gather5(name, src_t, idx_t, nidx, 0, _0) \ + instantiate_gather5(name, src_t, idx_t, nidx, 1, _1) \ + instantiate_gather5(name, src_t, idx_t, nidx, 2, ) + + +// Special for case NIDX=0 +instantiate_gather4("bool_", bool, bool, 0) +instantiate_gather4("uint8", uint8_t, bool, 0) +instantiate_gather4("uint16", uint16_t, bool, 0) +instantiate_gather4("uint32", uint32_t, bool, 0) +instantiate_gather4("uint64", uint64_t, bool, 0) +instantiate_gather4("int8", int8_t, bool, 0) +instantiate_gather4("int16", int16_t, bool, 0) +instantiate_gather4("int32", int32_t, bool, 0) +instantiate_gather4("int64", int64_t, bool, 0) +instantiate_gather4("float16", half, bool, 0) +instantiate_gather4("float32", float, bool, 0) +instantiate_gather4("bfloat16", bfloat16_t, bool, 0) + +#define instantiate_gather3(name, src_type, ind_type) \ + instantiate_gather4(name, src_type, ind_type, 1) \ + instantiate_gather4(name, src_type, ind_type, 2) \ + instantiate_gather4(name, src_type, ind_type, 3) \ + instantiate_gather4(name, src_type, ind_type, 4) \ + instantiate_gather4(name, src_type, ind_type, 5) \ + instantiate_gather4(name, src_type, ind_type, 6) \ + instantiate_gather4(name, src_type, ind_type, 7) \ + instantiate_gather4(name, src_type, ind_type, 8) \ + instantiate_gather4(name, src_type, ind_type, 9) \ + instantiate_gather4(name, src_type, ind_type, 10) + +#define instantiate_gather(name, src_type) \ + instantiate_gather3(#name "bool_", src_type, bool) \ + instantiate_gather3(#name "uint8", src_type, uint8_t) \ + instantiate_gather3(#name "uint16", src_type, uint16_t) \ + instantiate_gather3(#name "uint32", src_type, uint32_t) \ + instantiate_gather3(#name "uint64", src_type, uint64_t) \ + instantiate_gather3(#name "int8", src_type, int8_t) \ + instantiate_gather3(#name "int16", src_type, int16_t) \ + instantiate_gather3(#name "int32", src_type, int32_t) \ + instantiate_gather3(#name "int64", src_type, int64_t) + +instantiate_gather(bool_, bool) +instantiate_gather(uint8, uint8_t) +instantiate_gather(uint16, uint16_t) +instantiate_gather(uint32, uint32_t) +instantiate_gather(uint64, uint64_t) +instantiate_gather(int8, int8_t) +instantiate_gather(int16, int16_t) +instantiate_gather(int32, int32_t) +instantiate_gather(int64, int64_t) +instantiate_gather(float16, half) +instantiate_gather(float32, float) +instantiate_gather(bfloat16, bfloat16_t) \ No newline at end of file diff --git a/mlx/backend/metal/kernels/indexing.h b/mlx/backend/metal/kernels/indexing.h new file mode 100644 index 000000000..c2b37f3ff --- /dev/null +++ b/mlx/backend/metal/kernels/indexing.h @@ -0,0 +1,54 @@ +// Copyright © 2023-2024 Apple Inc. + +#include + +using namespace metal; + +///////////////////////////////////////////////////////////////////// +// Indexing utils +///////////////////////////////////////////////////////////////////// + +template +struct Indices { + const array buffers; + const constant int* shapes; + const constant size_t* strides; + const int ndim; +}; + +template +METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size) { + if (is_unsigned_v) { + return idx; + } else { + return (idx < 0) ? idx + size : idx; + } +} + +#define IDX_ARG_N(idx_t, n) const device idx_t *idx##n [[buffer(n)]], + +#define IDX_ARG_0(idx_t) +#define IDX_ARG_1(idx_t) IDX_ARG_0(idx_t) IDX_ARG_N(idx_t, 21) +#define IDX_ARG_2(idx_t) IDX_ARG_1(idx_t) IDX_ARG_N(idx_t, 22) +#define IDX_ARG_3(idx_t) IDX_ARG_2(idx_t) IDX_ARG_N(idx_t, 23) +#define IDX_ARG_4(idx_t) IDX_ARG_3(idx_t) IDX_ARG_N(idx_t, 24) +#define IDX_ARG_5(idx_t) IDX_ARG_4(idx_t) IDX_ARG_N(idx_t, 25) +#define IDX_ARG_6(idx_t) IDX_ARG_5(idx_t) IDX_ARG_N(idx_t, 26) +#define IDX_ARG_7(idx_t) IDX_ARG_6(idx_t) IDX_ARG_N(idx_t, 27) +#define IDX_ARG_8(idx_t) IDX_ARG_7(idx_t) IDX_ARG_N(idx_t, 28) +#define IDX_ARG_9(idx_t) IDX_ARG_8(idx_t) IDX_ARG_N(idx_t, 29) +#define IDX_ARG_10(idx_t) IDX_ARG_9(idx_t) IDX_ARG_N(idx_t, 30) + +#define IDX_ARR_N(n) idx##n, + +#define IDX_ARR_0() +#define IDX_ARR_1() IDX_ARR_0() IDX_ARR_N(21) +#define IDX_ARR_2() IDX_ARR_1() IDX_ARR_N(22) +#define IDX_ARR_3() IDX_ARR_2() IDX_ARR_N(23) +#define IDX_ARR_4() IDX_ARR_3() IDX_ARR_N(24) +#define IDX_ARR_5() IDX_ARR_4() IDX_ARR_N(25) +#define IDX_ARR_6() IDX_ARR_5() IDX_ARR_N(26) +#define IDX_ARR_7() IDX_ARR_6() IDX_ARR_N(27) +#define IDX_ARR_8() IDX_ARR_7() IDX_ARR_N(28) +#define IDX_ARR_9() IDX_ARR_8() IDX_ARR_N(29) +#define IDX_ARR_10() IDX_ARR_9() IDX_ARR_N(30) \ No newline at end of file diff --git a/mlx/backend/metal/kernels/indexing.metal b/mlx/backend/metal/kernels/indexing.metal deleted file mode 100644 index 7b6e2399a..000000000 --- a/mlx/backend/metal/kernels/indexing.metal +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include -#include - -#include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/reduce.h" -#include "mlx/backend/metal/kernels/utils.h" - -using namespace metal; - -///////////////////////////////////////////////////////////////////// -// Gather kernel -///////////////////////////////////////////////////////////////////// - -template -struct Indices { - const array buffers [[id(0)]]; - device int* shapes [[id(NIDX + 1)]]; - device size_t* strides [[id(NIDX + 2)]]; - const int ndim [[id(NIDX + 3)]]; -}; - -template -inline size_t offset_neg_idx(IdxT idx, size_t size) { - return (idx < 0) ? idx + size : idx; -} - -template <> -inline size_t offset_neg_idx(bool idx, size_t) { - return idx; -} - -template <> -inline size_t offset_neg_idx(uint32_t idx, size_t) { - return idx; -} - -// IDX_NDIM is the number of dimensions of the indices arrays. Compile-time -// special case for 0 and 1. Anything >= 2 uses the general case -template -[[kernel]] void gather( - const device T *src [[buffer(0)]], - const constant Indices& indices [[buffer(1)]], - device T *out [[buffer(2)]], - const constant int *src_shape [[buffer(3)]], - const constant size_t *src_strides [[buffer(4)]], - const constant size_t& src_ndim [[buffer(5)]], - const constant int *slice_sizes [[buffer(6)]], - const constant int *axes [[buffer(7)]], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - - auto ind_idx = index.x; - auto ind_offset = index.y; - - size_t src_idx = 0; - for (int i = 0; i < NIDX; ++i) { - size_t idx_loc; - if (IDX_NDIM == 0) { - idx_loc = 0; - } else if (IDX_NDIM == 1) { - idx_loc = ind_idx * indices.strides[indices.ndim * i]; - } else { - idx_loc = elem_to_loc( - ind_idx, - &indices.shapes[indices.ndim * i], - &indices.strides[indices.ndim * i], - indices.ndim); - } - auto ax = axes[i]; - auto idx_val = offset_neg_idx( - indices.buffers[i][idx_loc], src_shape[ax]); - src_idx += idx_val * src_strides[ax]; - } - - auto src_offset = elem_to_loc( - ind_offset, slice_sizes, src_strides, src_ndim); - - size_t out_idx = index.y + static_cast(grid_dim.y) * index.x; - out[out_idx] = src[src_offset + src_idx]; -} - -#define instantiate_gather4(name, src_type, ind_type, nindex) \ -template [[host_name("gather" name "_" #nindex "_0")]] \ -[[kernel]] void gather( \ - const device src_type *src [[buffer(0)]], \ - const constant Indices& indices [[buffer(1)]], \ - device src_type *out [[buffer(2)]], \ - const constant int *src_shape [[buffer(3)]], \ - const constant size_t *src_strides [[buffer(4)]], \ - const constant size_t& src_ndim [[buffer(5)]], \ - const constant int *slice_sizes [[buffer(6)]], \ - const constant int* axes [[buffer(7)]], \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); \ -template [[host_name("gather" name "_" #nindex "_1")]] \ -[[kernel]] void gather( \ - const device src_type *src [[buffer(0)]], \ - const constant Indices& indices [[buffer(1)]], \ - device src_type *out [[buffer(2)]], \ - const constant int *src_shape [[buffer(3)]], \ - const constant size_t *src_strides [[buffer(4)]], \ - const constant size_t& src_ndim [[buffer(5)]], \ - const constant int *slice_sizes [[buffer(6)]], \ - const constant int* axes [[buffer(7)]], \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); \ -template [[host_name("gather" name "_" #nindex)]] \ -[[kernel]] void gather( \ - const device src_type *src [[buffer(0)]], \ - const constant Indices& indices [[buffer(1)]], \ - device src_type *out [[buffer(2)]], \ - const constant int *src_shape [[buffer(3)]], \ - const constant size_t *src_strides [[buffer(4)]], \ - const constant size_t& src_ndim [[buffer(5)]], \ - const constant int *slice_sizes [[buffer(6)]], \ - const constant int* axes [[buffer(7)]], \ - uint2 index [[thread_position_in_grid]], \ - uint2 grid_dim [[threads_per_grid]]); - - -// Special for case NIDX=0 -instantiate_gather4("bool_", bool, bool, 0) -instantiate_gather4("uint8", uint8_t, bool, 0) -instantiate_gather4("uint16", uint16_t, bool, 0) -instantiate_gather4("uint32", uint32_t, bool, 0) -instantiate_gather4("uint64", uint64_t, bool, 0) -instantiate_gather4("int8", int8_t, bool, 0) -instantiate_gather4("int16", int16_t, bool, 0) -instantiate_gather4("int32", int32_t, bool, 0) -instantiate_gather4("int64", int64_t, bool, 0) -instantiate_gather4("float16", half, bool, 0) -instantiate_gather4("float32", float, bool, 0) -instantiate_gather4("bfloat16", bfloat16_t, bool, 0) - -#define instantiate_gather3(name, src_type, ind_type) \ - instantiate_gather4(name, src_type, ind_type, 1) \ - instantiate_gather4(name, src_type, ind_type, 2) \ - instantiate_gather4(name, src_type, ind_type, 3) \ - instantiate_gather4(name, src_type, ind_type, 4) \ - instantiate_gather4(name, src_type, ind_type, 5) \ - instantiate_gather4(name, src_type, ind_type, 6) \ - instantiate_gather4(name, src_type, ind_type, 7) \ - instantiate_gather4(name, src_type, ind_type, 8) \ - instantiate_gather4(name, src_type, ind_type, 9) \ - instantiate_gather4(name, src_type, ind_type, 10) - -#define instantiate_gather(name, src_type) \ - instantiate_gather3(#name "bool_", src_type, bool) \ - instantiate_gather3(#name "uint8", src_type, uint8_t) \ - instantiate_gather3(#name "uint16", src_type, uint16_t) \ - instantiate_gather3(#name "uint32", src_type, uint32_t) \ - instantiate_gather3(#name "uint64", src_type, uint64_t) \ - instantiate_gather3(#name "int8", src_type, int8_t) \ - instantiate_gather3(#name "int16", src_type, int16_t) \ - instantiate_gather3(#name "int32", src_type, int32_t) \ - instantiate_gather3(#name "int64", src_type, int64_t) - -instantiate_gather(bool_, bool) -instantiate_gather(uint8, uint8_t) -instantiate_gather(uint16, uint16_t) -instantiate_gather(uint32, uint32_t) -instantiate_gather(uint64, uint64_t) -instantiate_gather(int8, int8_t) -instantiate_gather(int16, int16_t) -instantiate_gather(int32, int32_t) -instantiate_gather(int64, int64_t) -instantiate_gather(float16, half) -instantiate_gather(float32, float) -instantiate_gather(bfloat16, bfloat16_t) - -///////////////////////////////////////////////////////////////////// -// Scatter kernel -///////////////////////////////////////////////////////////////////// - -template -[[kernel]] void scatter( - const device Indices& indices [[buffer(0)]], - const device T *updates [[buffer(1)]], - device mlx_atomic *out [[buffer(2)]], - const device int *upd_shape [[buffer(3)]], - const device size_t *upd_strides [[buffer(4)]], - const device size_t& upd_ndim [[buffer(5)]], - const device size_t& upd_size [[buffer(6)]], - const device int *out_shape [[buffer(7)]], - const device size_t *out_strides [[buffer(8)]], - const device size_t& out_ndim [[buffer(9)]], - const device int* axes [[buffer(10)]], - uint2 gid [[thread_position_in_grid]]) { - - Op op; - auto ind_idx = gid.y; - auto ind_offset = gid.x; - - size_t out_idx = 0; - for (int i = 0; i < NIDX; ++i) { - auto idx_loc = elem_to_loc( - ind_idx, - &indices.shapes[indices.ndim * i], - &indices.strides[indices.ndim * i], - indices.ndim); - auto ax = axes[i]; - auto idx_val = offset_neg_idx( - indices.buffers[i][idx_loc], out_shape[ax]); - out_idx += idx_val * out_strides[ax]; - } - - auto out_offset = elem_to_loc( - ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); - auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); - op.atomic_update(out, updates[upd_idx], out_idx + out_offset); -} - -#define instantiate_scatter4(name, type, ind_type, op_type, nindex) \ -template [[host_name("scatter" name "_" #nindex)]] \ -[[kernel]] void scatter( \ - const device Indices& indices [[buffer(0)]], \ - const device type *updates [[buffer(1)]], \ - device mlx_atomic *out [[buffer(2)]], \ - const device int *upd_shape [[buffer(3)]], \ - const device size_t *upd_strides [[buffer(4)]], \ - const device size_t& upd_ndim [[buffer(5)]], \ - const device size_t& upd_size [[buffer(6)]], \ - const device int *out_shape [[buffer(7)]], \ - const device size_t *out_strides [[buffer(8)]], \ - const device size_t& out_ndim [[buffer(9)]], \ - const device int* axes [[buffer(10)]], \ - uint2 gid [[thread_position_in_grid]]); - -// Special case NINDEX=0 -#define instantiate_scatter_nd0(name, type) \ - instantiate_scatter4(#name "none", type, bool, None, 0) \ - instantiate_scatter4(#name "_sum", type, bool, Sum, 0) \ - instantiate_scatter4(#name "_prod", type, bool, Prod, 0) \ - instantiate_scatter4(#name "_max", type, bool, Max, 0) \ - instantiate_scatter4(#name "_min", type, bool, Min, 0) - -#define instantiate_scatter3(name, type, ind_type, op_type) \ - instantiate_scatter4(name, type, ind_type, op_type, 1) \ - instantiate_scatter4(name, type, ind_type, op_type, 2) \ - instantiate_scatter4(name, type, ind_type, op_type, 3) \ - instantiate_scatter4(name, type, ind_type, op_type, 4) \ - instantiate_scatter4(name, type, ind_type, op_type, 5) \ - instantiate_scatter4(name, type, ind_type, op_type, 6) \ - instantiate_scatter4(name, type, ind_type, op_type, 7) \ - instantiate_scatter4(name, type, ind_type, op_type, 8) \ - instantiate_scatter4(name, type, ind_type, op_type, 9) \ - instantiate_scatter4(name, type, ind_type, op_type, 10) - -#define instantiate_scatter2(name, type, ind_type) \ - instantiate_scatter3(name "_none", type, ind_type, None) \ - instantiate_scatter3(name "_sum", type, ind_type, Sum) \ - instantiate_scatter3(name "_prod", type, ind_type, Prod) \ - instantiate_scatter3(name "_max", type, ind_type, Max) \ - instantiate_scatter3(name "_min", type, ind_type, Min) - -#define instantiate_scatter(name, type) \ - instantiate_scatter2(#name "bool_", type, bool) \ - instantiate_scatter2(#name "uint8", type, uint8_t) \ - instantiate_scatter2(#name "uint16", type, uint16_t) \ - instantiate_scatter2(#name "uint32", type, uint32_t) \ - instantiate_scatter2(#name "uint64", type, uint64_t) \ - instantiate_scatter2(#name "int8", type, int8_t) \ - instantiate_scatter2(#name "int16", type, int16_t) \ - instantiate_scatter2(#name "int32", type, int32_t) \ - instantiate_scatter2(#name "int64", type, int64_t) - -// TODO uint64 and int64 unsupported -instantiate_scatter_nd0(bool_, bool) -instantiate_scatter_nd0(uint8, uint8_t) -instantiate_scatter_nd0(uint16, uint16_t) -instantiate_scatter_nd0(uint32, uint32_t) -instantiate_scatter_nd0(int8, int8_t) -instantiate_scatter_nd0(int16, int16_t) -instantiate_scatter_nd0(int32, int32_t) -instantiate_scatter_nd0(float16, half) -instantiate_scatter_nd0(float32, float) -instantiate_scatter_nd0(bfloat16, bfloat16_t) - -instantiate_scatter(bool_, bool) -instantiate_scatter(uint8, uint8_t) -instantiate_scatter(uint16, uint16_t) -instantiate_scatter(uint32, uint32_t) -instantiate_scatter(int8, int8_t) -instantiate_scatter(int16, int16_t) -instantiate_scatter(int32, int32_t) -instantiate_scatter(float16, half) -instantiate_scatter(float32, float) -instantiate_scatter(bfloat16, bfloat16_t) diff --git a/mlx/backend/metal/kernels/scatter.metal b/mlx/backend/metal/kernels/scatter.metal new file mode 100644 index 000000000..7a94be7da --- /dev/null +++ b/mlx/backend/metal/kernels/scatter.metal @@ -0,0 +1,194 @@ +// Copyright © 2023-2024 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/indexing.h" +#include "mlx/backend/metal/kernels/reduce.h" +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; + +///////////////////////////////////////////////////////////////////// +// Scatter kernel +///////////////////////////////////////////////////////////////////// + + +template +METAL_FUNC void scatter_impl( + const device T *updates [[buffer(1)]], + device mlx_atomic *out [[buffer(2)]], + const constant int *upd_shape [[buffer(3)]], + const constant size_t *upd_strides [[buffer(4)]], + const constant size_t& upd_ndim [[buffer(5)]], + const constant size_t& upd_size [[buffer(6)]], + const constant int *out_shape [[buffer(7)]], + const constant size_t *out_strides [[buffer(8)]], + const constant size_t& out_ndim [[buffer(9)]], + const constant int* axes [[buffer(10)]], + const thread Indices& indices, + uint2 gid [[thread_position_in_grid]]) { + + Op op; + auto ind_idx = gid.y; + auto ind_offset = gid.x; + + size_t out_idx = 0; + for (int i = 0; i < NIDX; ++i) { + auto idx_loc = elem_to_loc( + ind_idx, + &indices.shapes[indices.ndim * i], + &indices.strides[indices.ndim * i], + indices.ndim); + auto ax = axes[i]; + auto idx_val = offset_neg_idx( + indices.buffers[i][idx_loc], out_shape[ax]); + out_idx += idx_val * out_strides[ax]; + } + + auto out_offset = elem_to_loc( + ind_offset, upd_shape + indices.ndim, out_strides, out_ndim); + auto upd_idx = elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim); + op.atomic_update(out, updates[upd_idx], out_idx + out_offset); +} + +#define make_scatter_impl(IDX_ARG, IDX_ARR) \ +template \ +[[kernel]] void scatter( \ + const device T *updates [[buffer(1)]], \ + device mlx_atomic *out [[buffer(2)]], \ + const constant int *upd_shape [[buffer(3)]], \ + const constant size_t *upd_strides [[buffer(4)]], \ + const constant size_t& upd_ndim [[buffer(5)]], \ + const constant size_t& upd_size [[buffer(6)]], \ + const constant int *out_shape [[buffer(7)]], \ + const constant size_t *out_strides [[buffer(8)]], \ + const constant size_t& out_ndim [[buffer(9)]], \ + const constant int* axes [[buffer(10)]], \ + const constant int *idx_shapes [[buffer(11)]], \ + const constant size_t *idx_strides [[buffer(12)]], \ + const constant int& idx_ndim [[buffer(13)]], \ + IDX_ARG(IdxT) \ + uint2 gid [[thread_position_in_grid]]) { \ + \ + Indices idxs{ \ + {{IDX_ARR()}}, \ + idx_shapes, \ + idx_strides, \ + idx_ndim}; \ + \ + return scatter_impl( \ + updates, \ + out, \ + upd_shape, \ + upd_strides, \ + upd_ndim, \ + upd_size, \ + out_shape, \ + out_strides, \ + out_ndim, \ + axes, \ + idxs, \ + gid); \ +} + +#define make_scatter(n) make_scatter_impl(IDX_ARG_ ##n, IDX_ARR_ ##n) + +make_scatter(0) +make_scatter(1) +make_scatter(2) +make_scatter(3) +make_scatter(4) +make_scatter(5) +make_scatter(6) +make_scatter(7) +make_scatter(8) +make_scatter(9) +make_scatter(10) + +///////////////////////////////////////////////////////////////////// +// Scatter instantiations +///////////////////////////////////////////////////////////////////// + +#define instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG) \ +template [[host_name("scatter" name "_" #nidx)]] \ +[[kernel]] void scatter( \ + const device src_t *updates [[buffer(1)]], \ + device mlx_atomic *out [[buffer(2)]], \ + const constant int *upd_shape [[buffer(3)]], \ + const constant size_t *upd_strides [[buffer(4)]], \ + const constant size_t& upd_ndim [[buffer(5)]], \ + const constant size_t& upd_size [[buffer(6)]], \ + const constant int *out_shape [[buffer(7)]], \ + const constant size_t *out_strides [[buffer(8)]], \ + const constant size_t& out_ndim [[buffer(9)]], \ + const constant int* axes [[buffer(10)]], \ + const constant int *idx_shapes [[buffer(11)]], \ + const constant size_t *idx_strides [[buffer(12)]], \ + const constant int& idx_ndim [[buffer(13)]], \ + IDX_ARG(idx_t) \ + uint2 gid [[thread_position_in_grid]]); + +#define instantiate_scatter4(name, src_t, idx_t, op_t, nidx) \ + instantiate_scatter5(name, src_t, idx_t, op_t, nidx, IDX_ARG_ ##nidx) + +// Special case NINDEX=0 +#define instantiate_scatter_nd0(name, type) \ + instantiate_scatter4(#name "none", type, bool, None, 0) \ + instantiate_scatter4(#name "_sum", type, bool, Sum, 0) \ + instantiate_scatter4(#name "_prod", type, bool, Prod, 0) \ + instantiate_scatter4(#name "_max", type, bool, Max, 0) \ + instantiate_scatter4(#name "_min", type, bool, Min, 0) + +#define instantiate_scatter3(name, type, ind_type, op_type) \ + instantiate_scatter4(name, type, ind_type, op_type, 1) \ + instantiate_scatter4(name, type, ind_type, op_type, 2) \ + instantiate_scatter4(name, type, ind_type, op_type, 3) \ + instantiate_scatter4(name, type, ind_type, op_type, 4) \ + instantiate_scatter4(name, type, ind_type, op_type, 5) \ + instantiate_scatter4(name, type, ind_type, op_type, 6) \ + instantiate_scatter4(name, type, ind_type, op_type, 7) \ + instantiate_scatter4(name, type, ind_type, op_type, 8) \ + instantiate_scatter4(name, type, ind_type, op_type, 9) \ + instantiate_scatter4(name, type, ind_type, op_type, 10) + +#define instantiate_scatter2(name, type, ind_type) \ + instantiate_scatter3(name "_none", type, ind_type, None) \ + instantiate_scatter3(name "_sum", type, ind_type, Sum) \ + instantiate_scatter3(name "_prod", type, ind_type, Prod) \ + instantiate_scatter3(name "_max", type, ind_type, Max) \ + instantiate_scatter3(name "_min", type, ind_type, Min) + +#define instantiate_scatter(name, type) \ + instantiate_scatter2(#name "bool_", type, bool) \ + instantiate_scatter2(#name "uint8", type, uint8_t) \ + instantiate_scatter2(#name "uint16", type, uint16_t) \ + instantiate_scatter2(#name "uint32", type, uint32_t) \ + instantiate_scatter2(#name "uint64", type, uint64_t) \ + instantiate_scatter2(#name "int8", type, int8_t) \ + instantiate_scatter2(#name "int16", type, int16_t) \ + instantiate_scatter2(#name "int32", type, int32_t) \ + instantiate_scatter2(#name "int64", type, int64_t) + +// TODO uint64 and int64 unsupported +instantiate_scatter_nd0(bool_, bool) +instantiate_scatter_nd0(uint8, uint8_t) +instantiate_scatter_nd0(uint16, uint16_t) +instantiate_scatter_nd0(uint32, uint32_t) +instantiate_scatter_nd0(int8, int8_t) +instantiate_scatter_nd0(int16, int16_t) +instantiate_scatter_nd0(int32, int32_t) +instantiate_scatter_nd0(float16, half) +instantiate_scatter_nd0(float32, float) +instantiate_scatter_nd0(bfloat16, bfloat16_t) + +instantiate_scatter(bool_, bool) +instantiate_scatter(uint8, uint8_t) +instantiate_scatter(uint16, uint16_t) +instantiate_scatter(uint32, uint32_t) +instantiate_scatter(int8, int8_t) +instantiate_scatter(int16, int16_t) +instantiate_scatter(int32, int32_t) +instantiate_scatter(float16, half) +instantiate_scatter(float32, float) +instantiate_scatter(bfloat16, bfloat16_t) diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index f7c672c9f..363632a30 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -9,20 +9,6 @@ namespace mlx::core { namespace { -void set_array_buffer( - MTL::ComputeCommandEncoder* compute_encoder, - MTL::ArgumentEncoder* enc, - const array& a, - int idx) { - auto a_buf = static_cast(a.buffer().ptr()); - auto offset = a.data() - - static_cast(const_cast(a_buf)->contents()); - enc->setBuffer(a_buf, offset, idx); - // MTL::Resource usage through argument buffer needs to be explicitly - // flagged to enable hazard tracking - compute_encoder->useResource(a_buf, MTL::ResourceUsageRead); -} - void set_array_buffer( MTL::ComputeCommandEncoder* enc, const array& a,