From ebfe64b92da3ed2625ed1a95ac8b7d46616aa576 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 23 Dec 2024 11:25:15 -0800 Subject: [PATCH] shapeless slice update and broadcast when possible (#1727) --- mlx/backend/metal/slicing.cpp | 2 +- mlx/compile.cpp | 39 -------------------- mlx/primitives.cpp | 67 ++++++----------------------------- mlx/primitives.h | 3 ++ python/src/indexing.cpp | 18 ++++++++-- python/tests/test_compile.py | 13 +++++++ 6 files changed, 43 insertions(+), 99 deletions(-) diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp index b37d9c316..3493ea858 100644 --- a/mlx/backend/metal/slicing.cpp +++ b/mlx/backend/metal/slicing.cpp @@ -14,7 +14,7 @@ void slice_gpu( const Shape& start_indices, const Shape& strides, const Stream& s) { - // Calculate out strides, initial offset and if copy needs to be made + // Calculate out strides and initial offset auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides); size_t data_end = 1; diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 38a4b52d0..35b460f2d 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -72,23 +72,6 @@ bool is_fusable(const Primitive& p) { is_noop(p); } -bool allows_shapeless(const Primitive& p) { - return typeid(p) == typeid(Arange) || typeid(p) == typeid(Compiled) || - is_unary(p) || is_binary(p) || is_noop(p) || is_reduction(p) || - typeid(p) == typeid(Softmax) || typeid(p) == typeid(Sort) || - typeid(p) == typeid(ArgSort) || typeid(p) == typeid(ArgPartition) || - typeid(p) == typeid(Partition) || typeid(p) == typeid(Select) || - typeid(p) == typeid(NumberOfElements) || typeid(p) == typeid(Gather) || - typeid(p) == typeid(Transpose) || typeid(p) == typeid(Concatenate) || - typeid(p) == typeid(Matmul) || typeid(p) == typeid(QuantizedMatmul) || - typeid(p) == typeid(Squeeze) || typeid(p) == typeid(ExpandDims) || - typeid(p) == typeid(Flatten) || typeid(p) == typeid(Unflatten) || - typeid(p) == typeid(fast::AffineQuantize) || - typeid(p) == typeid(fast::LayerNorm) || - typeid(p) == typeid(fast::RMSNorm) || typeid(p) == typeid(fast::RoPE) || - typeid(p) == typeid(fast::ScaledDotProductAttention); -} - Compiled::Compiled( Stream stream, std::vector inputs, @@ -800,24 +783,6 @@ std::vector compile_replace( return outputs; } -void compile_validate_shapeless(const std::vector& tape) { - for (auto& t : tape) { - if (!t.has_primitive()) { - continue; - } - auto& p = t.primitive(); - if (allows_shapeless(p)) { - continue; - } - - std::ostringstream msg; - msg << "[compile] Cannot compile primitive "; - p.print(msg); - msg << " with shapeless enabled."; - throw std::invalid_argument(msg.str()); - } -} - bool skip_compile() { return compile_mode() == CompileMode::disabled || !(compile_available_for_device(default_device())); @@ -877,10 +842,6 @@ std::function(const std::vector&)> compile( if (compile_mode() != CompileMode::no_fuse) { compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs); } - - if (shapeless) { - compile_validate_shapeless(entry.tape); - } } // At this point we must have a tape, now replace the placeholders diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index aa8f16c9f..87af9c654 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -740,6 +740,13 @@ bool Broadcast::is_equivalent(const Primitive& other) const { return shape_ == b_other.shape_; } +std::vector Broadcast::output_shapes(const std::vector& inputs) { + if (broadcast_shapes(inputs[0].shape(), shape_) != shape_) { + throw std::invalid_argument("[Broadcast] Unable to infer broadcast shape"); + } + return {shape_}; +} + std::vector Ceil::vjp( const std::vector& primals, const std::vector& cotangents, @@ -3585,63 +3592,9 @@ std::vector Slice::vjp( const std::vector&) { // Check inputs assert(primals.size() == 1); - - std::vector inds; - std::vector ind_axes; - std::vector single_inds; - std::vector single_ind_axes; - for (int i = 0; i < start_indices_.size(); ++i) { - auto start = start_indices_[i]; - auto end = end_indices_[i]; - auto stride = strides_[i]; - if (start == 0 && stride == 1) { - continue; - } - if (stride == 1) { - single_inds.push_back(array(start)); - single_ind_axes.push_back(i); - } else { - inds.push_back(arange(start, end, stride, stream())); - ind_axes.push_back(i); - } - } - - // Transpose and reshape cotangents - auto cotan = cotangents[0]; - if (!ind_axes.empty()) { - Shape cotan_shape; - for (auto ax : ind_axes) { - cotan_shape.push_back(cotan.shape(ax)); - } - std::vector cotan_axes(ind_axes); - for (int j = 0, i = 0; i < cotan.ndim(); ++i) { - if (j < ind_axes.size() && ind_axes[j] == i) { - cotan_shape.push_back(1); - j++; - } else { - cotan_shape.push_back(cotan.shape(i)); - cotan_axes.push_back(i); - } - } - cotan = - reshape(transpose(cotan, cotan_axes, stream()), cotan_shape, stream()); - } - - // Make indices broadcastable - Shape inds_shape(inds.size(), 1); - for (int i = 0; i < inds.size(); ++i) { - inds_shape[i] = inds[i].size(); - inds[i] = reshape(inds[i], inds_shape, stream()); - inds_shape[i] = 1; - } - - // Concatenate all the indices and axes - inds.insert(inds.end(), single_inds.begin(), single_inds.end()); - ind_axes.insert( - ind_axes.end(), single_ind_axes.begin(), single_ind_axes.end()); - - return {scatter_add( - zeros_like(primals[0], stream()), inds, cotan, ind_axes, stream())}; + auto out = zeros_like(primals[0], stream()); + return {slice_update( + out, cotangents[0], start_indices_, end_indices_, strides_, stream())}; } std::vector Slice::jvp( diff --git a/mlx/primitives.h b/mlx/primitives.h index 88b7a63ed..ec084a9ba 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -533,6 +533,8 @@ class Broadcast : public UnaryPrimitive { DEFINE_PRINT(Broadcast) bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + private: Shape shape_; @@ -1943,6 +1945,7 @@ class SliceUpdate : public UnaryPrimitive { DEFINE_GRADS() DEFINE_PRINT(SliceUpdate) bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE() private: Shape start_indices_; diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 40aa2eabb..9770f529e 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -766,7 +766,8 @@ auto mlx_slice_update( const ScalarOrArray& v) { // Can't route to slice update if not slice or tuple if (src.ndim() == 0 || - (!nb::isinstance(obj) && !nb::isinstance(obj))) { + (!nb::isinstance(obj) && !nb::isinstance(obj) && + !nb::isinstance(obj))) { return std::make_pair(false, src); } if (nb::isinstance(obj)) { @@ -777,7 +778,6 @@ auto mlx_slice_update( } } } - // Should be able to route to slice update // Pre process tuple @@ -797,6 +797,20 @@ auto mlx_slice_update( mx::Shape starts(src.ndim(), 0); mx::Shape stops = src.shape(); mx::Shape strides(src.ndim(), 1); + if (nb::isinstance(obj)) { + if (src.ndim() < 1) { + std::ostringstream msg; + msg << "Too many indices for array with " << src.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + auto idx = nb::cast(obj); + idx = idx < 0 ? idx + stops[0] : idx; + starts[0] = idx; + stops[0] = idx + 1; + auto out = slice_update( + src, up, std::move(starts), std::move(stops), std::move(strides)); + return std::make_pair(true, out); + } // If it's just a simple slice, just do a slice update and return if (nb::isinstance(obj)) { diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 42646bfe1..9dfd234b4 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -817,6 +817,19 @@ class TestCompile(mlx_tests.MLXTestCase): fun = mx.compile(lambda a, b: a @ b, shapeless=True) self.assertTrue(mx.allclose(fun(a, b), a @ b)) + def test_shapeless_compile_slice_update(self): + def fun(x): + x[2] = mx.array([3.0]) + return x + + cfun = mx.compile(fun, shapeless=True) + + a = mx.array([0.0, 1.0, 2.0, 3.0]) + self.assertTrue(mx.allclose(cfun(a), fun(a))) + + a = mx.array([0.0, 1.0, 2.0, 3.0, 4.0]) + self.assertTrue(mx.allclose(cfun(a), fun(a))) + if __name__ == "__main__": unittest.main()