From e88e474fd1e00d6c569d736a12f3a2059861d7eb Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 1 Feb 2024 11:30:28 -0800 Subject: [PATCH] Reduce vmap + some fixes (#601) --- mlx/ops.cpp | 77 +++++++++++++++++++++++++-------------- mlx/primitives.cpp | 46 ++++++++++++++++++++++- mlx/primitives.h | 1 + mlx/transforms.cpp | 24 ++++++++++-- python/tests/test_vmap.py | 46 ++++++++++++++++++++++- 5 files changed, 161 insertions(+), 33 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index b97bf5621..8cb54cc78 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -17,8 +17,7 @@ namespace { std::pair, std::vector> compute_reduce_shape( const std::vector& axes, - const std::vector& shape, - bool keepdims) { + const std::vector& shape) { std::set axes_set; auto ndim = shape.size(); for (auto ax : axes) { @@ -38,7 +37,7 @@ std::pair, std::vector> compute_reduce_shape( for (int i = 0; i < ndim; ++i) { if (axes_set.count(i) == 0) { out_shape.push_back(shape[i]); - } else if (keepdims) { + } else { out_shape.push_back(1); } } @@ -1217,13 +1216,16 @@ array all( if (axes.empty()) { return astype(a, bool_, s); } - auto [out_shape, sorted_axes] = - compute_reduce_shape(axes, a.shape(), keepdims); - return array( + auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape()); + auto out = array( out_shape, bool_, std::make_unique(to_stream(s), Reduce::And, sorted_axes), {a}); + if (!keepdims) { + out = squeeze(out, sorted_axes, s); + } + return out; } array all( @@ -1248,13 +1250,16 @@ array any( if (axes.empty()) { return astype(a, bool_, s); } - auto [out_shape, sorted_axes] = - compute_reduce_shape(axes, a.shape(), keepdims); - return array( + auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape()); + auto out = array( out_shape, bool_, std::make_unique(to_stream(s), Reduce::Or, sorted_axes), {a}); + if (!keepdims) { + out = squeeze(out, sorted_axes, s); + } + return out; } array any( @@ -1279,14 +1284,17 @@ array sum( if (axes.empty()) { return a; } - auto [out_shape, sorted_axes] = - compute_reduce_shape(axes, a.shape(), keepdims); + auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape()); auto out_type = a.dtype() == bool_ ? int32 : a.dtype(); - return array( + auto out = array( out_shape, out_type, std::make_unique(to_stream(s), Reduce::Sum, sorted_axes), {a}); + if (!keepdims) { + out = squeeze(out, sorted_axes, s); + } + return out; } array sum( @@ -1374,13 +1382,16 @@ array prod( if (axes.empty()) { return a; } - auto [out_shape, sorted_axes] = - compute_reduce_shape(axes, a.shape(), keepdims); - return array( + auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape()); + auto out = array( out_shape, a.dtype(), std::make_unique(to_stream(s), Reduce::Prod, sorted_axes), {a}); + if (!keepdims) { + out = squeeze(out, sorted_axes, s); + } + return out; } array prod( @@ -1408,13 +1419,16 @@ array max( if (axes.empty()) { return a; } - auto [out_shape, sorted_axes] = - compute_reduce_shape(axes, a.shape(), keepdims); - return array( + auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape()); + auto out = array( out_shape, a.dtype(), std::make_unique(to_stream(s), Reduce::Max, sorted_axes), {a}); + if (!keepdims) { + out = squeeze(out, sorted_axes, s); + } + return out; } array max( @@ -1442,13 +1456,16 @@ array min( if (axes.empty()) { return a; } - auto [out_shape, sorted_axes] = - compute_reduce_shape(axes, a.shape(), keepdims); - return array( + auto [out_shape, sorted_axes] = compute_reduce_shape(axes, a.shape()); + auto out = array( out_shape, a.dtype(), std::make_unique(to_stream(s), Reduce::Min, sorted_axes), {a}); + if (!keepdims) { + out = squeeze(out, sorted_axes, s); + } + return out; } array min( @@ -1477,14 +1494,17 @@ array argmin( throw std::invalid_argument( "[argmin] Cannot argmin reduce zero size array."); } - auto [out_shape, sorted_axes] = - compute_reduce_shape({axis}, a.shape(), keepdims); - return array( + auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape()); + auto out = array( out_shape, uint32, std::make_unique( to_stream(s), ArgReduce::ArgMin, sorted_axes[0]), {a}); + if (!keepdims) { + out = squeeze(out, sorted_axes, s); + } + return out; } array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) { @@ -1505,14 +1525,17 @@ array argmax( throw std::invalid_argument( "[argmax] Cannot argmax reduce zero size array."); } - auto [out_shape, sorted_axes] = - compute_reduce_shape({axis}, a.shape(), keepdims); - return array( + auto [out_shape, sorted_axes] = compute_reduce_shape({axis}, a.shape()); + auto out = array( out_shape, uint32, std::make_unique( to_stream(s), ArgReduce::ArgMax, sorted_axes[0]), {a}); + if (!keepdims) { + out = squeeze(out, sorted_axes, s); + } + return out; } /** Returns a sorted copy of the flattened array. */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 018cce265..a7e1d205d 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1,5 +1,4 @@ // Copyright © 2023-2024 Apple Inc. - #include #include #include @@ -361,6 +360,20 @@ bool ArgReduce::is_equivalent(const Primitive& other) const { return reduce_type_ == r_other.reduce_type_ && axis_ == r_other.axis_; } +std::pair, std::vector> ArgReduce::vmap( + const std::vector& inputs, + const std::vector& axes) { + int reduce_ax = axis_ + (axis_ >= axes[0]); + auto& in = inputs[0]; + std::vector out; + if (reduce_type_ == ArgReduce::ArgMin) { + out.push_back(argmin(in, reduce_ax, true, stream())); + } else { + out.push_back(argmax(in, reduce_ax, true, stream())); + } + return {out, axes}; +} + std::pair, std::vector> ArgSort::vmap( const std::vector& inputs, const std::vector& axes) { @@ -2153,7 +2166,36 @@ std::vector Reduce::vjp( std::pair, std::vector> Reduce::vmap( const std::vector& inputs, const std::vector& axes) { - throw std::runtime_error("Reduce::vmap not yet implemented."); + auto ax = axes[0]; + auto reduce_axes = axes_; + for (auto& rax : reduce_axes) { + if (rax >= ax) { + rax++; + } + } + auto& in = inputs[0]; + std::vector out; + switch (reduce_type_) { + case Reduce::And: + out.push_back(all(in, reduce_axes, true, stream())); + break; + case Reduce::Or: + out.push_back(any(in, reduce_axes, true, stream())); + break; + case Reduce::Sum: + out.push_back(sum(in, reduce_axes, true, stream())); + break; + case Reduce::Prod: + out.push_back(prod(in, reduce_axes, true, stream())); + break; + case Reduce::Min: + out.push_back(min(in, reduce_axes, true, stream())); + break; + case Reduce::Max: + out.push_back(max(in, reduce_axes, true, stream())); + break; + } + return {out, axes}; } bool Reduce::is_equivalent(const Primitive& other) const { diff --git a/mlx/primitives.h b/mlx/primitives.h index 265095694..15d6c485e 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -341,6 +341,7 @@ class ArgReduce : public UnaryPrimitive { void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; + DEFINE_VMAP() DEFINE_PRINT(ArgReduce) bool is_equivalent(const Primitive& other) const override; diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index aeafcac7d..368c6cff4 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -548,9 +548,8 @@ std::pair, std::vector> vmap_trace( "[vmap] The number of in axes must match the number of inputs."); } - // Run the function on placeholder inputs - // to get the original graph - std::vector s_inputs; + // Some error checking and get the vmap axis size + size_t vmap_ax_size; for (int i = 0; i < inputs.size(); ++i) { if (in_axes[i] != -1) { if (inputs[i].ndim() == 0) { @@ -563,7 +562,26 @@ std::pair, std::vector> vmap_trace( << inputs[i].ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } + vmap_ax_size = inputs[i].shape(in_axes[i]); + } + } + // Check that all vmapped axes have the same size + for (int i = 0; i < inputs.size(); ++i) { + if (in_axes[i] != -1) { + if (size_t in_ax = inputs[i].shape(in_axes[i]); vmap_ax_size != in_ax) { + std::ostringstream msg; + msg << "[vmap] Inconsistent axis sizes: " << in_ax << " and " + << vmap_ax_size << "."; + throw std::invalid_argument(msg.str()); + } + } + } + // Run the function on placeholder inputs + // to get the original graph + std::vector s_inputs; + for (int i = 0; i < inputs.size(); ++i) { + if (in_axes[i] != -1) { std::vector shape = inputs[i].shape(); shape.erase(shape.begin() + in_axes[i]); array in(shape, inputs[i].dtype(), nullptr, {}); diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 695ddd65f..72a28654f 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -1,4 +1,4 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import unittest @@ -220,6 +220,50 @@ class TestVmap(mlx_tests.MLXTestCase): ) self.assertTrue(mx.array_equal(out, expected)) + def test_vmap_reduce(self): + a = mx.ones((5, 5), mx.int32) + out = mx.vmap(lambda x: x.sum())(a) + self.assertTrue(mx.array_equal(out, mx.full((5,), 5))) + + out = mx.vmap(lambda x: x.sum(keepdims=True))(a) + self.assertTrue(mx.array_equal(out, mx.full((5, 1), 5))) + + out = mx.vmap(lambda x: x.sum(axis=0))(a) + self.assertTrue(mx.array_equal(out, mx.full((5,), 5))) + + a = mx.ones((5, 3, 2), mx.int32) + out = mx.vmap(lambda x: x.sum(axis=(0, 1)))(a) + self.assertTrue(mx.array_equal(out, mx.full((5,), 6))) + + a = mx.ones((5, 3, 2), mx.int32) + out = mx.vmap(lambda x: x.sum(axis=(0, 1)), in_axes=(1,))(a) + self.assertTrue(mx.array_equal(out, mx.full((3,), 10))) + + a = mx.ones((5, 3, 2), mx.int32) + out = mx.vmap(lambda x: x.sum(axis=(0, 1)), in_axes=(2,))(a) + self.assertTrue(mx.array_equal(out, mx.full((2,), 15))) + + def test_vmap_argreduce(self): + a = mx.array([[1, 2, 3], [2, 3, 1]]) + out = mx.vmap(lambda x: mx.argmin(x))(a) + expected = mx.array([0, 2]) + self.assertTrue(mx.array_equal(out, expected)) + + out = mx.vmap(lambda x: mx.argmax(x))(a) + expected = mx.array([2, 1]) + self.assertTrue(mx.array_equal(out, expected)) + + def test_mismatch_input_sizes(self): + a = mx.ones((10, 1)) + b = mx.ones((1, 1, 1, 5)) + + with self.assertRaises(ValueError): + out = mx.vmap(lambda x, y: x + y)(a, b) + + b = mx.ones((10, 5)) + with self.assertRaises(ValueError): + out = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))(a, b) + if __name__ == "__main__": unittest.main()