From cec8661113dfd8da8ffcd37419bdb504a02ebff2 Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 20 Mar 2024 10:39:25 -0700 Subject: [PATCH] Add a SliceUpdate op and primitive (#850) * Enable copy to work with int64 strides * Fix uniform buffer indices or copy kernel arguments * Update utils.h * Remove manual unrolling of elem to loc loop * GPU copy updated to handle negative strides * Add slice update primitive --- mlx/backend/accelerate/primitives.cpp | 1 + mlx/backend/common/copy.cpp | 413 +++++++++++++++------- mlx/backend/common/copy.h | 13 +- mlx/backend/common/default_primitives.cpp | 1 + mlx/backend/common/primitives.cpp | 129 +++++-- mlx/backend/common/utils.h | 43 ++- mlx/backend/metal/copy.cpp | 75 ++-- mlx/backend/metal/copy.h | 24 +- mlx/backend/metal/kernels/copy.metal | 186 +++++----- mlx/backend/metal/kernels/utils.h | 246 +++++++------ mlx/backend/metal/matmul.cpp | 6 +- mlx/backend/metal/primitives.cpp | 68 +++- mlx/backend/metal/utils.h | 35 +- mlx/backend/no_metal/primitives.cpp | 1 + mlx/ops.cpp | 212 +++++------ mlx/ops.h | 17 + mlx/primitives.cpp | 108 ++++++ mlx/primitives.h | 38 ++ mlx/utils.cpp | 9 + mlx/utils.h | 3 +- tests/ops_tests.cpp | 25 ++ 21 files changed, 1147 insertions(+), 506 deletions(-) diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 4ee0c0e89..5567c0785 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -69,6 +69,7 @@ DEFAULT(Select) DEFAULT(Sigmoid) DEFAULT(Sign) DEFAULT(Slice) +DEFAULT(SliceUpdate) DEFAULT_MULTI(Split) DEFAULT(Sort) DEFAULT(StopGradient) diff --git a/mlx/backend/common/copy.cpp b/mlx/backend/common/copy.cpp index cc37e767a..53956041a 100644 --- a/mlx/backend/common/copy.cpp +++ b/mlx/backend/common/copy.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include @@ -25,121 +25,196 @@ void copy_vector(const array& src, array& dst) { std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr); } -template -void copy_general_dim1(const array& src, array& dst) { +template +void copy_general_dim1( + const array& src, + array& dst, + const std::vector& data_shape, + const std::vector& i_strides, + int64_t i_offset) { const SrcT* src_ptr = src.data(); DstT* dst_ptr = dst.data(); - size_t src_idx = 0; - size_t dst_idx = 0; - for (size_t i = 0; i < src.shape()[0]; ++i) { + stride_t src_idx = i_offset; + stride_t dst_idx = 0; + for (int i = 0; i < data_shape[0]; ++i) { dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); - src_idx += src.strides()[0]; + src_idx += i_strides[0]; } } template -void copy_general_dim2(const array& src, array& dst) { +inline void copy_general_dim1(const array& src, array& dst) { + return copy_general_dim1( + src, dst, src.shape(), src.strides(), 0); +} + +template +void copy_general_dim2( + const array& src, + array& dst, + const std::vector& data_shape, + const std::vector& i_strides, + int64_t i_offset) { const SrcT* src_ptr = src.data(); DstT* dst_ptr = dst.data(); - size_t src_idx = 0; - size_t dst_idx = 0; - for (size_t i = 0; i < src.shape()[0]; ++i) { - for (size_t j = 0; j < src.shape()[1]; ++j) { + stride_t src_idx = i_offset; + stride_t dst_idx = 0; + for (int i = 0; i < data_shape[0]; ++i) { + for (int j = 0; j < data_shape[1]; ++j) { dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); - src_idx += src.strides()[1]; + src_idx += i_strides[1]; } - src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1]; + src_idx += i_strides[0] - i_strides[1] * data_shape[1]; } } template -void copy_general_dim3(const array& src, array& dst) { +inline void copy_general_dim2(const array& src, array& dst) { + return copy_general_dim2( + src, dst, src.shape(), src.strides(), 0); +} + +template +void copy_general_dim3( + const array& src, + array& dst, + const std::vector& data_shape, + const std::vector& i_strides, + int64_t i_offset) { const SrcT* src_ptr = src.data(); DstT* dst_ptr = dst.data(); - size_t src_idx = 0; - size_t dst_idx = 0; - for (size_t i = 0; i < src.shape()[0]; ++i) { - for (size_t j = 0; j < src.shape()[1]; ++j) { - for (size_t k = 0; k < src.shape()[2]; ++k) { + stride_t src_idx = i_offset; + stride_t dst_idx = 0; + for (int i = 0; i < data_shape[0]; ++i) { + for (int j = 0; j < data_shape[1]; ++j) { + for (int k = 0; k < data_shape[2]; ++k) { dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); - src_idx += src.strides()[2]; + src_idx += i_strides[2]; } - src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2]; + src_idx += i_strides[1] - i_strides[2] * data_shape[2]; } - src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1]; + src_idx += i_strides[0] - i_strides[1] * data_shape[1]; } } template -void copy_general_dim4(const array& src, array& dst) { +inline void copy_general_dim3(const array& src, array& dst) { + return copy_general_dim3( + src, dst, src.shape(), src.strides(), 0); +} + +template +void copy_general_dim4( + const array& src, + array& dst, + const std::vector& data_shape, + const std::vector& i_strides, + int64_t i_offset) { const SrcT* src_ptr = src.data(); DstT* dst_ptr = dst.data(); - size_t src_idx = 0; - size_t dst_idx = 0; - for (size_t i = 0; i < src.shape()[0]; ++i) { - for (size_t j = 0; j < src.shape()[1]; ++j) { - for (size_t k = 0; k < src.shape()[2]; ++k) { - for (size_t ii = 0; ii < src.shape()[3]; ++ii) { + stride_t src_idx = i_offset; + stride_t dst_idx = 0; + for (int i = 0; i < data_shape[0]; ++i) { + for (int j = 0; j < data_shape[1]; ++j) { + for (int k = 0; k < data_shape[2]; ++k) { + for (int ii = 0; ii < data_shape[3]; ++ii) { dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); - src_idx += src.strides()[3]; + src_idx += i_strides[3]; } - src_idx += src.strides()[2] - src.strides()[3] * src.shape()[3]; + src_idx += i_strides[2] - i_strides[3] * data_shape[3]; } - src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2]; + src_idx += i_strides[1] - i_strides[2] * data_shape[2]; } - src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1]; + src_idx += i_strides[0] - i_strides[1] * data_shape[1]; } } template -void copy_general(const array& src, array& dst) { +inline void copy_general_dim4(const array& src, array& dst) { + return copy_general_dim4( + src, dst, src.shape(), src.strides(), 0); +} + +template +void copy_general( + const array& src, + array& dst, + const std::vector& data_shape, + const std::vector& i_strides, + int64_t i_offset) { switch (src.ndim()) { case 1: - copy_general_dim1(src, dst); + copy_general_dim1( + src, dst, data_shape, i_strides, i_offset); return; case 2: - copy_general_dim2(src, dst); + copy_general_dim2( + src, dst, data_shape, i_strides, i_offset); return; case 3: - copy_general_dim3(src, dst); + copy_general_dim3( + src, dst, data_shape, i_strides, i_offset); return; case 4: - copy_general_dim4(src, dst); + copy_general_dim4( + src, dst, data_shape, i_strides, i_offset); return; } - auto src_ptr = src.data(); + auto src_ptr = src.data() + i_offset; auto dst_ptr = dst.data(); for (size_t i = 0; i < dst.size(); ++i) { - size_t src_elem = elem_to_loc(i, src.shape(), src.strides()); + stride_t src_elem = elem_to_loc(i, data_shape, i_strides); dst_ptr[i] = static_cast(src_ptr[src_elem]); } } -template +template +inline void copy_general(const array& src, array& dst) { + return copy_general( + src, dst, src.shape(), src.strides(), 0); +} + +template +inline void copy_general( + const array& src, + array& dst, + const std::vector& data_shape, + const std::vector& i_strides, + const std::vector& o_strides, + int64_t i_offset, + int64_t o_offset) { + return copy_general( + src, dst, data_shape, i_strides, i_offset); +} + +template inline void copy_general_general_dims( const array& src, array& dst, - size_t offset_src, - size_t offset_dst) { + const std::vector& data_shape, + const std::vector& i_strides, + const std::vector& o_strides, + stride_t i_offset, + stride_t o_offset) { if constexpr (D > 1) { int axis = src.ndim() - D; - auto stride_src = src.strides()[axis]; - auto stride_dst = dst.strides()[axis]; - auto N = src.shape(axis); + auto stride_src = i_strides[axis]; + auto stride_dst = o_strides[axis]; + auto N = data_shape[axis]; for (int i = 0; i < N; i++) { - copy_general_general_dims( - src, dst, offset_src, offset_dst); - offset_src += stride_src; - offset_dst += stride_dst; + copy_general_general_dims( + src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); + i_offset += stride_src; + o_offset += stride_dst; } } else { int axis = src.ndim() - 1; - auto stride_src = src.strides()[axis]; - auto stride_dst = dst.strides()[axis]; - auto N = src.shape(axis); - const SrcT* src_ptr = src.data() + offset_src; - DstT* dst_ptr = dst.data() + offset_dst; + auto stride_src = i_strides[axis]; + auto stride_dst = o_strides[axis]; + auto N = data_shape[axis]; + const SrcT* src_ptr = src.data() + i_offset; + DstT* dst_ptr = dst.data() + o_offset; for (int i = 0; i < N; i++) { *dst_ptr = static_cast(*src_ptr); src_ptr += stride_src; @@ -148,37 +223,56 @@ inline void copy_general_general_dims( } } -template -void copy_general_general(const array& src, array& dst) { +template +void copy_general_general( + const array& src, + array& dst, + const std::vector& data_shape, + const std::vector& i_strides, + const std::vector& o_strides, + stride_t i_offset, + stride_t o_offset) { switch (src.ndim()) { case 1: - copy_general_general_dims(src, dst, 0, 0); + copy_general_general_dims( + src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); return; case 2: - copy_general_general_dims(src, dst, 0, 0); + copy_general_general_dims( + src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); return; case 3: - copy_general_general_dims(src, dst, 0, 0); + copy_general_general_dims( + src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); return; case 4: - copy_general_general_dims(src, dst, 0, 0); + copy_general_general_dims( + src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); return; case 5: - copy_general_general_dims(src, dst, 0, 0); + copy_general_general_dims( + src, dst, data_shape, i_strides, o_strides, i_offset, o_offset); return; } int size = std::accumulate( - src.shape().begin() - 5, src.shape().end(), 1, std::multiplies()); + data_shape.begin() - 5, data_shape.end(), 1, std::multiplies()); for (int i = 0; i < src.size(); i += size) { - size_t offset_src = elem_to_loc(i, src.shape(), src.strides()); - size_t offset_dst = elem_to_loc(i, dst.shape(), dst.strides()); - copy_general_general_dims(src, dst, offset_src, offset_dst); + stride_t src_offset = i_offset + elem_to_loc(i, data_shape, i_strides); + stride_t dst_offset = o_offset + elem_to_loc(i, dst.shape(), o_strides); + copy_general_general_dims( + src, dst, data_shape, i_strides, o_strides, src_offset, dst_offset); } } template -void copy(const array& src, array& dst, CopyType ctype) { +inline void copy_general_general(const array& src, array& dst) { + return copy_general_general( + src, dst, src.shape(), src.strides(), dst.strides(), 0, 0); +} + +template +void copy(const array& src, array& dst, CopyType ctype, Args... args) { switch (ctype) { case CopyType::Scalar: copy_single(src, dst); @@ -187,54 +281,103 @@ void copy(const array& src, array& dst, CopyType ctype) { copy_vector(src, dst); return; case CopyType::General: - copy_general(src, dst); + copy_general(src, dst, args...); return; case CopyType::GeneralGeneral: - copy_general_general(src, dst); + copy_general_general(src, dst, args...); } } -template -void copy(const array& src, array& dst, CopyType ctype) { +template +void copy(const array& src, array& dst, CopyType ctype, Args... args) { switch (dst.dtype()) { case bool_: - copy(src, dst, ctype); + copy(src, dst, ctype, args...); break; case uint8: - copy(src, dst, ctype); + copy(src, dst, ctype, args...); break; case uint16: - copy(src, dst, ctype); + copy(src, dst, ctype, args...); break; case uint32: - copy(src, dst, ctype); + copy(src, dst, ctype, args...); break; case uint64: - copy(src, dst, ctype); + copy(src, dst, ctype, args...); break; case int8: - copy(src, dst, ctype); + copy(src, dst, ctype, args...); break; case int16: - copy(src, dst, ctype); + copy(src, dst, ctype, args...); break; case int32: - copy(src, dst, ctype); + copy(src, dst, ctype, args...); break; case int64: - copy(src, dst, ctype); + copy(src, dst, ctype, args...); break; case float16: - copy(src, dst, ctype); + copy(src, dst, ctype, args...); break; case float32: - copy(src, dst, ctype); + copy(src, dst, ctype, args...); break; case bfloat16: - copy(src, dst, ctype); + copy(src, dst, ctype, args...); break; case complex64: - copy(src, dst, ctype); + copy(src, dst, ctype, args...); + break; + } +} + +template +inline void copy_inplace_dispatch( + const array& src, + array& dst, + CopyType ctype, + Args... args) { + switch (src.dtype()) { + case bool_: + copy(src, dst, ctype, args...); + break; + case uint8: + copy(src, dst, ctype, args...); + break; + case uint16: + copy(src, dst, ctype, args...); + break; + case uint32: + copy(src, dst, ctype, args...); + break; + case uint64: + copy(src, dst, ctype, args...); + break; + case int8: + copy(src, dst, ctype, args...); + break; + case int16: + copy(src, dst, ctype, args...); + break; + case int32: + copy(src, dst, ctype, args...); + break; + case int64: + copy(src, dst, ctype, args...); + break; + case float16: + copy(src, dst, ctype, args...); + break; + case float32: + copy(src, dst, ctype, args...); + break; + case bfloat16: + copy(src, dst, ctype, args...); + break; + case complex64: + copy(src, dst, ctype, args...); break; } } @@ -242,47 +385,7 @@ void copy(const array& src, array& dst, CopyType ctype) { } // namespace void copy_inplace(const array& src, array& dst, CopyType ctype) { - switch (src.dtype()) { - case bool_: - copy(src, dst, ctype); - break; - case uint8: - copy(src, dst, ctype); - break; - case uint16: - copy(src, dst, ctype); - break; - case uint32: - copy(src, dst, ctype); - break; - case uint64: - copy(src, dst, ctype); - break; - case int8: - copy(src, dst, ctype); - break; - case int16: - copy(src, dst, ctype); - break; - case int32: - copy(src, dst, ctype); - break; - case int64: - copy(src, dst, ctype); - break; - case float16: - copy(src, dst, ctype); - break; - case float32: - copy(src, dst, ctype); - break; - case bfloat16: - copy(src, dst, ctype); - break; - case complex64: - copy(src, dst, ctype); - break; - } + return copy_inplace_dispatch(src, dst, ctype); } void copy(const array& src, array& dst, CopyType ctype) { @@ -312,4 +415,62 @@ void copy(const array& src, array& dst, CopyType ctype) { copy_inplace(src, dst, ctype); } +template +void copy_inplace( + const array& src, + array& dst, + const std::vector& data_shape, + const std::vector& i_strides, + const std::vector& o_strides, + int64_t i_offset, + int64_t o_offset, + CopyType ctype) { + switch (ctype) { + case CopyType::General: + case CopyType::GeneralGeneral: + return copy_inplace_dispatch( + src, + dst, + ctype, + data_shape, + i_strides, + o_strides, + i_offset, + o_offset); + + case CopyType::Scalar: + case CopyType::Vector: + return copy_inplace_dispatch(src, dst, ctype); + } +} + +template <> +void copy_inplace( + const array& src, + array& dst, + const std::vector& data_shape, + const std::vector& i_strides, + const std::vector& o_strides, + int64_t i_offset, + int64_t o_offset, + CopyType ctype) { + switch (ctype) { + case CopyType::General: + case CopyType::GeneralGeneral: + return copy_inplace_dispatch( + src, + dst, + ctype, + data_shape, + i_strides, + o_strides, + i_offset, + o_offset); + + case CopyType::Scalar: + case CopyType::Vector: + return copy_inplace_dispatch(src, dst, ctype); + } +} + } // namespace mlx::core diff --git a/mlx/backend/common/copy.h b/mlx/backend/common/copy.h index 0affddec3..b0106257a 100644 --- a/mlx/backend/common/copy.h +++ b/mlx/backend/common/copy.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once @@ -26,4 +26,15 @@ enum class CopyType { void copy(const array& src, array& dst, CopyType ctype); void copy_inplace(const array& src, array& dst, CopyType ctype); +template +void copy_inplace( + const array& src, + array& dst, + const std::vector& data_shape, + const std::vector& i_strides, + const std::vector& o_strides, + int64_t i_offset, + int64_t o_offset, + CopyType ctype); + } // namespace mlx::core diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 4f43f1965..83fb86da9 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -94,6 +94,7 @@ DEFAULT(Sign) DEFAULT(Sin) DEFAULT(Sinh) DEFAULT(Slice) +DEFAULT(SliceUpdate) DEFAULT(Softmax) DEFAULT(Sort) DEFAULT_MULTI(Split) diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index ef00721aa..f8a5e7936 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -651,36 +651,33 @@ void Sinh::eval(const std::vector& inputs, array& out) { } } -void Slice::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - if (out.size() == 0) { - out.set_data(nullptr); - return; - } - auto& in = inputs[0]; - auto strides = in.strides(); - auto flags = in.flags(); - size_t data_offset = 0; +std::tuple> Slice::prepare_slice( + const array& in) { + int64_t data_offset = 0; + bool copy_needed = false; + std::vector inp_strides(in.ndim(), 0); for (int i = 0; i < in.ndim(); ++i) { data_offset += start_indices_[i] * in.strides()[i]; - strides[i] *= strides_[i]; + inp_strides[i] = in.strides()[i] * strides_[i]; + + copy_needed |= strides_[i] < 0; } + return std::make_tuple(copy_needed, data_offset, inp_strides); +} + +void Slice::shared_buffer_slice( + const array& in, + const std::vector& out_strides, + size_t data_offset, + array& out) { // Compute row/col contiguity - size_t data_size = 1; - size_t f_stride = 1; - size_t b_stride = 1; - flags.row_contiguous = true; - flags.col_contiguous = true; - for (int i = 0, ri = out.ndim() - 1; ri >= 0; i++, ri--) { - flags.col_contiguous &= strides[i] == f_stride || out.shape(i) == 1; - flags.row_contiguous &= strides[ri] == b_stride || out.shape(ri) == 1; - f_stride *= out.shape(i); - b_stride *= out.shape(ri); - if (strides[i] > 0) { - data_size *= out.shape(i); - } - } + auto [data_size, is_row_contiguous, is_col_contiguous] = + check_contiguity(out.shape(), out_strides); + + auto flags = in.flags(); + flags.row_contiguous = is_row_contiguous; + flags.col_contiguous = is_col_contiguous; if (data_size == 1) { // Broadcasted scalar array is contiguous. @@ -694,7 +691,87 @@ void Slice::eval(const std::vector& inputs, array& out) { flags.contiguous &= flags.row_contiguous || flags.col_contiguous; } - out.copy_shared_buffer(in, strides, flags, data_size, data_offset); + out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset); +} + +void Slice::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + + // Calculate out strides, initial offset and if copy needs to be made + auto [copy_needed, data_offset, inp_strides] = prepare_slice(in); + + // Do copy if needed + if (copy_needed) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + std::vector ostrides{out.strides().begin(), out.strides().end()}; + copy_inplace( + /* const array& src = */ in, + /* array& dst = */ out, + /* const std::vector& data_shape = */ out.shape(), + /* const std::vector& i_strides = */ inp_strides, + /* const std::vector& o_strides = */ ostrides, + /* int64_t i_offset = */ data_offset, + /* int64_t o_offset = */ 0, + /* CopyType ctype = */ CopyType::General); + } else { + std::vector ostrides{inp_strides.begin(), inp_strides.end()}; + shared_buffer_slice(in, ostrides, data_offset, out); + } +} + +std::tuple> SliceUpdate::prepare_slice( + const array& in) { + int64_t data_offset = 0; + std::vector inp_strides(in.ndim(), 0); + for (int i = 0; i < in.ndim(); ++i) { + data_offset += start_indices_[i] * in.strides()[i]; + inp_strides[i] = in.strides()[i] * strides_[i]; + } + + return std::make_tuple(data_offset, inp_strides); +} + +void SliceUpdate::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + auto& upd = inputs[1]; + + if (upd.size() == 0) { + out.copy_shared_buffer(in); + return; + } + + // Check if materialization is needed + auto ctype = in.flags().contiguous && in.size() == in.data_size() + ? CopyType::Vector + : CopyType::General; + copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype); + + // Calculate out strides, initial offset and if copy needs to be made + auto [data_offset, out_strides] = prepare_slice(out); + + // Do copy + std::vector upd_strides{upd.strides().begin(), upd.strides().end()}; + copy_inplace( + /* const array& src = */ upd, + /* array& dst = */ out, + /* const std::vector& data_shape = */ upd.shape(), + /* const std::vector& i_strides = */ upd_strides, + /* const std::vector& o_strides = */ out_strides, + /* int64_t i_offset = */ 0, + /* int64_t o_offset = */ data_offset, + /* CopyType ctype = */ CopyType::GeneralGeneral); } void Split::eval( diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 8023789dc..24ded411b 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once @@ -8,11 +8,12 @@ namespace mlx::core { -inline size_t elem_to_loc( +template +inline stride_t elem_to_loc( int elem, const std::vector& shape, - const std::vector& strides) { - size_t loc = 0; + const std::vector& strides) { + stride_t loc = 0; for (int i = shape.size() - 1; i >= 0; --i) { auto q_and_r = ldiv(elem, shape[i]); loc += q_and_r.rem * strides[i]; @@ -34,10 +35,11 @@ inline size_t elem_to_loc(int elem, const array& a) { // // When multiple arrays are passed they should all have the same shape. The // collapsed axes are also the same so one shape is returned. -inline std::tuple, std::vector>> +template +inline std::tuple, std::vector>> collapse_contiguous_dims( const std::vector& shape, - const std::vector> strides) { + const std::vector> strides) { // Make a vector that has axes separated with -1. Collapse all axes between // -1. std::vector to_collapse; @@ -45,7 +47,7 @@ collapse_contiguous_dims( to_collapse.push_back(0); for (int i = 1; i < shape.size(); i++) { bool contiguous = true; - for (const std::vector& st : strides) { + for (const std::vector& st : strides) { if (st[i] * shape[i] != st[i - 1]) { contiguous = false; } @@ -62,7 +64,7 @@ collapse_contiguous_dims( } std::vector out_shape; - std::vector> out_strides(strides.size()); + std::vector> out_strides(strides.size()); for (int i = 0; i < to_collapse.size(); i++) { int current_shape = shape[to_collapse[i]]; while (to_collapse[++i] != -1) { @@ -70,7 +72,7 @@ collapse_contiguous_dims( } out_shape.push_back(current_shape); for (int j = 0; j < strides.size(); j++) { - const std::vector& st = strides[j]; + const std::vector& st = strides[j]; out_strides[j].push_back(st[to_collapse[i - 1]]); } } @@ -94,4 +96,27 @@ collapse_contiguous_dims(Arrays... xs) { std::vector{std::forward(xs)...}); } +template +inline auto check_contiguity( + const std::vector& shape, + const std::vector& strides) { + size_t data_size = 1; + size_t f_stride = 1; + size_t b_stride = 1; + bool is_row_contiguous = true; + bool is_col_contiguous = true; + + for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) { + is_row_contiguous &= strides[i] == f_stride || shape[i] == 1; + is_col_contiguous &= strides[ri] == b_stride || shape[ri] == 1; + f_stride *= shape[i]; + b_stride *= shape[ri]; + if (strides[i] > 0) { + data_size *= shape[i]; + } + } + + return std::make_tuple(data_size, is_row_contiguous, is_col_contiguous); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index e023165c0..0885f5691 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include @@ -37,15 +37,22 @@ void copy_gpu(const array& in, array& out, CopyType ctype) { copy_gpu(in, out, ctype, out.primitive().stream()); } +template void copy_gpu_inplace( const array& in, array& out, + const std::vector& data_shape, + const std::vector& strides_in_pre, + const std::vector& strides_out_pre, + int64_t inp_offset, + int64_t out_offset, CopyType ctype, const Stream& s) { // Try to collapse contiguous dims - auto [shape, strides] = collapse_contiguous_dims(in, out); - auto& strides_in = strides[0]; - auto& strides_out = strides[1]; + auto [shape, strides] = collapse_contiguous_dims( + data_shape, std::vector{strides_in_pre, strides_out_pre}); + auto& strides_in_ = strides[0]; + auto& strides_out_ = strides[1]; auto& d = metal::device(s.device); std::ostringstream kname; @@ -72,39 +79,44 @@ void copy_gpu_inplace( auto compute_encoder = d.get_command_encoder(s.index); compute_encoder->setComputePipelineState(kernel); bool donate_in = in.data_shared_ptr() == nullptr; - set_array_buffer(compute_encoder, donate_in ? out : in, 0); - set_array_buffer(compute_encoder, out, 1); + + inp_offset *= size_of(in.dtype()); + out_offset *= size_of(out.dtype()); + + set_array_buffer(compute_encoder, donate_in ? out : in, inp_offset, 0); + set_array_buffer(compute_encoder, out, out_offset, 1); if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { - size_t ndim = shape.size(); + int ndim = shape.size(); + std::vector strides_in{strides_in_.begin(), strides_in_.end()}; + std::vector strides_out{strides_out_.begin(), strides_out_.end()}; + if (ndim > 3) { - compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2); - compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 3); - if (ctype == CopyType::GeneralGeneral) { - compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 4); - } - } else { - // The shape is implicit in the grid for <= 3D - compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 2); - if (ctype == CopyType::GeneralGeneral) { - compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 3); - } + set_vector_bytes(compute_encoder, shape, ndim, 2); + } + set_vector_bytes(compute_encoder, strides_in, ndim, 3); + if (ctype == CopyType::GeneralGeneral) { + set_vector_bytes(compute_encoder, strides_out, ndim, 4); } if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { - compute_encoder->setBytes( - &ndim, sizeof(int), (ctype == CopyType::GeneralGeneral) ? 5 : 4); + compute_encoder->setBytes(&ndim, sizeof(int), 5); } int dim0 = ndim > 0 ? shape[ndim - 1] : 1; int dim1 = ndim > 1 ? shape[ndim - 2] : 1; - int rest = in.size() / (dim0 * dim1); + + size_t data_size = 1; + for (auto& s : shape) + data_size *= s; + int rest = data_size / (dim0 * dim1); // NB assuming thread_group_size is a power of 2 larger than 32 x 32 NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); if (thread_group_size != 1024) { throw std::runtime_error("[Metal::copy] Must use 1024 sized block"); } + auto group_dims = get_block_dims(dim0, dim1, rest); MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); compute_encoder->dispatchThreads(grid_dims, group_dims); @@ -120,4 +132,25 @@ void copy_gpu_inplace( } } +void copy_gpu_inplace( + const array& in, + array& out, + CopyType ctype, + const Stream& s) { + return copy_gpu_inplace( + in, out, in.shape(), in.strides(), out.strides(), 0, 0, ctype, s); +} + +void copy_gpu_inplace( + const array& in, + array& out, + const std::vector& istride, + int64_t ioffset, + CopyType ctype, + const Stream& s) { + std::vector ostrides{out.strides().begin(), out.strides().end()}; + return copy_gpu_inplace( + in, out, in.shape(), istride, ostrides, ioffset, 0, ctype, s); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/copy.h b/mlx/backend/metal/copy.h index c08ae2cf7..c810868f5 100644 --- a/mlx/backend/metal/copy.h +++ b/mlx/backend/metal/copy.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once @@ -7,12 +7,34 @@ namespace mlx::core { +// Generic copy inplace +template +void copy_gpu_inplace( + const array& in, + array& out, + const std::vector& data_shape, + const std::vector& i_strides, + const std::vector& o_strides, + int64_t i_offset, + int64_t o_offset, + CopyType ctype, + const Stream& s); + void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s); void copy_gpu(const array& src, array& out, CopyType ctype); + void copy_gpu_inplace( const array& src, array& out, CopyType ctype, const Stream& s); +void copy_gpu_inplace( + const array& in, + array& out, + const std::vector& istride, + int64_t ioffset, + CopyType ctype, + const Stream& s); + } // namespace mlx::core diff --git a/mlx/backend/metal/kernels/copy.metal b/mlx/backend/metal/kernels/copy.metal index 9808cc958..4cfc3b68f 100644 --- a/mlx/backend/metal/kernels/copy.metal +++ b/mlx/backend/metal/kernels/copy.metal @@ -1,29 +1,29 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/utils.h" template [[kernel]] void copy_s( - device const T* src, - device U* dst, + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], uint index [[thread_position_in_grid]]) { dst[index] = static_cast(src[0]); } template [[kernel]] void copy_v( - device const T* src, - device U* dst, + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], uint index [[thread_position_in_grid]]) { dst[index] = static_cast(src[index]); } template [[kernel]] void copy_g_nd1( - device const T* src, - device U* dst, - constant const size_t& src_stride, + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t& src_stride [[buffer(3)]], uint index [[thread_position_in_grid]]) { auto src_idx = elem_to_loc_1(index, src_stride); dst[index] = static_cast(src[src_idx]); @@ -31,61 +31,61 @@ template template [[kernel]] void copy_g_nd2( - device const T* src, - device U* dst, - constant const size_t src_strides[2], + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { auto src_idx = elem_to_loc_2(index, src_strides); - size_t dst_idx = index.x + (size_t)grid_dim.x * index.y; + int64_t dst_idx = index.x + (int64_t)grid_dim.x * index.y; dst[dst_idx] = static_cast(src[src_idx]); } template [[kernel]] void copy_g_nd3( - device const T* src, - device U* dst, - constant const size_t src_strides[3], + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto src_idx = elem_to_loc_3(index, src_strides); - size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); dst[dst_idx] = static_cast(src[src_idx]); } template [[kernel]] void copy_g_nd( - device const T* src, - device U* dst, - constant const int src_shape[DIM], - constant const size_t src_strides[DIM], + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto src_idx = elem_to_loc_nd(index, src_shape, src_strides); - size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); dst[dst_idx] = static_cast(src[src_idx]); } template [[kernel]] void copy_g( - device const T* src, - device U* dst, - constant const int* src_shape, - constant const size_t* src_strides, - constant const int& ndim, + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int& ndim [[buffer(5)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); - size_t dst_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + int64_t dst_idx = index.x + (int64_t)grid_dim.x * (index.y + (int64_t)grid_dim.y * index.z); dst[dst_idx] = static_cast(src[src_idx]); } template [[kernel]] void copy_gg_nd1( - device const T* src, - device U* dst, - constant const size_t& src_stride, - constant const size_t& dst_stride, + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t& src_stride [[buffer(3)]], + constant const int64_t& dst_stride [[buffer(4)]], uint index [[thread_position_in_grid]]) { auto src_idx = elem_to_loc_1(index, src_stride); auto dst_idx = elem_to_loc_1(index, dst_stride); @@ -94,10 +94,10 @@ template template [[kernel]] void copy_gg_nd2( - device const T* src, - device U* dst, - constant const size_t src_strides[2], - constant const size_t dst_strides[2], + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], uint2 index [[thread_position_in_grid]]) { auto src_idx = elem_to_loc_2(index, src_strides); auto dst_idx = elem_to_loc_2(index, dst_strides); @@ -106,10 +106,10 @@ template template [[kernel]] void copy_gg_nd3( - device const T* src, - device U* dst, - constant const size_t src_strides[3], - constant const size_t dst_strides[3], + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], uint3 index [[thread_position_in_grid]]) { auto src_idx = elem_to_loc_3(index, src_strides); auto dst_idx = elem_to_loc_3(index, dst_strides); @@ -118,11 +118,11 @@ template template [[kernel]] void copy_gg_nd( - device const T* src, - device U* dst, - constant const int src_shape[DIM], - constant const size_t src_strides[DIM], - constant const size_t dst_strides[DIM], + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], uint3 index [[thread_position_in_grid]]) { auto src_idx = elem_to_loc_nd(index, src_shape, src_strides); auto dst_idx = elem_to_loc_nd(index, src_shape, dst_strides); @@ -131,12 +131,12 @@ template template [[kernel]] void copy_gg( - device const T* src, - device U* dst, - constant const int* src_shape, - constant const size_t* src_strides, - constant const size_t* dst_strides, - constant const int& ndim, + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + constant const int& ndim [[buffer(5)]], uint3 index [[thread_position_in_grid]]) { auto src_idx = elem_to_loc(index, src_shape, src_strides, ndim); auto dst_idx = elem_to_loc(index, src_shape, dst_strides, ndim); @@ -146,70 +146,70 @@ template #define instantiate_copy(name, itype, otype, ctype) \ template [[host_name(name)]] \ [[kernel]] void copy_##ctype( \ - device const itype* src, \ - device otype* dst, \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ uint index [[thread_position_in_grid]]); #define instantiate_copy_g_dim(name, itype, otype, dims) \ template [[host_name(name "_" #dims)]] \ [[kernel]] void copy_g_nd( \ - device const itype* src, \ - device otype* dst, \ - constant const int src_shape[dims], \ - constant const size_t src_strides[dims], \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int* src_shape [[buffer(2)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ uint3 index [[thread_position_in_grid]], \ uint3 grid_dim [[threads_per_grid]]); \ template [[host_name("g" name "_" #dims)]] \ [[kernel]] void copy_gg_nd( \ - device const itype* src, \ - device otype* dst, \ - constant const int src_shape[dims], \ - constant const size_t src_strides[dims], \ - constant const size_t dst_strides[dims], \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int* src_shape [[buffer(2)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + constant const int64_t* dst_strides [[buffer(4)]], \ uint3 index [[thread_position_in_grid]]); #define instantiate_copy_g_nd(name, itype, otype) \ template [[host_name(name "_1")]] \ [[kernel]] void copy_g_nd1( \ - device const itype* src, \ - device otype* dst, \ - constant const size_t& src_stride, \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int64_t& src_stride [[buffer(3)]], \ uint index [[thread_position_in_grid]]); \ template [[host_name(name "_2")]] \ [[kernel]] void copy_g_nd2( \ - device const itype* src, \ - device otype* dst, \ - constant const size_t src_strides[2], \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ uint2 index [[thread_position_in_grid]], \ uint2 grid_dim [[threads_per_grid]]); \ template [[host_name(name "_3")]] \ [[kernel]] void copy_g_nd3( \ - device const itype* src, \ - device otype* dst, \ - constant const size_t src_strides[3], \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ uint3 index [[thread_position_in_grid]], \ uint3 grid_dim [[threads_per_grid]]); \ template [[host_name("g" name "_1")]] \ [[kernel]] void copy_gg_nd1( \ - device const itype* src, \ - device otype* dst, \ - constant const size_t& src_stride, \ - constant const size_t& dst_stride, \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int64_t& src_stride [[buffer(3)]], \ + constant const int64_t& dst_stride [[buffer(4)]], \ uint index [[thread_position_in_grid]]); \ template [[host_name("g" name "_2")]] \ [[kernel]] void copy_gg_nd2( \ - device const itype* src, \ - device otype* dst, \ - constant const size_t src_strides[2], \ - constant const size_t dst_strides[2], \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + constant const int64_t* dst_strides [[buffer(4)]], \ uint2 index [[thread_position_in_grid]]); \ template [[host_name("g" name "_3")]] \ [[kernel]] void copy_gg_nd3( \ - device const itype* src, \ - device otype* dst, \ - constant const size_t src_strides[3], \ - constant const size_t dst_strides[3], \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + constant const int64_t* dst_strides [[buffer(4)]], \ uint3 index [[thread_position_in_grid]]); \ instantiate_copy_g_dim(name, itype, otype, 4) \ instantiate_copy_g_dim(name, itype, otype, 5) @@ -218,21 +218,21 @@ template #define instantiate_copy_g(name, itype, otype) \ template [[host_name(name)]] \ [[kernel]] void copy_g( \ - device const itype* src, \ - device otype* dst, \ - constant const int* src_shape, \ - constant const size_t* src_strides, \ - constant const int& ndim, \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int* src_shape [[buffer(2)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + constant const int& ndim [[buffer(5)]], \ uint3 index [[thread_position_in_grid]], \ uint3 grid_dim [[threads_per_grid]]); \ template [[host_name("g" name)]] \ [[kernel]] void copy_gg( \ - device const itype* src, \ - device otype* dst, \ - constant const int* src_shape, \ - constant const size_t* src_strides, \ - constant const size_t* dst_strides, \ - constant const int& ndim, \ + device const itype* src [[buffer(0)]], \ + device otype* dst [[buffer(1)]], \ + constant const int* src_shape [[buffer(2)]], \ + constant const int64_t* src_strides [[buffer(3)]], \ + constant const int64_t* dst_strides [[buffer(4)]], \ + constant const int& ndim [[buffer(5)]], \ uint3 index [[thread_position_in_grid]]); #define instantiate_copy_all(tname, itype, otype) \ diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index 9c3d20b30..641df11f0 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once @@ -65,12 +65,18 @@ struct Limits { // Indexing utils /////////////////////////////////////////////////////////////////////////////// -inline size_t elem_to_loc( +#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") + +/////////////////////////////////////////////////////////////////////////////// +// Single Array with generic dims + +template +METAL_FUNC stride_t elem_to_loc( uint elem, device const int* shape, - device const size_t* strides, + device const stride_t* strides, int ndim) { - size_t loc = 0; + stride_t loc = 0; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { loc += (elem % shape[i]) * strides[i]; elem /= shape[i]; @@ -78,12 +84,13 @@ inline size_t elem_to_loc( return loc; } -inline size_t elem_to_loc( +template +METAL_FUNC stride_t elem_to_loc( uint elem, constant const int* shape, - constant const size_t* strides, + constant const stride_t* strides, int ndim) { - size_t loc = 0; + stride_t loc = 0; for (int i = ndim - 1; i >= 0 && elem > 0; --i) { loc += (elem % shape[i]) * strides[i]; elem /= shape[i]; @@ -91,52 +98,59 @@ inline size_t elem_to_loc( return loc; } -template -inline uint3 elem_to_loc_3_nd( +// Non templated version to handle arbitrary dims +template +METAL_FUNC stride_t elem_to_loc( uint3 elem, - constant const int shape[NDIM], - constant const size_t a_strides[NDIM], - constant const size_t b_strides[NDIM], - constant const size_t c_strides[NDIM]) { - uint3 loc = { - static_cast( - elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]), - static_cast( - elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]), - static_cast( - elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])}; - for (int d = NDIM - 3; d >= 0; --d) { - uint l = elem.z % shape[d]; - loc.x += l * a_strides[d]; - loc.y += l * b_strides[d]; - loc.z += l * c_strides[d]; + constant const int* shape, + constant const stride_t* strides, + int ndim) { + stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2]; + for (int d = ndim - 3; d >= 0; --d) { + loc += (elem.z % shape[d]) * strides[d]; elem.z /= shape[d]; } return loc; } +/////////////////////////////////////////////////////////////////////////////// +// Single Array with fixed N dims + +template +METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) { + return elem * stride; +} + +template +METAL_FUNC stride_t +elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) { + return elem.x * strides[1] + elem.y * strides[0]; +} + +template +METAL_FUNC stride_t +elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) { + return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0]; +} + template -inline uint2 elem_to_loc_2_nd( - uint3 elem, - constant const int shape[NDIM], - constant const size_t a_strides[NDIM], - constant const size_t b_strides[NDIM]) { - uint2 loc = { - static_cast( - elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]), - static_cast( - elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])}; - for (int d = NDIM - 3; d >= 0; --d) { - uint l = elem.z % shape[d]; - loc.x += l * a_strides[d]; - loc.y += l * b_strides[d]; - elem.z /= shape[d]; +METAL_FUNC size_t elem_to_loc_nd( + uint elem, + device const int* shape, + device const size_t* strides) { + size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1]; + + MLX_MTL_PRAGMA_UNROLL + for (int d = NDIM - 2; d >= 0; --d) { + elem /= shape[d + 1]; + loc += (elem % shape[d]) * strides[d]; } + return loc; } template -inline size_t elem_to_loc_nd( +METAL_FUNC size_t elem_to_loc_nd( uint3 elem, constant const int shape[NDIM], constant const size_t strides[NDIM]) { @@ -148,33 +162,59 @@ inline size_t elem_to_loc_nd( return loc; } -inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) { - return elem * stride; +template +METAL_FUNC int64_t elem_to_loc_nd( + uint elem, + constant const int shape[NDIM], + constant const int64_t strides[NDIM]) { + int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1]; + + MLX_MTL_PRAGMA_UNROLL + for (int d = NDIM - 2; d >= 0; --d) { + elem /= shape[d + 1]; + loc += (elem % shape[d]) * strides[d]; + } + + return loc; } -inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) { - return elem.x * strides[1] + elem.y * strides[0]; -} - -inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) { - return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0]; -} - -// Non templated version to handle arbitrary dims -inline size_t elem_to_loc( +template +METAL_FUNC int64_t elem_to_loc_nd( uint3 elem, - constant const int* shape, - constant const size_t* strides, - int ndim) { - size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2]; - for (int d = ndim - 3; d >= 0; --d) { + constant const int shape[NDIM], + constant const int64_t strides[NDIM]) { + int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2]; + for (int d = NDIM - 3; d >= 0; --d) { loc += (elem.z % shape[d]) * strides[d]; elem.z /= shape[d]; } return loc; } -inline uint3 elem_to_loc_3_nd( +/////////////////////////////////////////////////////////////////////////////// +// Multiple Arrays with generic dims + +METAL_FUNC uint2 elem_to_loc_2_nd( + uint3 elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + int ndim) { + uint2 loc = { + static_cast( + elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), + static_cast( + elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])}; + for (int d = ndim - 3; d >= 0; --d) { + uint l = elem.z % shape[d]; + loc.x += l * a_strides[d]; + loc.y += l * b_strides[d]; + elem.z /= shape[d]; + } + return loc; +} + +METAL_FUNC uint3 elem_to_loc_3_nd( uint3 elem, constant const int* shape, constant const size_t* a_strides, @@ -198,18 +238,21 @@ inline uint3 elem_to_loc_3_nd( return loc; } -inline uint2 elem_to_loc_2_nd( +/////////////////////////////////////////////////////////////////////////////// +// Multiple Arrays with fixed N dims + +template +METAL_FUNC uint2 elem_to_loc_2_nd( uint3 elem, - constant const int* shape, - constant const size_t* a_strides, - constant const size_t* b_strides, - int ndim) { + constant const int shape[NDIM], + constant const size_t a_strides[NDIM], + constant const size_t b_strides[NDIM]) { uint2 loc = { static_cast( - elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), + elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]), static_cast( - elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])}; - for (int d = ndim - 3; d >= 0; --d) { + elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])}; + for (int d = NDIM - 3; d >= 0; --d) { uint l = elem.z % shape[d]; loc.x += l * a_strides[d]; loc.y += l * b_strides[d]; @@ -219,55 +262,26 @@ inline uint2 elem_to_loc_2_nd( } template -inline uint elem_to_loc_nd( - uint elem, - device const int* shape, - device const size_t* strides); - -template <> -inline uint elem_to_loc_nd<1>( - uint elem, - device const int* shape, - device const size_t* strides) { - return (elem % shape[0]) * strides[0]; -} - -template <> -inline uint elem_to_loc_nd<2>( - uint elem, - device const int* shape, - device const size_t* strides) { - uint loc = (elem % shape[1]) * strides[1]; - elem /= shape[1]; - loc += (elem % shape[0]) * strides[0]; - return loc; -} - -template <> -inline uint elem_to_loc_nd<3>( - uint elem, - device const int* shape, - device const size_t* strides) { - uint loc = (elem % shape[2]) * strides[2]; - elem /= shape[2]; - loc += (elem % shape[1]) * strides[1]; - elem /= shape[1]; - loc += (elem % shape[0]) * strides[0]; - return loc; -} - -template <> -inline uint elem_to_loc_nd<4>( - uint elem, - device const int* shape, - device const size_t* strides) { - uint loc = (elem % shape[3]) * strides[3]; - elem /= shape[3]; - loc += (elem % shape[2]) * strides[2]; - elem /= shape[2]; - loc += (elem % shape[1]) * strides[1]; - elem /= shape[1]; - loc += (elem % shape[0]) * strides[0]; +METAL_FUNC uint3 elem_to_loc_3_nd( + uint3 elem, + constant const int shape[NDIM], + constant const size_t a_strides[NDIM], + constant const size_t b_strides[NDIM], + constant const size_t c_strides[NDIM]) { + uint3 loc = { + static_cast( + elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]), + static_cast( + elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]), + static_cast( + elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])}; + for (int d = NDIM - 3; d >= 0; --d) { + uint l = elem.z % shape[d]; + loc.x += l * a_strides[d]; + loc.y += l * b_strides[d]; + loc.z += l * c_strides[d]; + elem.z /= shape[d]; + } return loc; } diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 2c127ed8e..51033997f 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -206,7 +206,7 @@ inline auto collapse_batches(const array& a, const array& b) { std::vector B_bstride{b.strides().begin(), b.strides().end() - 2}; auto [batch_shape, batch_strides] = - collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride}); + collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride}); auto A_batch_stride = batch_strides[0]; auto B_batch_stride = batch_strides[1]; @@ -237,8 +237,8 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) { std::vector B_bstride{b.strides().begin(), b.strides().end() - 2}; std::vector C_bstride{c.strides().begin(), c.strides().end() - 2}; - auto [batch_shape, batch_strides] = - collapse_contiguous_dims(A_bshape, {A_bstride, B_bstride, C_bstride}); + auto [batch_shape, batch_strides] = collapse_contiguous_dims( + A_bshape, std::vector{A_bstride, B_bstride, C_bstride}); auto A_batch_stride = batch_strides[0]; auto B_batch_stride = batch_strides[1]; diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 6908b6905..b7eeed0e1 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -865,7 +865,73 @@ void Sqrt::eval_gpu(const std::vector& inputs, array& out) { } void Slice::eval_gpu(const std::vector& inputs, array& out) { - eval(inputs, out); + assert(inputs.size() == 1); + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + + // Calculate out strides, initial offset and if copy needs to be made + auto [copy_needed, data_offset, inp_strides] = prepare_slice(in); + + // Do copy if needed + if (copy_needed) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + std::vector ostrides{out.strides().begin(), out.strides().end()}; + copy_gpu_inplace( + /* const array& in = */ in, + /* array& out = */ out, + /* const std::vector& data_shape = */ out.shape(), + /* const std::vector& i_strides = */ inp_strides, + /* const std::vector& o_strides = */ ostrides, + /* int64_t i_offset = */ data_offset, + /* int64_t o_offset = */ 0, + /* CopyType ctype = */ CopyType::General, + /* const Stream& s = */ stream()); + } else { + std::vector ostrides{inp_strides.begin(), inp_strides.end()}; + shared_buffer_slice(in, ostrides, data_offset, out); + } +} + +void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + auto& upd = inputs[1]; + + if (upd.size() == 0) { + out.copy_shared_buffer(in); + return; + } + + // Check if materialization is needed + auto ctype = in.flags().contiguous && in.size() == in.data_size() + ? CopyType::Vector + : CopyType::General; + copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); + + // Calculate out strides, initial offset and if copy needs to be made + auto [data_offset, out_strides] = prepare_slice(out); + + // Do copy + std::vector upd_strides{upd.strides().begin(), upd.strides().end()}; + copy_gpu_inplace( + /* const array& src = */ upd, + /* array& dst = */ out, + /* const std::vector& data_shape = */ upd.shape(), + /* const std::vector& i_strides = */ upd_strides, + /* const std::vector& o_strides = */ out_strides, + /* int64_t i_offset = */ 0, + /* int64_t o_offset = */ data_offset, + /* CopyType ctype = */ CopyType::GeneralGeneral, + /* const Stream& s = */ stream()); } void StopGradient::eval_gpu(const std::vector& inputs, array& out) { diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h index a8e4cfd44..10aea8ab0 100644 --- a/mlx/backend/metal/utils.h +++ b/mlx/backend/metal/utils.h @@ -9,16 +9,43 @@ namespace mlx::core { namespace { -void set_array_buffer( - MTL::ComputeCommandEncoder* enc, - const array& a, - int idx) { +inline void +set_array_buffer(MTL::ComputeCommandEncoder* enc, const array& a, int idx) { auto a_buf = static_cast(a.buffer().ptr()); auto offset = a.data() - static_cast(const_cast(a_buf)->contents()); enc->setBuffer(a_buf, offset, idx); } +inline void set_array_buffer( + MTL::ComputeCommandEncoder* enc, + const array& a, + int64_t offset, + int idx) { + auto a_buf = static_cast(a.buffer().ptr()); + auto base_offset = a.data() - + static_cast(const_cast(a_buf)->contents()); + base_offset += offset; + enc->setBuffer(a_buf, base_offset, idx); +} + +template +inline void set_vector_bytes( + MTL::ComputeCommandEncoder* enc, + const std::vector& vec, + size_t nelems, + int idx) { + enc->setBytes(vec.data(), nelems * sizeof(T), idx); +} + +template +inline void set_vector_bytes( + MTL::ComputeCommandEncoder* enc, + const std::vector& vec, + int idx) { + return set_vector_bytes(enc, vec, vec.size(), idx); +} + std::string type_to_name(const array& a) { std::string tname; switch (a.dtype()) { diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index cbb971384..cd5bef5c4 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -87,6 +87,7 @@ NO_GPU(Sign) NO_GPU(Sin) NO_GPU(Sinh) NO_GPU(Slice) +NO_GPU(SliceUpdate) NO_GPU(Softmax) NO_GPU(Sort) NO_GPU_MULTI(Split) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index f86f55f58..e643170a6 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -445,6 +445,60 @@ array expand_dims( return reshape(a, out_shape, s); } +// Slice helper +namespace { + +inline auto normalize_slice( + const std::vector& shape, + std::vector& start, + std::vector& stop, + std::vector& strides) { + std::vector out_shape(shape.size()); + bool has_neg_strides = false; + + for (int i = 0; i < shape.size(); ++i) { + // Following numpy docs + // Negative i and j are interpreted as n + i and n + j where n is + // the number of elements in the corresponding dimension. Negative + // k makes stepping go towards smaller indices + + auto n = shape[i]; + auto s = start[i]; + s = s < 0 ? s + n : s; + auto e = stop[i]; + e = e < 0 ? e + n : e; + + // Note: -ve strides require start >= stop + if (strides[i] < 0) { + has_neg_strides = true; + + // Clamp to bounds + auto st = std::min(s, n - 1); + auto ed = std::max(-1, e); + + start[i] = st; + stop[i] = ed > st ? st : ed; + + auto str = -strides[i]; + out_shape[i] = (start[i] - stop[i] + str - 1) / str; + + } else { + // Clamp to bounds + auto st = std::max(0, std::min(s, n)); + auto ed = std::max(0, std::min(e, n)); + + start[i] = st; + stop[i] = ed < st ? st : ed; + + out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i]; + } + } + + return std::make_pair(has_neg_strides, out_shape); +} + +} // namespace + array slice( const array& a, std::vector start, @@ -459,113 +513,13 @@ array slice( throw std::invalid_argument(msg.str()); } - std::vector negatively_strided_axes; - std::vector> negatively_strided_slices; - std::vector out_shape(a.ndim()); - for (int i = 0; i < a.ndim(); ++i) { - // Following numpy docs - // Negative i and j are interpreted as n + i and n + j where n is - // the number of elements in the corresponding dimension. Negative - // k makes stepping go towards smaller indices + auto [has_neg_strides, out_shape] = + normalize_slice(a.shape(), start, stop, strides); - auto n = a.shape(i); - auto s = start[i]; - s = s < 0 ? s + n : s; - auto e = stop[i]; - e = e < 0 ? e + n : e; - - // Note: We pass positive strides to the primitive and then flip - // the axes later as needed - if (strides[i] < 0) { - negatively_strided_axes.push_back(i); - auto st = std::min(s, n - 1); - auto ed = std::max(e, -1); - negatively_strided_slices.push_back({st, ed, strides[i]}); - start[i] = 0; - stop[i] = n; - strides[i] = 1; - } else { - start[i] = s; - stop[i] = e < s ? s : e; - } - - // Clamp to bounds - start[i] = std::max(0, std::min(start[i], n)); - stop[i] = std::max(0, std::min(stop[i], n)); - - out_shape[i] = (stop[i] - start[i] + strides[i] - 1) / strides[i]; - } - - // If strides are negative, slice and then make a copy with axes flipped - if (negatively_strided_axes.size() > 0) { - // First, take the slice of the positively strided axes - auto out = array( - out_shape, - a.dtype(), - std::make_unique( - to_stream(s), - std::move(start), - std::move(stop), - std::move(strides)), - {a}); - - std::vector indices; - std::vector slice_sizes = out.shape(); - std::vector t_axes(out.ndim(), -1); - std::vector out_reshape(out.ndim(), -1); - - int n_axes = negatively_strided_axes.size(); - for (int i = 0; i < n_axes; i++) { - // Get axis and corresponding slice - auto ax = negatively_strided_axes[i]; - auto sl = negatively_strided_slices[i]; - - // Get indices for the slice - auto ax_idx = arange(sl[0], sl[1], sl[2], s); - - // Reshape indices for broadcast as needed - std::vector ax_idx_shape(n_axes, 1); - ax_idx_shape[i] = ax_idx.size(); - ax_idx = reshape(ax_idx, ax_idx_shape, s); - - // Add indices to list - indices.push_back(ax_idx); - - // Set slice size for axis - slice_sizes[ax] = 1; - - // Gather moves the axis up, remainder needs to be squeezed - out_reshape[i] = indices[i].size(); - - // Gather moves the axis up, needs to be transposed - t_axes[ax] = i; - } - - // Prepare out_reshape to squeeze gathered dims - // Prepare to transpose dims as needed - int j = n_axes; - for (int i = 0; j < out.ndim() && i < out.ndim(); i++) { - if (t_axes[i] < 0) { - t_axes[i] = j; - out_reshape[j] = out_shape[i]; - j++; - } - } - - // Gather - out = gather(out, indices, negatively_strided_axes, slice_sizes, s); - - // Squeeze dims - out = reshape(out, out_reshape, s); - - // Transpose dims - out = transpose(out, t_axes, s); - - return out; - } - if (out_shape == a.shape()) { + if (!has_neg_strides && out_shape == a.shape()) { return a; } + return array( out_shape, a.dtype(), @@ -582,6 +536,56 @@ array slice( return slice(a, start, stop, std::vector(a.ndim(), 1), to_stream(s)); } +/** Update a slice from the source array */ +array slice_update( + const array& src, + const array& update, + std::vector start, + std::vector stop, + std::vector strides, + StreamOrDevice s /* = {} */) { + // Check dimensions + if (start.size() != src.ndim() || stop.size() != src.ndim() || + strides.size() != src.ndim()) { + std::ostringstream msg; + msg << "[slice] Invalid number of indices or strides for " + << "array with dimension " << src.ndim() << "."; + throw std::invalid_argument(msg.str()); + } + + // Process slice dimensions + auto [has_neg_strides, upd_shape] = + normalize_slice(src.shape(), start, stop, strides); + + // Broadcast update shape to slice shape + auto upd_shape_broadcast = broadcast_shapes(upd_shape, update.shape()); + auto update_broadcasted = broadcast_to(update, upd_shape_broadcast, s); + + // If the entire src is the slice, just return the update + if (!has_neg_strides && upd_shape == src.shape()) { + return astype(update_broadcasted, src.dtype(), s); + } + + return array( + src.shape(), + src.dtype(), + std::make_unique( + to_stream(s), std::move(start), std::move(stop), std::move(strides)), + {src, update}); +} + +/** Update a slice from the source array with stride 1 in each dimension */ +array slice_update( + const array& src, + const array& update, + std::vector start, + std::vector stop, + StreamOrDevice s /* = {} */) { + auto strides = std::vector(src.ndim(), 1); + return slice_update( + src, update, std::move(start), std::move(stop), std::move(strides), s); +} + std::vector split( const array& a, const std::vector& indices, diff --git a/mlx/ops.h b/mlx/ops.h index 263eef35c..0df68ecc9 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -177,6 +177,23 @@ array slice( const std::vector& stop, StreamOrDevice s = {}); +/** Update a slice from the source array */ +array slice_update( + const array& src, + const array& update, + std::vector start, + std::vector stop, + std::vector strides, + StreamOrDevice s = {}); + +/** Update a slice from the source array with stride 1 in each dimension */ +array slice_update( + const array& src, + const array& update, + std::vector start, + std::vector stop, + StreamOrDevice s = {}); + /** Split an array into sub-arrays along a given axis. */ std::vector split(const array& a, int num_splits, int axis, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index f2b7f16af..e911d19dd 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2849,6 +2849,114 @@ bool Slice::is_equivalent(const Primitive& other) const { end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_); } +std::pair, std::vector> SliceUpdate::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 2); + assert(axes.size() == 2); + + auto start = start_indices_; + auto stop = end_indices_; + auto strides = strides_; + + auto src = inputs[0]; + auto upd = inputs[1]; + + auto src_ax = axes[0]; + auto upd_ax = axes[1]; + + // No vmapping needed + if (src_ax == -1 && upd_ax == -1) { + return {{slice_update(src, upd, start, stop, strides, stream())}, {-1}}; + } + + // Broadcast src + if (src_ax == -1) { + src = expand_dims(src, upd_ax, stream()); + auto shape = src.shape(); + shape[upd_ax] = upd.shape(upd_ax); + src = broadcast_to(src, shape, stream()); + src_ax = upd_ax; + } + + // Broadcast upd + if (upd_ax == -1) { + upd = expand_dims(upd, src_ax, stream()); + upd_ax = src_ax; + } + + if (src_ax != upd_ax) { + upd = moveaxis(upd, upd_ax, src_ax, stream()); + } + + start.insert(start.begin() + src_ax, 0); + stop.insert(stop.begin() + src_ax, src.shape(src_ax)); + strides.insert(strides.begin() + src_ax, 1); + + return {{slice_update(src, upd, start, stop, strides, stream())}, {src_ax}}; +} + +std::vector SliceUpdate::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + // Check inputs + assert(primals.size() == 2); + + auto& cotan = cotangents[0]; + auto& src = primals[0]; + auto& upd = primals[1]; + + std::vector vjps; + + for (int num : argnums) { + // Vjp for source + if (num == 0) { + auto grad = slice_update( + cotan, + zeros_like(upd, stream()), + start_indices_, + end_indices_, + strides_, + stream()); + + vjps.push_back(grad); + } + // Vjp fpr updates + else { + auto grad = + slice(cotan, start_indices_, end_indices_, strides_, stream()); + + vjps.push_back(grad); + } + } + + return vjps; +} + +std::vector SliceUpdate::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + // Check inputs + assert(primals.size() == 2); + return {slice_update( + tangents[0], + tangents[1], + start_indices_, + end_indices_, + strides_, + stream())}; +} + +bool SliceUpdate::is_equivalent(const Primitive& other) const { + const SliceUpdate& s_other = static_cast(other); + return ( + start_indices_ == s_other.start_indices_ && + end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_); +} + std::pair, std::vector> Softmax::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index ebb11b04d..8a0a06894 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1660,6 +1660,44 @@ class Slice : public UnaryPrimitive { std::vector strides_; void eval(const std::vector& inputs, array& out); + + std::tuple> prepare_slice( + const array& in); + void shared_buffer_slice( + const array& in, + const std::vector& out_strides, + size_t data_offset, + array& out); +}; + +class SliceUpdate : public UnaryPrimitive { + public: + explicit SliceUpdate( + Stream stream, + const std::vector& start_indices, + const std::vector& end_indices, + const std::vector& strides) + : UnaryPrimitive(stream), + start_indices_(start_indices), + end_indices_(end_indices), + strides_(strides){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(SliceUpdate) + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector start_indices_; + std::vector end_indices_; + std::vector strides_; + + void eval(const std::vector& inputs, array& out); + + std::tuple> prepare_slice(const array& in); }; class Softmax : public UnaryPrimitive { diff --git a/mlx/utils.cpp b/mlx/utils.cpp index c6365beb9..07ef696f0 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -329,4 +329,13 @@ std::ostream& operator<<(std::ostream& os, const std::vector& v) { return os; } +std::ostream& operator<<(std::ostream& os, const std::vector& v) { + os << "("; + for (int i = 0; i < v.size(); ++i) { + os << v[i] << ((i == v.size() - 1) ? "" : ","); + } + os << ")"; + return os; +} + } // namespace mlx::core diff --git a/mlx/utils.h b/mlx/utils.h index ebcca3a1e..b2eedbdaf 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once @@ -76,6 +76,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k); std::ostream& operator<<(std::ostream& os, array a); std::ostream& operator<<(std::ostream& os, const std::vector& v); std::ostream& operator<<(std::ostream& os, const std::vector& v); +std::ostream& operator<<(std::ostream& os, const std::vector& v); inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j"; } diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index c4305cf92..204f1ffd3 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -198,6 +198,31 @@ TEST_CASE("test slice") { CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item()); } +TEST_CASE("test slice update") { + array x = array({0., 0., 0., 0., 0., 0., 0., 0.}, {8}, float32); + array y = array( + { + 1., + 2., + 3., + 4., + }, + {4}, + float32); + + auto out = slice_update(x, y, {2}, {6}, {1}); + CHECK(array_equal(slice(out, {2}, {6}, {1}), y).item()); + + out = slice_update(x, y, {5}, {1}, {-1}); + CHECK(array_equal(slice(out, {5}, {1}, {-1}), y).item()); + + x = reshape(x, {2, 4}); + out = slice_update(x, y, {0, 0}, {2, 4}, {1, 1}); + out = reshape(out, {8}); + CHECK(array_equal(slice(out, {0}, {4}, {1}), y).item()); + CHECK(array_equal(slice(out, {4}, {8}, {1}), y).item()); +} + TEST_CASE("test split") { array x = array(1); CHECK_THROWS(split(x, 0));