From d8fabaa12b4c55594412ef9e1bea2c9b9cc2b495 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 16 Jan 2024 13:33:55 -0800 Subject: [PATCH] Split multi output (#461) * Multi-output split primitive * Add the multi-output split to the ArrayIterator * Add some grad tests for split --- mlx/array.cpp | 19 +++++++++ mlx/array.h | 7 +-- mlx/backend/accelerate/primitives.cpp | 1 + mlx/backend/common/default_primitives.cpp | 1 + mlx/backend/common/primitives.cpp | 52 +++++++++++++++++++++++ mlx/backend/metal/primitives.cpp | 6 +++ mlx/backend/no_metal/primitives.cpp | 1 + mlx/ops.cpp | 24 +++++++++++ mlx/primitives.cpp | 26 ++++++++++++ mlx/primitives.h | 22 ++++++++++ python/tests/test_autograd.py | 19 +++++++++ tests/autograd_tests.cpp | 29 +++++++++++++ 12 files changed, 202 insertions(+), 5 deletions(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index cc85a497b..7f8ee92ec 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -158,7 +158,26 @@ array::ArrayDesc::ArrayDesc( } } +array::ArrayIterator::ArrayIterator(const array& arr, int idx) + : arr(arr), idx(idx) { + if (arr.ndim() == 0) { + throw std::invalid_argument("Cannot iterate over 0-d array."); + } + + // Iterate using split + if (arr.shape(0) > 0 && arr.shape(0) <= 10) { + splits = split(arr, arr.shape(0)); + for (auto& arr_i : splits) { + arr_i = squeeze(arr_i, 0); + } + } +} + array::ArrayIterator::reference array::ArrayIterator::operator*() const { + if (idx >= 0 && idx < splits.size()) { + return splits[idx]; + } + auto start = std::vector(arr.ndim(), 0); auto end = arr.shape(); auto shape = arr.shape(); diff --git a/mlx/array.h b/mlx/array.h index fedfb8570..52c968c0e 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -127,11 +127,7 @@ class array { using value_type = const array; using reference = value_type; - explicit ArrayIterator(const array& arr, int idx = 0) : arr(arr), idx(idx) { - if (arr.ndim() == 0) { - throw std::invalid_argument("Cannot iterate over 0-d array."); - } - } + explicit ArrayIterator(const array& arr, int idx = 0); reference operator*() const; @@ -155,6 +151,7 @@ class array { private: const array& arr; int idx; + std::vector splits; }; ArrayIterator begin() const { diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index e52a47964..8f2da02a2 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -60,6 +60,7 @@ DEFAULT(Scatter) DEFAULT(Sigmoid) DEFAULT(Sign) DEFAULT(Slice) +DEFAULT_MULTI(Split) DEFAULT(Sort) DEFAULT(StopGradient) DEFAULT(Transpose) diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 66f224624..cecf64cee 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -88,6 +88,7 @@ DEFAULT(Sinh) DEFAULT(Slice) DEFAULT(Softmax) DEFAULT(Sort) +DEFAULT_MULTI(Split) DEFAULT(Square) DEFAULT(Sqrt) DEFAULT(StopGradient) diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 889a9841b..6aaf59b45 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -588,6 +588,58 @@ void Slice::eval(const std::vector& inputs, array& out) { out.copy_shared_buffer(in, strides, flags, data_size, data_offset); } +void Split::eval( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + + auto& in = inputs[0]; + + auto compute_new_flags = [](const auto& shape, + const auto& strides, + size_t in_data_size, + auto flags) { + size_t data_size = 1; + size_t f_stride = 1; + size_t b_stride = 1; + flags.row_contiguous = true; + flags.col_contiguous = true; + for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) { + flags.col_contiguous &= strides[i] == f_stride || shape[i] == 1; + flags.row_contiguous &= strides[ri] == b_stride || shape[ri] == 1; + f_stride *= shape[i]; + b_stride *= shape[ri]; + if (strides[i] > 0) { + data_size *= shape[i]; + } + } + + if (data_size == 1) { + // Broadcasted scalar array is contiguous. + flags.contiguous = true; + } else if (data_size == in_data_size) { + // Means we sliced a broadcasted dimension so leave the "no holes" flag + // alone. + } else { + // We sliced something. So either we are row or col contiguous or we + // punched a hole. + flags.contiguous &= flags.row_contiguous || flags.col_contiguous; + } + + return std::pair{flags, data_size}; + }; + + std::vector indices(1, 0); + indices.insert(indices.end(), indices_.begin(), indices_.end()); + for (int i = 0; i < indices.size(); i++) { + size_t offset = indices[i] * in.strides()[axis_]; + auto [new_flags, data_size] = compute_new_flags( + outputs[i].shape(), in.strides(), in.data_size(), in.flags()); + outputs[i].copy_shared_buffer( + in, in.strides(), new_flags, data_size, offset); + } +} + void Square::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index ecea66e85..d9e0619cd 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -727,6 +727,12 @@ void Sinh::eval_gpu(const std::vector& inputs, array& out) { unary_op(inputs, out, "sinh"); } +void Split::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + eval(inputs, outputs); +} + void Square::eval_gpu(const std::vector& inputs, array& out) { unary_op(inputs, out, "square"); } diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 90019f65b..a6902ba3a 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -80,6 +80,7 @@ NO_GPU(Sinh) NO_GPU(Slice) NO_GPU(Softmax) NO_GPU(Sort) +NO_GPU_MULTI(Split) NO_GPU(Square) NO_GPU(Sqrt) NO_GPU(StopGradient) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 87a7d5e96..7fcf17403 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1,4 +1,5 @@ // Copyright © 2023 Apple Inc. +#include #include #include #include @@ -573,6 +574,29 @@ std::vector split( << " for array with shape " << a.shape() << "."; throw std::invalid_argument(msg.str()); } + + if (indices.empty()) { + return {a}; + } + + if (indices.size() < 10 && + std::is_sorted(indices.begin(), indices.end(), std::less<>{}) && + indices[0] > 0 && indices.back() < a.shape(ax)) { + std::vector dtypes(indices.size() + 1, a.dtype()); + std::vector> shapes(indices.size() + 1, a.shape()); + shapes[0][ax] = indices[0]; + for (int i = 1; i < indices.size(); i++) { + shapes[i][ax] = indices[i] - indices[i - 1]; + } + shapes.back()[ax] = a.shape(ax) - indices.back(); + + return array::make_arrays( + shapes, + dtypes, + std::make_shared(to_stream(s), indices, ax), + {a}); + } + std::vector res; auto out_shape = a.shape(); auto start_indices = std::vector(a.ndim(), 0); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 8eccd1f60..9876a2443 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2493,6 +2493,32 @@ bool Sort::is_equivalent(const Primitive& other) const { return axis_ == r_other.axis_; } +std::pair, std::vector> Split::vmap( + const std::vector& inputs, + const std::vector& axes) { + return { + {split(inputs[0], indices_, axis_ + (axes[0] <= axis_), stream())}, axes}; +} + +std::vector Split::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums) { + return {concatenate(cotangents, axis_, stream())}; +} + +std::vector Split::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + return split(tangents[0], indices_, axis_, stream()); +} + +bool Split::is_equivalent(const Primitive& other) const { + const Split& s_other = static_cast(other); + return axis_ == s_other.axis_ && indices_ == s_other.indices_; +} + std::vector Square::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 85ffbf25e..1fb4cf8be 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1421,6 +1421,28 @@ class Sort : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class Split : public Primitive { + public: + explicit Split(Stream stream, const std::vector& indices, int axis) + : Primitive(stream), indices_(indices), axis_(axis){}; + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(Split) + bool is_equivalent(const Primitive& other) const override; + + private: + void eval(const std::vector& inputs, std::vector& outputs); + + std::vector indices_; + int axis_; +}; + class Square : public UnaryPrimitive { public: explicit Square(Stream stream) : UnaryPrimitive(stream){}; diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py index 1279cfb40..c7edc8b76 100644 --- a/python/tests/test_autograd.py +++ b/python/tests/test_autograd.py @@ -339,6 +339,25 @@ class TestAutograd(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(vjps[0], mx.array([[4.0], [5.0], [6.0]]))) self.assertTrue(mx.allclose(vjps[1], mx.array([[[5.0]]]))) + def test_split_against_slice(self): + def f_split(x): + a, _, b = x.split(3, -1) + return (a * b).sum() + + def f_slice(x): + step = x.shape[-1] // 3 + a = x[..., :step] + b = x[..., -step:] + return (a * b).sum() + + x = mx.random.uniform(shape=(100, 300)) + mx.eval(x) + + df1 = mx.grad(f_split) + df2 = mx.grad(f_slice) + + self.assertTrue(mx.allclose(df1(x), df2(x))) + def test_vjp_types(self): def fun(x): return x diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index 6c31ee118..eb2905eb9 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -922,6 +922,35 @@ TEST_CASE("test concatenate grads") { array_equal(out[0], array({0.0f, 0.0f, 2.0f, 0.0f, 3.0f})).item()); } +TEST_CASE("test split grads") { + array x = arange(6, float32); + eval(x); + + { + auto fn = [](const array& x) { + auto parts = split(x, 3); + return parts[0] * parts[1] + parts[2]; + }; + auto out = vjp(fn, {x}, {ones({2})}).second; + + CHECK_EQ(out.size(), 6); + CHECK(array_equal(out, array({2.0f, 3.0f, 0.0f, 1.0f, 1.0f, 1.0f})) + .item()); + } + + { + auto fn = [](const array& x) { + auto parts = split(x, 3); + return parts[0] * parts[2]; + }; + auto out = vjp(fn, {x}, {ones({2})}).second; + + CHECK_EQ(out.size(), 6); + CHECK(array_equal(out, array({4.0f, 5.0f, 0.0f, 0.0f, 0.0f, 1.0f})) + .item()); + } +} + TEST_CASE("test comparison grads") { auto x = ones({3, 1}); auto y = zeros({1, 3});