From 516ded618b3ca41dfab22d59218f85e7c2d22196 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 7 Jan 2025 14:02:16 -0800 Subject: [PATCH] Dynamic slicing (#1741) * dynamic slice and slice update * python bindings + tests + fix set item * fix compile issue * comment * fix jit --- docs/src/python/ops.rst | 2 + mlx/array.h | 12 +-- mlx/backend/common/primitives.cpp | 80 +++++++++++++- mlx/backend/metal/copy.cpp | 33 +++++- mlx/backend/metal/copy.h | 10 +- mlx/backend/metal/jit_kernels.cpp | 32 ++++++ mlx/backend/metal/kernels.h | 6 ++ mlx/backend/metal/kernels/copy.h | 75 +++++++++++++ mlx/backend/metal/kernels/copy.metal | 52 +++++---- mlx/backend/metal/nojit_kernels.cpp | 8 ++ mlx/backend/metal/primitives.cpp | 122 ++++++++++++++++++++- mlx/backend/no_cpu/primitives.cpp | 2 + mlx/backend/no_metal/primitives.cpp | 2 + mlx/export.cpp | 2 + mlx/ops.cpp | 110 ++++++++++++++++++- mlx/ops.h | 28 ++++- mlx/primitives.cpp | 155 +++++++++++++++++++++++++-- mlx/primitives.h | 45 ++++++++ python/src/indexing.cpp | 39 +++---- python/src/ops.cpp | 77 +++++++++++++ python/tests/test_array.py | 10 ++ python/tests/test_ops.py | 13 +++ tests/autograd_tests.cpp | 21 ++++ tests/linalg_tests.cpp | 2 +- tests/ops_tests.cpp | 46 +++++++- tests/random_tests.cpp | 2 +- tests/vmap_tests.cpp | 30 ++++++ 27 files changed, 941 insertions(+), 75 deletions(-) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 9e50feebb..248028575 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -145,6 +145,8 @@ Operations sign sin sinh + slice + slice_update softmax sort split diff --git a/mlx/array.h b/mlx/array.h index d4ed48b6c..6ad0e578a 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -35,29 +35,29 @@ class array { explicit array(const std::complex& val, Dtype dtype = complex64); template - array( + explicit array( It data, Shape shape, Dtype dtype = TypeToDtype::value_type>()); template - array(std::initializer_list data, Dtype dtype = TypeToDtype()); + explicit array(std::initializer_list data, Dtype dtype = TypeToDtype()); /* Special case so empty lists default to float32. */ - array(std::initializer_list data); + explicit array(std::initializer_list data); /* Special case so array({}, type) is an empty array. */ - array(std::initializer_list data, Dtype dtype); + explicit array(std::initializer_list data, Dtype dtype); template - array( + explicit array( std::initializer_list data, Shape shape, Dtype dtype = TypeToDtype()); /* Build an array from a buffer */ - array( + explicit array( allocator::Buffer data, Shape shape, Dtype dtype, diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 8c831a9b5..9ea015cd5 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -29,6 +29,35 @@ void reshape(const array& in, array& out) { } } +int64_t compute_dynamic_offset( + const array& indices, + const Strides& strides, + const std::vector& axes) { + auto compute_offset = [&strides, &axes](const auto* indices) { + int64_t offset = 0; + for (int i = 0; i < axes.size(); ++i) { + offset += indices[i] * strides[axes[i]]; + } + return offset; + }; + switch (indices.dtype()) { + case int8: + case uint8: + return compute_offset(indices.data()); + case int16: + case uint16: + return compute_offset(indices.data()); + case int32: + case uint32: + return compute_offset(indices.data()); + case int64: + case uint64: + return compute_offset(indices.data()); + default: + throw std::runtime_error("Invalid indices type."); + } +} + void Abs::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; @@ -519,6 +548,54 @@ void Slice::eval(const std::vector& inputs, array& out) { shared_buffer_slice(in, ostrides, data_offset, data_size, out); } +void DynamicSlice::eval_cpu(const std::vector& inputs, array& out) { + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + auto& in = inputs[0]; + out.set_data(allocator::malloc_or_wait(out.nbytes())); + auto i_offset = compute_dynamic_offset(inputs[1], in.strides(), axes_); + copy_inplace( + /* const array& src = */ in, + /* array& dst = */ out, + /* const Shape& data_shape = */ out.shape(), + /* const Strides& i_strides = */ in.strides(), + /* const Strides& o_strides = */ out.strides(), + /* int64_t i_offset = */ i_offset, + /* int64_t o_offset = */ 0, + /* CopyType ctype = */ CopyType::GeneralGeneral); +} + +void DynamicSliceUpdate::eval_cpu( + const std::vector& inputs, + array& out) { + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + auto& upd = inputs[1]; + + // Copy or move src to dst + auto ctype = in.flags().contiguous && in.size() == in.data_size() + ? CopyType::Vector + : CopyType::General; + copy(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype); + + auto o_offset = compute_dynamic_offset(inputs[2], out.strides(), axes_); + copy_inplace( + /* const array& src = */ upd, + /* array& dst = */ out, + /* const std::vector& data_shape = */ upd.shape(), + /* const std::vector& i_strides = */ upd.strides(), + /* const std::vector& o_strides = */ out.strides(), + /* int64_t i_offset = */ 0, + /* int64_t o_offset = */ o_offset, + /* CopyType ctype = */ CopyType::GeneralGeneral); +} + void SliceUpdate::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (out.size() == 0) { @@ -544,12 +621,11 @@ void SliceUpdate::eval(const std::vector& inputs, array& out) { auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_); // Do copy - Strides upd_strides{upd.strides().begin(), upd.strides().end()}; copy_inplace( /* const array& src = */ upd, /* array& dst = */ out, /* const std::vector& data_shape = */ upd.shape(), - /* const std::vector& i_strides = */ upd_strides, + /* const std::vector& i_strides = */ upd.strides(), /* const std::vector& o_strides = */ out_strides, /* int64_t i_offset = */ 0, /* int64_t o_offset = */ data_offset, diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index f808c52ed..16d2db362 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -52,7 +52,9 @@ void copy_gpu_inplace( int64_t inp_offset, int64_t out_offset, CopyType ctype, - const Stream& s) { + const Stream& s, + const std::optional& dynamic_i_offset /* = std::nullopt */, + const std::optional& dynamic_o_offset /* = std::nullopt */) { if (out.size() == 0) { return; } @@ -80,6 +82,7 @@ void copy_gpu_inplace( } else { large = out.data_size() > UINT32_MAX; } + bool dynamic = dynamic_i_offset || dynamic_o_offset; auto& d = metal::device(s.device); int work_per_thread = 1; std::string kernel_name; @@ -107,9 +110,17 @@ void copy_gpu_inplace( if (large) { kernel_name += "large"; } + if (dynamic) { + kernel_name += "_dynamic"; + if (ctype != CopyType::GeneralGeneral) { + throw std::runtime_error( + "[Copy::eval_gpu] Dynamic output offset requires GeneralGeneral copy"); + } + } } concatenate(kernel_name, "_copy", type_to_name(in), type_to_name(out)); - auto kernel = get_copy_kernel(d, kernel_name, in, out); + auto kernel = dynamic ? get_dynamic_copy_kernel(d, kernel_name, in, out) + : get_copy_kernel(d, kernel_name, in, out); auto& compute_encoder = d.get_command_encoder(s.index); compute_encoder.set_compute_pipeline_state(kernel); @@ -145,6 +156,18 @@ void copy_gpu_inplace( compute_encoder.set_bytes(ndim, 5); dim0 = (dim0 + work_per_thread - 1) / work_per_thread; } + if (dynamic) { + if (dynamic_i_offset) { + compute_encoder.set_input_array(*dynamic_i_offset, 6); + } else { + compute_encoder.set_bytes(0ll, 6); + } + if (dynamic_o_offset) { + compute_encoder.set_input_array(*dynamic_o_offset, 7); + } else { + compute_encoder.set_bytes(0ll, 7); + } + } // NB assuming thread_group_size is a power of 2 larger than 32 x 32 if (thread_group_size != 1024) { @@ -179,13 +202,13 @@ void copy_gpu_inplace( void copy_gpu_inplace( const array& in, array& out, - const Strides& istride, - int64_t ioffset, + const Strides& i_strides, + int64_t i_offset, CopyType ctype, const Stream& s) { assert(in.shape() == out.shape()); return copy_gpu_inplace( - in, out, in.shape(), istride, out.strides(), ioffset, 0, ctype, s); + in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s); } void fill_gpu(const array& val, array& out, const Stream& s) { diff --git a/mlx/backend/metal/copy.h b/mlx/backend/metal/copy.h index 2568f9afa..37c60df42 100644 --- a/mlx/backend/metal/copy.h +++ b/mlx/backend/metal/copy.h @@ -17,13 +17,15 @@ void copy_gpu_inplace( int64_t i_offset, int64_t o_offset, CopyType ctype, - const Stream& s); + const Stream& s, + const std::optional& dynamic_i_offset = std::nullopt, + const std::optional& dynamic_o_offset = std::nullopt); void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s); void copy_gpu(const array& src, array& out, CopyType ctype); void copy_gpu_inplace( - const array& src, + const array& in, array& out, CopyType ctype, const Stream& s); @@ -31,8 +33,8 @@ void copy_gpu_inplace( void copy_gpu_inplace( const array& in, array& out, - const Strides& istride, - int64_t ioffset, + const Strides& i_strides, + int64_t i_offset, CopyType ctype, const Stream& s); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index f364cb8ee..78560bb2a 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -218,6 +218,38 @@ MTL::ComputePipelineState* get_copy_kernel( return d.get_kernel(kernel_name, lib); } +MTL::ComputePipelineState* get_dynamic_copy_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out) { + std::string lib_name = kernel_name.substr(kernel_name.find("_") + 1); + auto lib = d.get_library(lib_name, [&]() { + std::string kernel_source = metal::utils(); + kernel_source += metal::copy(); + auto in_type = get_type_string(in.dtype()); + auto out_type = get_type_string(out.dtype()); + kernel_source += get_template_definition( + "gg1_" + lib_name, "copy_gg_dynamic_nd1", in_type, out_type, "int"); + kernel_source += get_template_definition( + "gg2_" + lib_name, "copy_gg_dynamic_nd2", in_type, out_type, "int"); + kernel_source += get_template_definition( + "gg3_" + lib_name, "copy_gg_dynamic_nd3", in_type, out_type, "int"); + kernel_source += get_template_definition( + "ggn2_" + lib_name, "copy_gg_dynamic", in_type, out_type, 2, "int"); + kernel_source += get_template_definition( + "gg1large_" + lib_name, "copy_gg_dynamic_nd1", in_type, out_type); + kernel_source += get_template_definition( + "gg2large_" + lib_name, "copy_gg_dynamic_nd2", in_type, out_type); + kernel_source += get_template_definition( + "gg3large_" + lib_name, "copy_gg_dynamic_nd3", in_type, out_type); + kernel_source += get_template_definition( + "ggn4large_" + lib_name, "copy_gg_dynamic", in_type, out_type, 4); + return kernel_source; + }); + return d.get_kernel(kernel_name, lib); +} + MTL::ComputePipelineState* get_softmax_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index 20e3bd907..dd6213754 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -45,6 +45,12 @@ MTL::ComputePipelineState* get_copy_kernel( const array& in, const array& out); +MTL::ComputePipelineState* get_dynamic_copy_kernel( + metal::Device& d, + const std::string& kernel_name, + const array& in, + const array& out); + MTL::ComputePipelineState* get_softmax_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/kernels/copy.h b/mlx/backend/metal/kernels/copy.h index dddcda366..b1367cf4f 100644 --- a/mlx/backend/metal/kernels/copy.h +++ b/mlx/backend/metal/kernels/copy.h @@ -161,3 +161,78 @@ template idx.y += dst_xstride; } } + +template +[[kernel]] void copy_gg_dynamic_nd1( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t& src_stride [[buffer(3)]], + constant const int64_t& dst_stride [[buffer(4)]], + constant const int64_t& src_offset [[buffer(6)]], + constant const int64_t& dst_offset [[buffer(7)]], + uint index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_1(index, src_stride); + auto dst_idx = elem_to_loc_1(index, dst_stride); + dst[dst_idx + dst_offset] = src[src_idx + src_offset]; +} + +template +[[kernel]] void copy_gg_dynamic_nd2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + constant const int64_t& src_offset [[buffer(6)]], + constant const int64_t& dst_offset [[buffer(7)]], + uint2 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_2(index, src_strides); + auto dst_idx = elem_to_loc_2(index, dst_strides); + dst[dst_idx + dst_offset] = src[src_idx + src_offset]; +} + +template +[[kernel]] void copy_gg_dynamic_nd3( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + constant const int64_t& src_offset [[buffer(6)]], + constant const int64_t& dst_offset [[buffer(7)]], + uint3 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_3(index, src_strides); + auto dst_idx = elem_to_loc_3(index, dst_strides); + dst[dst_idx + dst_offset] = src[src_idx + src_offset]; +} + +template +[[kernel]] void copy_gg_dynamic( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + constant const int& ndim [[buffer(5)]], + constant const int64_t& src_offset [[buffer(6)]], + constant const int64_t& dst_offset [[buffer(7)]], + uint3 index [[thread_position_in_grid]]) { + src += src_offset; + dst += dst_offset; + auto idx = elem_to_loc_2_nd( + {N * index.x, index.y, index.z}, + src_shape, + src_strides, + dst_strides, + ndim); + if (N == 1) { + dst[idx.y] = src[idx.x]; + return; + } + IdxT src_xstride = src_strides[ndim - 1]; + IdxT dst_xstride = dst_strides[ndim - 1]; + auto xshape = src_shape[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + dst[idx.y] = src[idx.x]; + idx.x += src_xstride; + idx.y += dst_xstride; + } +} diff --git a/mlx/backend/metal/kernels/copy.metal b/mlx/backend/metal/kernels/copy.metal index 68e6dcec6..bbf268158 100644 --- a/mlx/backend/metal/kernels/copy.metal +++ b/mlx/backend/metal/kernels/copy.metal @@ -4,29 +4,37 @@ #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/copy.h" -#define instantiate_copy_all(tname, itype, otype) \ - instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \ - instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \ - instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \ - instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \ - instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \ - instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \ - instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \ - instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \ - instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \ - instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \ - instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) +#define instantiate_copy_all(tname, itype, otype) \ + instantiate_kernel("s_copy" #tname, copy_s, itype, otype) \ + instantiate_kernel("v_copy" #tname, copy_v, itype, otype) \ + instantiate_kernel("s2_copy" #tname, copy_s2, itype, otype) \ + instantiate_kernel("v2_copy" #tname, copy_v2, itype, otype) \ + instantiate_kernel("g1_copy" #tname, copy_g_nd1, itype, otype, int) \ + instantiate_kernel("g2_copy" #tname, copy_g_nd2, itype, otype, int) \ + instantiate_kernel("g3_copy" #tname, copy_g_nd3, itype, otype, int) \ + instantiate_kernel("gn2_copy" #tname, copy_g, itype, otype, 2, int) \ + instantiate_kernel("g1large_copy" #tname, copy_g_nd1, itype, otype) \ + instantiate_kernel("g2large_copy" #tname, copy_g_nd2, itype, otype) \ + instantiate_kernel("g3large_copy" #tname, copy_g_nd3, itype, otype) \ + instantiate_kernel("gn4large_copy" #tname, copy_g, itype, otype, 4) -#define instantiate_copy_same(tname, type) \ - instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \ - instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \ - instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, type, type, int) \ - instantiate_kernel("ggn2_copy" #tname, copy_gg, type, type, 2, int) \ - instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, type, type) \ - instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, type, type) \ - instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, type, type) \ - instantiate_kernel("gn4large_copy" #tname, copy_g, type, type, 4) \ - instantiate_kernel("ggn4large_copy" #tname, copy_gg, type, type, 4) +#define instantiate_copy_same(tname, type) \ + instantiate_kernel("gg1_copy" #tname, copy_gg_nd1, type, type, int) \ + instantiate_kernel("gg2_copy" #tname, copy_gg_nd2, type, type, int) \ + instantiate_kernel("gg3_copy" #tname, copy_gg_nd3, type, type, int) \ + instantiate_kernel("ggn2_copy" #tname, copy_gg, type, type, 2, int) \ + instantiate_kernel("gg1large_copy" #tname, copy_gg_nd1, type, type) \ + instantiate_kernel("gg2large_copy" #tname, copy_gg_nd2, type, type) \ + instantiate_kernel("gg3large_copy" #tname, copy_gg_nd3, type, type) \ + instantiate_kernel("ggn4large_copy" #tname, copy_gg, type, type, 4) \ + instantiate_kernel("gg1_dynamic_copy" #tname, copy_gg_dynamic_nd1, type, type, int) \ + instantiate_kernel("gg2_dynamic_copy" #tname, copy_gg_dynamic_nd2, type, type, int) \ + instantiate_kernel("gg3_dynamic_copy" #tname, copy_gg_dynamic_nd3, type, type, int) \ + instantiate_kernel("ggn2_dynamic_copy" #tname, copy_gg_dynamic, type, type, 2, int) \ + instantiate_kernel("gg1large_dynamic_copy" #tname, copy_gg_dynamic_nd1, type, type) \ + instantiate_kernel("gg2large_dynamic_copy" #tname, copy_gg_dynamic_nd2, type, type) \ + instantiate_kernel("gg3large_dynamic_copy" #tname, copy_gg_dynamic_nd3, type, type) \ + instantiate_kernel("ggn4large_dynamic_copy" #tname, copy_gg_dynamic, type, type, 4) #define instantiate_copy_itype(itname, itype) \ instantiate_copy_same(itname ##itname, itype) \ diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 03b31197b..ff561374d 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -56,6 +56,14 @@ MTL::ComputePipelineState* get_copy_kernel( return d.get_kernel(kernel_name); } +MTL::ComputePipelineState* get_dynamic_copy_kernel( + metal::Device& d, + const std::string& kernel_name, + const array&, + const array&) { + return d.get_kernel(kernel_name); +} + MTL::ComputePipelineState* get_softmax_kernel( metal::Device& d, const std::string& kernel_name, diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 4f1518d9a..f796af805 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -4,6 +4,7 @@ #include #include +#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/load.h" #include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" @@ -43,6 +44,59 @@ void reshape(const array& in, array& out, Stream s) { shared_buffer_reshape(in, out_strides, out); } } +array compute_dynamic_offset( + const array& indices, + const Strides& strides, + const std::vector& axes, + Stream s) { + auto& d = metal::device(s.device); + + // Kernel to compute offset here. + array offset({1}, int64, nullptr, {}); + bool donate = indices.is_donatable() && + (indices.data_size() * indices.itemsize()) >= offset.itemsize(); + if (donate) { + offset.move_shared_buffer(indices); + } else { + offset.set_data(allocator::malloc_or_wait(offset.itemsize())); + } + d.add_temporary(offset, s.index); + + auto dtype = indices.dtype(); + std::string lib_name = "compute_dynamic_offset_" + type_to_name(dtype); + auto lib = d.get_library(lib_name, [dtype]() { + return fmt::format( + R"( + [[kernel]] void compute_dynamic_offset_{0}( + constant const {1}* indices [[buffer(0)]], + device int64_t& offset [[buffer(1)]], + constant const int64_t* strides [[buffer(2)]], + constant const int* axes [[buffer(3)]], + constant const int& n_axes [[buffer(4)]], + uint index [[thread_position_in_grid]]) {{ + int64_t acc = 0; + for (int i = 0; i < n_axes; ++i) {{ + acc += indices[i] * strides[axes[i]]; + }} + offset = acc; + }})", + type_to_name(dtype), + get_type_string(dtype)); + }); + auto kernel = d.get_kernel(lib_name, lib); + + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(donate ? offset : indices, 0); + compute_encoder.set_output_array(offset, 1); + compute_encoder.set_vector_bytes(strides, 2); + compute_encoder.set_vector_bytes(axes, 3); + int n_axes = axes.size(); + compute_encoder.set_bytes(n_axes, 4); + MTL::Size dims = MTL::Size(1, 1, 1); + compute_encoder.dispatch_threads(dims, dims); + return offset; +} void Arange::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 0); @@ -356,6 +410,72 @@ void Slice::eval_gpu(const std::vector& inputs, array& out) { slice_gpu(in, out, start_indices_, strides_, stream()); } +void DynamicSlice::eval_gpu(const std::vector& inputs, array& out) { + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + auto& start = inputs[1]; + out.set_data(allocator::malloc_or_wait(out.nbytes())); + auto s = stream(); + auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s); + copy_gpu_inplace( + /* const array& src = */ in, + /* array& dst = */ out, + /* const Shape& data_shape = */ out.shape(), + /* const Strides& i_strides = */ in.strides(), + /* const Strides& o_strides = */ out.strides(), + /* int64_t i_offset = */ 0, + /* int64_t o_offset = */ 0, + /* CopyType ctype = */ CopyType::GeneralGeneral, + /* const Stream& s = */ s, + /* const std::optional& dynamic_i_offset = */ in_offset, + /* const std::optional& dynamic_o_offset = */ std::nullopt); +} + +void DynamicSliceUpdate::eval_gpu( + const std::vector& inputs, + array& out) { + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + + auto& in = inputs[0]; + auto& upd = inputs[1]; + auto& start_indices = inputs[2]; + + if (upd.size() == 0) { + move_or_copy(in, out); + return; + } + + // Copy or donate input to output + auto s = stream(); + auto& d = metal::device(s.device); + auto ctype = in.flags().contiguous && in.size() == in.data_size() + ? CopyType::Vector + : CopyType::General; + copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, s); + + auto out_offset = + compute_dynamic_offset(start_indices, out.strides(), axes_, s); + copy_gpu_inplace( + /* const array& src = */ upd, + /* array& dst = */ out, + /* const Shape& data_shape = */ upd.shape(), + /* const Strides& i_strides = */ upd.strides(), + /* const Strides& o_strides = */ out.strides(), + /* int64_t i_offset = */ 0, + /* int64_t o_offset = */ 0, + /* CopyType ctype = */ CopyType::GeneralGeneral, + /* const Stream& s = */ s, + /* const std::optional& dynamic_i_offset = */ std::nullopt, + /* const std::optional& dynamic_o_offset = */ out_offset); +} + void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 2); if (out.size() == 0) { @@ -371,13 +491,11 @@ void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { return; } - // Check if materialization is needed auto ctype = in.flags().contiguous && in.size() == in.data_size() ? CopyType::Vector : CopyType::General; copy_gpu(in, out, in.data_size() == 1 ? CopyType::Scalar : ctype, stream()); - // Calculate out strides, initial offset and if copy needs to be made auto [data_offset, out_strides] = prepare_slice(in, start_indices_, strides_); // Do copy diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 9db4d3983..751a23418 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -48,6 +48,8 @@ NO_CPU_MULTI(CustomTransforms) NO_CPU_MULTI(Depends) NO_CPU(Divide) NO_CPU_MULTI(DivMod) +NO_CPU(DynamicSlice) +NO_CPU(DynamicSliceUpdate) NO_CPU(NumberOfElements) NO_CPU(Remainder) NO_CPU_MULTI(Eigh) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index f7a34c8e6..b864fc7f2 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -49,6 +49,8 @@ NO_GPU_MULTI(CustomTransforms) NO_GPU_MULTI(Depends) NO_GPU(Divide) NO_GPU_MULTI(DivMod) +NO_GPU(DynamicSlice) +NO_GPU(DynamicSliceUpdate) NO_GPU(NumberOfElements) NO_GPU(Remainder) NO_GPU(Equal) diff --git a/mlx/export.cpp b/mlx/export.cpp index ae826e61c..3ab6213cf 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -253,6 +253,8 @@ struct PrimitiveFactory { SERIALIZE_PRIMITIVE(Depends), SERIALIZE_PRIMITIVE(Divide), SERIALIZE_PRIMITIVE(DivMod), + SERIALIZE_PRIMITIVE(DynamicSlice), + SERIALIZE_PRIMITIVE(DynamicSliceUpdate), SERIALIZE_PRIMITIVE(Equal, "NaNEqual"), SERIALIZE_PRIMITIVE(Erf), SERIALIZE_PRIMITIVE(ErfInv), diff --git a/mlx/ops.cpp b/mlx/ops.cpp index ad0a64697..4362b1d08 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -647,6 +647,52 @@ normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) { return std::make_pair(has_neg_strides, out_shape); } +void normalize_dynamic_slice_inputs( + const array& a, + const array& start, + std::vector& axes, + const std::string prefix) { + if (start.size() > a.ndim()) { + std::ostringstream msg; + msg << prefix << " Invalid number of starting positions for " + << "array with dimension " << a.ndim() << "."; + throw std::invalid_argument(msg.str()); + } + if (start.ndim() > 1) { + std::ostringstream msg; + msg << prefix << " array of starting indices " + << "must be zero or one dimensional but has dimension " << start.ndim() + << "."; + throw std::invalid_argument(msg.str()); + } + if (start.size() != axes.size()) { + std::ostringstream msg; + msg << prefix << " Number of starting indices " << start.size() + << " does not match number of axes " << axes.size() << "."; + throw std::invalid_argument(msg.str()); + } + if (!issubdtype(start.dtype(), integer)) { + std::ostringstream msg; + msg << prefix << " Start indices must be integers, got type " + << start.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + for (auto& ax : axes) { + auto new_ax = ax < 0 ? ax + a.ndim() : ax; + if (new_ax < 0 || new_ax >= a.ndim()) { + std::ostringstream msg; + msg << prefix << " Invalid axis " << ax << " for array with dimension " + << a.ndim() << "."; + throw std::invalid_argument(msg.str()); + } + ax = new_ax; + } + std::set dims(axes.begin(), axes.end()); + if (dims.size() != axes.size()) { + throw std::invalid_argument(prefix + " Repeat axes not allowed."); + } +} + } // namespace array slice( @@ -687,6 +733,38 @@ array slice( a, std::move(start), std::move(stop), Shape(a.ndim(), 1), to_stream(s)); } +array slice( + const array& a, + const array& start, + std::vector axes, + Shape slice_size, + StreamOrDevice s /* = {} */) { + normalize_dynamic_slice_inputs(a, start, axes, "[slice]"); + + // Check the slice_size + if (slice_size.size() != a.ndim()) { + std::ostringstream msg; + msg << "[slice] Invalid slice size for array with " << a.ndim() + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + for (int i = 0; i < a.ndim(); ++i) { + if (slice_size[i] > a.shape(i)) { + std::ostringstream msg; + msg << "[slice] Invalid slice size " << slice_size + << " for array with shape " << a.shape() << "."; + throw std::invalid_argument(msg.str()); + } + } + auto out_shape = slice_size; + return array( + std::move(out_shape), + a.dtype(), + std::make_shared( + to_stream(s), std::move(axes), std::move(slice_size)), + {a, start}); +} + /** Update a slice from the source array */ array slice_update( const array& src, @@ -699,7 +777,7 @@ array slice_update( if (start.size() != src.ndim() || stop.size() != src.ndim() || strides.size() != src.ndim()) { std::ostringstream msg; - msg << "[slice] Invalid number of indices or strides for " + msg << "[slice_update] Invalid number of indices or strides for " << "array with dimension " << src.ndim() << "."; throw std::invalid_argument(msg.str()); } @@ -734,6 +812,36 @@ array slice_update( src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s); } +/** Update a slice from the source array */ +array slice_update( + const array& src, + const array& update, + const array& start, + std::vector axes, + StreamOrDevice s /* = {} */) { + normalize_dynamic_slice_inputs(src, start, axes, "[slice_update]"); + + // Broadcast update with unspecified axes + auto up_shape = update.shape(); + auto dim_diff = std::max(src.ndim() - update.ndim(), 0ul); + up_shape.insert( + up_shape.begin(), src.shape().begin(), src.shape().begin() + dim_diff); + for (int d = dim_diff; d < src.ndim(); ++d) { + up_shape[d] = std::min(up_shape[d], src.shape(d)); + } + for (auto ax : axes) { + if (ax < dim_diff) { + up_shape[ax] = 1; + } + } + auto upd = broadcast_to(astype(update, src.dtype(), s), up_shape, s); + return array( + src.shape(), + src.dtype(), + std::make_shared(to_stream(s), std::move(axes)), + {src, upd, start}); +} + std::vector split( const array& a, const Shape& indices, diff --git a/mlx/ops.h b/mlx/ops.h index aadd187cb..26bff5ec8 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -164,11 +164,27 @@ array slice( Shape stop, Shape strides, StreamOrDevice s = {}); +inline array slice( + const array& a, + std::initializer_list start, + Shape stop, + Shape strides, + StreamOrDevice s = {}) { + return slice(a, Shape(start), std::move(stop), std::move(strides), s); +} /** Slice an array with a stride of 1 in each dimension. */ array slice(const array& a, Shape start, Shape stop, StreamOrDevice s = {}); -/** Update a slice from the source array */ +/** Slice an array with dynamic starting indices. */ +array slice( + const array& a, + const array& start, + std::vector axes, + Shape slice_size, + StreamOrDevice s = {}); + +/** Update a slice from the source array. */ array slice_update( const array& src, const array& update, @@ -177,7 +193,7 @@ array slice_update( Shape strides, StreamOrDevice s = {}); -/** Update a slice from the source array with stride 1 in each dimension */ +/** Update a slice from the source array with stride 1 in each dimension. */ array slice_update( const array& src, const array& update, @@ -185,6 +201,14 @@ array slice_update( Shape stop, StreamOrDevice s = {}); +/** Update a slice from the source array with dynamic starting indices. */ +array slice_update( + const array& src, + const array& update, + const array& start, + std::vector axes, + StreamOrDevice s = {}); + /** Split an array into sub-arrays along a given axis. */ std::vector split(const array& a, int num_splits, int axis, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index f68c69b2b..5b7bf238b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3715,22 +3715,18 @@ std::vector SliceUpdate::vjp( for (int num : argnums) { // Vjp for source if (num == 0) { - auto grad = slice_update( + vjps.push_back(slice_update( cotan, zeros_like(upd, stream()), start_indices_, end_indices_, strides_, - stream()); - - vjps.push_back(grad); + stream())); } // Vjp fpr updates else { - auto grad = - slice(cotan, start_indices_, end_indices_, strides_, stream()); - - vjps.push_back(grad); + vjps.push_back( + slice(cotan, start_indices_, end_indices_, strides_, stream())); } } @@ -3753,12 +3749,153 @@ std::vector SliceUpdate::jvp( } bool SliceUpdate::is_equivalent(const Primitive& other) const { - const SliceUpdate& s_other = static_cast(other); + const auto& s_other = static_cast(other); return ( start_indices_ == s_other.start_indices_ && end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_); } +std::pair, std::vector> DynamicSlice::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto& in = inputs[0]; + auto& start = inputs[1]; + auto vax = axes[0]; + if (axes[1] >= 0) { + throw std::invalid_argument( + "[DynamicSlice::vmap] vmap over start indices not yet supported."); + } + auto slice_size = slice_size_; + auto slice_axes = axes_; + if (vax >= 0) { + for (auto& ax : slice_axes) { + if (ax >= vax) { + ax++; + } + } + slice_size.insert(slice_size.begin() + vax, in.shape(vax)); + } + return { + {slice( + in, start, std::move(slice_axes), std::move(slice_size), stream())}, + {vax}}; +} + +std::vector DynamicSlice::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + if (argnums[0] == 1 || argnums.size() > 1) { + throw std::invalid_argument( + "[DynamicSlice::vjp] Not supported for start indices."); + } + auto out = zeros_like(primals[0], stream()); + return {slice_update(out, cotangents[0], primals[1], axes_, stream())}; +} + +std::vector DynamicSlice::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector&) { + return {slice(tangents[0], primals[1], axes_, slice_size_, stream())}; +} + +bool DynamicSlice::is_equivalent(const Primitive& other) const { + const auto& s_other = static_cast(other); + return (axes_ == s_other.axes_ && slice_size_ == s_other.slice_size_); +} + +std::vector DynamicSlice::output_shapes(const std::vector&) { + return {slice_size_}; +} + +std::pair, std::vector> DynamicSliceUpdate::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto src = inputs[0]; + auto upd = inputs[1]; + auto& start = inputs[2]; + auto src_ax = axes[0]; + auto upd_ax = axes[1]; + if (axes[2] >= 0) { + throw std::runtime_error( + "[DynamicSliceUpdate::vmap] vmap over start indices not yet supported."); + } + // No vmapping needed + if (src_ax == -1 && upd_ax == -1) { + return {{slice_update(src, upd, start, axes_, stream())}, {-1}}; + } + + // Broadcast src + if (src_ax == -1) { + src = expand_dims(src, upd_ax, stream()); + auto shape = src.shape(); + shape[upd_ax] = upd.shape(upd_ax); + src = broadcast_to(src, shape, stream()); + src_ax = upd_ax; + } + + // Broadcast upd + if (upd_ax == -1) { + upd = expand_dims(upd, src_ax, stream()); + upd_ax = src_ax; + } + + if (src_ax != upd_ax) { + upd = moveaxis(upd, upd_ax, src_ax, stream()); + } + + auto slice_axes = axes_; + for (auto& ax : slice_axes) { + if (ax >= src_ax) { + ax++; + } + } + return { + {slice_update(src, upd, start, std::move(slice_axes), stream())}, + {src_ax}}; +} + +std::vector DynamicSliceUpdate::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + auto& cotan = cotangents[0]; + auto& upd = primals[1]; + auto& start = primals[2]; + + std::vector vjps; + + for (int num : argnums) { + if (num == 0) { + // Vjp for source + vjps.push_back(slice_update( + cotan, zeros_like(upd, stream()), start, axes_, stream())); + } else if (num == 1) { + // Vjp fpr updates + vjps.push_back(slice(cotan, start, axes_, upd.shape(), stream())); + } else { + throw std::invalid_argument( + "[DynamicSliceUpdate::vjp] Not supported for start indices"); + } + } + return vjps; +} + +std::vector DynamicSliceUpdate::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector&) { + return {slice_update(tangents[0], tangents[1], primals[2], axes_, stream())}; +} + +bool DynamicSliceUpdate::is_equivalent(const Primitive& other) const { + const auto& s_other = static_cast(other); + return axes_ == s_other.axes_; +} + std::pair, std::vector> Softmax::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 74e50d4fb..97f0c2021 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2057,6 +2057,51 @@ class SliceUpdate : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class DynamicSlice : public UnaryPrimitive { + public: + explicit DynamicSlice(Stream stream, std::vector axes, Shape slice_size) + : UnaryPrimitive(stream), + axes_(std::move(axes)), + slice_size_(std::move(slice_size)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(DynamicSlice) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return std::make_pair(axes_, slice_size_); + } + + private: + std::vector axes_; + Shape slice_size_; +}; + +class DynamicSliceUpdate : public UnaryPrimitive { + public: + explicit DynamicSliceUpdate(Stream stream, std::vector axes) + : UnaryPrimitive(stream), axes_(std::move(axes)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(DynamicSliceUpdate) + bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE() + auto state() const { + return axes_; + } + + private: + std::vector axes_; +}; + class Softmax : public UnaryPrimitive { public: explicit Softmax(Stream stream, bool precise) diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 9770f529e..95c6e2a18 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -764,7 +764,7 @@ auto mlx_slice_update( const mx::array& src, const nb::object& obj, const ScalarOrArray& v) { - // Can't route to slice update if not slice or tuple + // Can't route to slice update if not slice, tuple, or int if (src.ndim() == 0 || (!nb::isinstance(obj) && !nb::isinstance(obj) && !nb::isinstance(obj))) { @@ -845,20 +845,14 @@ auto mlx_slice_update( return std::make_pair(true, broadcast_to(up, src.shape())); } - // Process entries - mx::Shape up_reshape(src.ndim()); - int ax = src.ndim() - 1; - int up_ax = up.ndim() - 1; - for (; ax >= non_none_indices; ax--) { - if (up_ax >= 0) { - up_reshape[ax] = up.shape(up_ax); - up_ax--; - } else { - up_reshape[ax] = 1; - } - } - - for (int i = indices.size() - 1; i >= 0; --i) { + int unspecified = src.ndim() - non_none_indices; + std::vector squeeze_dims; + std::vector expand_dims; + for (int i = indices.size() - 1, + ax = non_none_indices - 1, + upd_ax = upd.ndim() - unspecified - 1; + i >= 0; + --i) { auto& pyidx = indices[i]; if (nb::isinstance(pyidx)) { get_slice_params( @@ -867,19 +861,26 @@ auto mlx_slice_update( strides[ax], nb::cast(pyidx), src.shape(ax)); - up_reshape[ax] = (up_ax >= 0) ? up.shape(up_ax--) : 1; ax--; + upd_ax--; } else if (nb::isinstance(pyidx)) { int st = nb::cast(pyidx); - st = (st < 0) ? st + src.shape(ax) : st; + st = (st < 0) ? st + src.shape(i) : st; starts[ax] = st; stops[ax] = st + 1; - up_reshape[ax] = 1; + if (upd_ax >= 0) { + expand_dims.push_back(i - indices.size() - unspecified); + } ax--; + } else if (pyidx.is_none()) { + if (upd_ax-- >= 0) { + squeeze_dims.push_back(i - indices.size() - unspecified); + } } } - up = reshape(up, std::move(up_reshape)); + up = mx::squeeze( + mx::expand_dims(up, std::move(expand_dims)), std::move(squeeze_dims)); auto out = slice_update(src, up, starts, stops, strides); return std::make_pair(true, out); } diff --git a/python/src/ops.cpp b/python/src/ops.cpp index e0315a437..0147cd709 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5004,4 +5004,81 @@ void init_ops(nb::module_& m) { Returns: array: The imaginary part of ``a``. )pbdoc"); + m.def( + "slice", + [](const mx::array& a, + const mx::array& start_indices, + std::vector axes, + mx::Shape slice_size, + mx::StreamOrDevice s) { + return mx::slice( + a, start_indices, std::move(axes), std::move(slice_size), s); + }, + nb::arg(), + "start_indices"_a, + "axes"_a, + "slice_size"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def slice(a: array, start_indices: array, axes: Sequence[int], slice_size: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Extract a sub-array from the input array. + + Args: + a (array): Input array + start_indices (array): The index location to start the slice at. + axes (tuple(int)): The axes corresponding to the indices in ``start_indices``. + slice_size (tuple(int)): The size of the slice. + + Returns: + array: The sliced output array. + + Example: + + >>> a = mx.array([[1, 2, 3], [4, 5, 6]]) + >>> mx.slice(a, start_indices=mx.array(1), axes=(0,), slice_size=(1, 2)) + array([[4, 5]], dtype=int32) + >>> + >>> mx.slice(a, start_indices=mx.array(1), axes=(1,), slice_size=(2, 1)) + array([[2], + [5]], dtype=int32) + )pbdoc"); + m.def( + "slice_update", + [](const mx::array& src, + const mx::array& update, + const mx::array& start_indices, + std::vector axes, + mx::StreamOrDevice s) { + return mx::slice_update(src, update, start_indices, axes, s); + }, + nb::arg(), + "update"_a, + "start_indices"_a, + "axes"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def slice_update(a: array, update: array, start_indices: array, axes: Sequence[int], *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Update a sub-array of the input array. + + Args: + a (array): The input array to update + update (array): The update array. + start_indices (array): The index location to start the slice at. + axes (tuple(int)): The axes corresponding to the indices in ``start_indices``. + + Returns: + array: The output array with the same shape and type as the input. + + Example: + + >>> a = mx.zeros((3, 3)) + >>> mx.slice_update(a, mx.ones((1, 2)), start_indices=mx.array(1, 1), axes=(0, 1)) + array([[0, 0, 0], + [0, 1, 0], + [0, 1, 0]], dtype=float32) + )pbdoc"); } diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 1f4515b6b..86c061289 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1327,6 +1327,16 @@ class TestArray(mlx_tests.MLXTestCase): x[0, 0] = 1 self.assertTrue(mx.array_equal(x[0, 0], mx.ones((2, 2, 2, 2)))) + a = mx.zeros((2, 2, 2)) + with self.assertRaises(ValueError): + a[:, None, :] = mx.ones((2, 2, 2)) + + # Ok, doesn't throw + a[:, None, :] = mx.ones((2, 1, 2, 2)) + a[:, None, :] = mx.ones((2, 2)) + a[:, None, 0] = mx.ones((2,)) + a[:, None, 0] = mx.ones((1, 2)) + def test_array_at(self): a = mx.array(1) a = a.at[None].add(1) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 12be3d5fd..9e72a83c7 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2769,6 +2769,19 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(mx.imag(z).dtype, mx.float32) self.assertTrue(mx.array_equal(mx.imag(z), y)) + def test_dynamic_slicing(self): + x = mx.random.randint(0, 100, shape=(4, 4, 4)) + expected = x[1:, 2:, 3:] + out = mx.slice(x, mx.array([1, 2, 3]), (0, 1, 2), (3, 2, 1)) + self.assertTrue(mx.array_equal(expected, out)) + + x = mx.zeros(shape=(4, 4, 4)) + update = mx.random.randint(0, 100, shape=(3, 2, 1)) + out = mx.slice_update(x, update, mx.array([1, 2, 3]), (0, 1, 2)) + expected = mx.zeros_like(x) + expected[1:, 2:, 3:] = update + self.assertTrue(mx.array_equal(expected, out)) + if __name__ == "__main__": unittest.main() diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index e87b0ca06..c992c3c6d 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -1291,3 +1291,24 @@ TEST_CASE("test grad types") { } } } + +TEST_CASE("test grad dynamic slices") { + { + auto fn = [](const array& x) { return slice(x, array({0}), {0}, {1, 2}); }; + auto x = array({1, 2, 3, 4}, {2, 2}); + auto out = vjp(fn, x, array({1, 1}, {1, 2})).second; + CHECK(array_equal(out, array({1, 1, 0, 0}, {2, 2})).item()); + } + { + auto fn = [](const std::vector& inputs) { + const auto& x = inputs[0]; + const auto& update = inputs[1]; + return std::vector{slice_update(x, update, array({0}), {0})}; + }; + auto x = zeros({2, 2}); + auto update = array({3.f, 4.f}, {1, 2}); + auto outs = vjp(fn, {x, update}, {ones({2, 2})}).second; + CHECK(allclose(outs[0], array({0.f, 0.f, 1.f, 1.f}, {2, 2})).item()); + CHECK(allclose(outs[1], ones({1, 2})).item()); + } +} diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index c5a1b8808..352d16ff2 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -250,7 +250,7 @@ TEST_CASE("test QR factorization") { // Unsupported types throw CHECK_THROWS(linalg::qr(array({0, 1}, {1, 2}))); - array A = array({{2., 3., 1., 2.}, {2, 2}}); + array A = array({2., 3., 1., 2.}, {2, 2}); auto [Q, R] = linalg::qr(A, Device::cpu); auto out = matmul(Q, R); CHECK(allclose(out, A).item()); diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 3d038cd30..b8bc8b45e 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -353,6 +353,50 @@ TEST_CASE("test slice update") { CHECK(array_equal(slice(out, {4}, {8}, {1}), y).item()); } +TEST_CASE("test dynamic slice") { + auto src = reshape(arange(6), {2, 3}); + CHECK_THROWS(slice(src, array({1, 0, 0}), {0, 0, 0}, {1, 1})); + CHECK_THROWS(slice(src, array({1, 0}), {0}, {1, 1})); + CHECK_THROWS(slice(src, array({1}), {3}, {1, 1})); + CHECK_THROWS(slice(src, array({1, 0}), {0, 0}, {1, 1})); + + CHECK_THROWS(slice(src, array({1}), {0}, {2, 4})); + CHECK_THROWS(slice(src, array({1.0f}, float32), {0}, {1, 1})); + + auto out = slice(src, array({1}), {0}, {1, 2}); + auto expected = array({3, 4}, {1, 2}); + CHECK(array_equal(out, expected).item()); + + out = slice(src, array({1, 1}), {0, 1}, {1, 2}); + expected = array({4, 5}, {1, 2}); + CHECK(array_equal(out, expected).item()); +} + +TEST_CASE("test dynamic slice update") { + auto src = zeros({2, 3}, int32); + auto upd = ones({1, 2}, int32); + CHECK_THROWS(slice_update(src, upd, array({1, 0, 0}), {0, 0, 0})); + CHECK_THROWS(slice_update(src, upd, array({1, 0}), {0})); + CHECK_THROWS(slice_update(src, upd, array({1}), {3})); + CHECK_THROWS(slice_update(src, upd, array({1, 0}), {0, 0})); + + upd = ones({4}, int32); + CHECK_THROWS(slice_update(src, upd, array({1}), {0})); + upd = ones({1, 4}, int32); + CHECK_THROWS(slice_update(src, upd, array({1}), {0})); + CHECK_THROWS(slice_update(src, upd, array({1.0f}, float32), {0})); + + upd = ones({1, 2}, int32); + auto out = slice_update(src, upd, array({1}), {0}); + auto expected = reshape(array({0, 0, 0, 1, 1, 0}), {2, 3}); + CHECK(array_equal(out, expected).item()); + + upd = ones({1, 2}, int32); + out = slice_update(src, upd, array({1, 1}), {0, 1}); + expected = reshape(array({0, 0, 0, 0, 1, 1}), {2, 3}); + CHECK(array_equal(out, expected).item()); +} + TEST_CASE("test split") { array x = array(1); CHECK_THROWS(split(x, 0)); @@ -720,7 +764,7 @@ TEST_CASE("test is inf") { CHECK_FALSE(any(isinf(z)).item()); array w = array({1.0f, inf, 2.0f}); - CHECK(array_equal({false, true, false}, isinf(w)).item()); + CHECK(array_equal(array({false, true, false}), isinf(w)).item()); array a(1.0f, bfloat16); CHECK_FALSE(isinf(a).item()); diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp index 1a9e1aa78..9d51c82b1 100644 --- a/tests/random_tests.cpp +++ b/tests/random_tests.cpp @@ -686,7 +686,7 @@ TEST_CASE("test laplace") { CHECK(std::abs(sample_variance - expected_variance) < 0.01); // Expected kurtosis of Laplace distribution is 3. - array fourth_pows = power(out - sample_mean, {4}); + array fourth_pows = power(out - sample_mean, array(4)); float sample_kurtosis = mean(fourth_pows).item() / std::pow(sample_variance, 2) - 3; float expected_kurtosis = 3.0; diff --git a/tests/vmap_tests.cpp b/tests/vmap_tests.cpp index 88aac6991..38011b942 100644 --- a/tests/vmap_tests.cpp +++ b/tests/vmap_tests.cpp @@ -496,3 +496,33 @@ TEST_CASE("test vmap SVD") { CHECK_EQ(Vt.shape(), Shape{a.shape(2), a.shape(1), a.shape(1)}); } } + +TEST_CASE("test vmap dynamic slices") { + { + auto fun = [](std::vector inputs) { + return std::vector{slice(inputs[0], array({1}), {0}, {2})}; + }; + auto x = reshape(arange(12), {3, 4}); + auto out = vmap(fun)({x})[0]; + CHECK(array_equal(out, array({1, 2, 5, 6, 9, 10}, {3, 2})).item()); + + out = vmap(fun, /* in_axes */ {1}, /* out_axes */ {1})({x})[0]; + CHECK(array_equal(out, array({4, 5, 6, 7, 8, 9, 10, 11}, {2, 4})) + .item()); + } + + { + auto fun = [](std::vector inputs) { + return std::vector{ + slice_update(inputs[0], inputs[1], array({1}), {0})}; + }; + auto x = zeros({2, 2}); + auto upd = ones({2, 1}); + + auto out = vmap(fun)({x, upd})[0]; + CHECK(array_equal(out, array({0, 1, 0, 1}, {2, 2})).item()); + + out = vmap(fun, /* in_axes */ {1, 0}, /* out_axes */ {1})({x, upd})[0]; + CHECK(array_equal(out, array({0, 0, 1, 1}, {2, 2})).item()); + } +}