From 76c919b4ecf0cccaa1cfef214d12be0ad71485cc Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 13 Mar 2024 10:34:14 -0700 Subject: [PATCH] NumberOfElements for shapeless compile and vmap fixes (#802) --- mlx/backend/accelerate/primitives.cpp | 1 + mlx/backend/common/default_primitives.cpp | 1 + mlx/backend/common/primitives.cpp | 56 +++++++ mlx/backend/metal/primitives.cpp | 4 + mlx/backend/no_metal/primitives.cpp | 1 + mlx/compile.cpp | 2 +- mlx/ops.cpp | 48 ++++-- mlx/ops.h | 11 ++ mlx/primitives.cpp | 176 +++++++++++++++------- mlx/primitives.h | 31 ++++ mlx/transforms.cpp | 1 + python/tests/test_compile.py | 18 +++ python/tests/test_vmap.py | 11 ++ 13 files changed, 289 insertions(+), 72 deletions(-) diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 9697942bf..25545a4c1 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -38,6 +38,7 @@ DEFAULT(Copy) DEFAULT_MULTI(CustomVJP) DEFAULT_MULTI(Depends) DEFAULT_MULTI(DivMod) +DEFAULT(NumberOfElements) DEFAULT(Equal) DEFAULT(Erf) DEFAULT(ErfInv) diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index e3eb7e0dc..b63414408 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -51,6 +51,7 @@ DEFAULT(Cosh) DEFAULT_MULTI(CustomVJP) DEFAULT_MULTI(Depends) DEFAULT(Divide) +DEFAULT(NumberOfElements) DEFAULT(Remainder) DEFAULT(Equal) DEFAULT(Erf) diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp index 6612a01a8..0bdc09ce9 100644 --- a/mlx/backend/common/primitives.cpp +++ b/mlx/backend/common/primitives.cpp @@ -251,6 +251,62 @@ void Depends::eval( } } +void NumberOfElements::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + double numel = 1; + for (auto ax : axes_) { + numel *= inputs[0].shape(ax); + } + + if (inverted_) { + numel = 1.0 / numel; + } + + switch (out.dtype()) { + case bool_: + *out.data() = static_cast(numel); + break; + case uint8: + *out.data() = static_cast(numel); + break; + case uint16: + *out.data() = static_cast(numel); + break; + case uint32: + *out.data() = static_cast(numel); + break; + case uint64: + *out.data() = static_cast(numel); + break; + case int8: + *out.data() = static_cast(numel); + break; + case int16: + *out.data() = static_cast(numel); + break; + case int32: + *out.data() = static_cast(numel); + break; + case int64: + *out.data() = static_cast(numel); + break; + case float16: + *out.data() = static_cast(numel); + break; + case float32: + *out.data() = static_cast(numel); + break; + case bfloat16: + *out.data() = static_cast(numel); + break; + case complex64: + *out.data() = static_cast(numel); + break; + } +} + void Erf::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 83c20db6f..f0c1d7f3f 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -696,6 +696,10 @@ void Minimum::eval_gpu(const std::vector& inputs, array& out) { binary_op(inputs, out, "min"); } +void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} + void Floor::eval_gpu(const std::vector& inputs, array& out) { unary_op(inputs, out, "floor"); } diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index ec93317f7..5e4dbfe0a 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -43,6 +43,7 @@ NO_GPU_MULTI(CustomVJP) NO_GPU_MULTI(Depends) NO_GPU(Divide) NO_GPU_MULTI(DivMod) +NO_GPU(NumberOfElements) NO_GPU(Remainder) NO_GPU(Equal) NO_GPU(Erf) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 5c3fbf438..ed776361a 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -74,7 +74,7 @@ bool allows_shapeless(const Primitive& p) { is_noop(p) || is_reduction(p) || typeid(p) == typeid(Softmax) || typeid(p) == typeid(Sort) || typeid(p) == typeid(ArgSort) || typeid(p) == typeid(ArgPartition) || typeid(p) == typeid(Partition) || - typeid(p) == typeid(Select); + typeid(p) == typeid(Select) || typeid(p) == typeid(NumberOfElements); } Compiled::Compiled( diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c30a6468a..f86f55f58 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -46,14 +46,6 @@ std::pair, std::vector> compute_reduce_shape( return {out_shape, sorted_axes}; } -int compute_number_of_elements(const array& a, const std::vector& axes) { - int nelements = 1; - for (auto axis : axes) { - nelements *= a.shape(axis); - } - return nelements; -} - Dtype at_least_float(const Dtype& d) { return is_floating_point(d) ? d : promote_types(d, float32); } @@ -1356,9 +1348,9 @@ array mean( throw std::invalid_argument(msg.str()); } } - auto nelements = compute_number_of_elements(a, axes); auto dtype = at_least_float(a.dtype()); - return multiply(sum(a, axes, keepdims, s), array(1.0 / nelements, dtype), s); + auto normalizer = number_of_elements(a, axes, true, dtype, s); + return multiply(sum(a, axes, keepdims, s), normalizer, s); } array mean( @@ -1391,9 +1383,12 @@ array var( auto v = subtract(a2, mu2, s); if (ddof != 0) { - auto nelements = compute_number_of_elements(a, axes); - auto factor = nelements / static_cast(std::max(nelements - ddof, 0)); - v = multiply(v, array(factor, dtype), s); + auto nelements = number_of_elements(a, axes, false, dtype, s); + auto factor = divide( + nelements, + maximum(subtract(nelements, array(ddof, dtype), s), array(0, dtype), s), + s); + v = multiply(v, factor, s); } return v; @@ -1770,7 +1765,7 @@ array logsumexp( const std::vector& axes, bool keepdims /* = false */, StreamOrDevice s /* = {}*/) { - auto maxval = stop_gradient(max(a, axes, true, s)); + auto maxval = stop_gradient(max(a, axes, true, s), s); auto out = log(sum(exp(subtract(a, maxval, s), s), axes, keepdims, s), s); out = add(out, reshape(maxval, out.shape(), s), s); if (!keepdims) { @@ -3600,4 +3595,29 @@ std::vector atleast_3d( return out; } +array number_of_elements( + const array& a, + std::vector axes, + bool inverted, + Dtype dtype /* = int32 */, + StreamOrDevice s /* = {} */) { + for (auto& ax : axes) { + int normal_axis = (ax + a.ndim()) % a.ndim(); + if (normal_axis >= a.ndim() || normal_axis < 0) { + std::ostringstream msg; + msg << "[number_of_elements] Can't get the shape for axis " << ax + << " from an array with " << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + ax = normal_axis; + } + + return stop_gradient(array( + std::vector{}, + dtype, + std::make_unique( + to_stream(s), std::move(axes), inverted, dtype), + {a})); +} + } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index e5aa17c52..263eef35c 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1173,4 +1173,15 @@ std::vector atleast_3d( const std::vector& a, StreamOrDevice s = {}); +/** + * Extract the number of elements along some axes as a scalar array. Used to + * allow shape dependent shapeless compilation (pun intended). + */ +array number_of_elements( + const array& a, + std::vector axes, + bool inverted, + Dtype dtype = int32, + StreamOrDevice s = {}); + } // namespace mlx::core diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index dc1676516..d729ea9ed 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -23,6 +23,10 @@ std::tuple vmap_binary_op( assert(inputs.size() == 2); assert(axes.size() == 2); + if (axes[0] == -1 && axes[1] == -1) { + return {inputs[0], inputs[1], -1}; + } + auto a = inputs[0]; auto b = inputs[1]; int ndim = std::max(a.ndim() + (axes[0] == -1), b.ndim() + (axes[1] == -1)); @@ -55,6 +59,10 @@ std::tuple vmap_ternary_op( assert(inputs.size() == 3); assert(axes.size() == 3); + if (axes[0] == -1 && axes[1] == -1 && axes[2] == -1) { + return {inputs[0], inputs[1], inputs[2], -1}; + } + auto a = inputs[0]; auto b = inputs[1]; auto c = inputs[2]; @@ -403,8 +411,8 @@ std::pair, std::vector> ArgPartition::vmap( assert(inputs.size() == 1); assert(axes.size() == 1); - return { - {argpartition(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes}; + int axis_left = axes[0] >= 0 && axes[0] <= axis_; + return {{argpartition(inputs[0], axis_ + axis_left, stream())}, axes}; } bool ArgPartition::is_equivalent(const Primitive& other) const { @@ -420,7 +428,7 @@ bool ArgReduce::is_equivalent(const Primitive& other) const { std::pair, std::vector> ArgReduce::vmap( const std::vector& inputs, const std::vector& axes) { - int reduce_ax = axis_ + (axis_ >= axes[0]); + int reduce_ax = axis_ + (axes[0] >= 0 && axis_ >= axes[0]); auto& in = inputs[0]; std::vector out; if (reduce_type_ == ArgReduce::ArgMin) { @@ -437,7 +445,8 @@ std::pair, std::vector> ArgSort::vmap( assert(inputs.size() == 1); assert(axes.size() == 1); - return {{argsort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes}; + int axis_left = axes[0] >= 0 && axes[0] <= axis_; + return {{argsort(inputs[0], axis_ + axis_left, stream())}, axes}; } std::vector> ArgReduce::output_shapes( @@ -563,13 +572,16 @@ std::pair, std::vector> Broadcast::vmap( assert(inputs.size() == 1); assert(axes.size() == 1); auto ax = axes[0]; - auto in_shape = inputs[0].shape(); - int diff = shape_.size() - inputs[0].ndim() + 1; - assert(diff >= 0); - in_shape.insert(in_shape.begin(), diff, 1); - ax += diff; - shape_.insert(shape_.begin() + ax, in_shape[ax]); - auto in = reshape(inputs[0], in_shape, stream()); + auto in = inputs[0]; + if (ax >= 0) { + auto in_shape = in.shape(); + int diff = shape_.size() - in.ndim() + 1; + assert(diff >= 0); + in_shape.insert(in_shape.begin(), diff, 1); + ax += diff; + shape_.insert(shape_.begin() + ax, in_shape[ax]); + in = reshape(in, in_shape, stream()); + } return {{broadcast_to(in, shape_, stream())}, {ax}}; } @@ -653,24 +665,30 @@ std::pair, std::vector> Concatenate::vmap( const std::vector& inputs, const std::vector& axes) { std::vector t_inputs; + int out_ax = -1; // Find the first vmapped input int i = 0; for (; i < axes.size(); i++) { t_inputs.push_back(inputs[i]); if (axes[i] >= 0) { + out_ax = axes[i]; break; } } - auto out_ax = axes[i++]; - // Move vmap axes to the same spot. - for (; i < axes.size(); ++i) { - if (out_ax != axes[i] && axes[i] >= 0) { - t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream())); - } else { - t_inputs.push_back(inputs[i]); + if (out_ax >= 0) { + // Advance to the next input + i++; + + // Move vmap axes to the same spot. + for (; i < axes.size(); ++i) { + if (out_ax != axes[i] && axes[i] >= 0) { + t_inputs.push_back(moveaxis(inputs[i], axes[i], out_ax, stream())); + } else { + t_inputs.push_back(inputs[i]); + } } } - auto axis = axis_ + (axis_ >= out_ax); + auto axis = axis_ + (out_ax >= 0 && axis_ >= out_ax); return {{concatenate(t_inputs, axis, stream())}, {out_ax}}; } @@ -1210,13 +1228,15 @@ std::pair, std::vector> FFT::vmap( int ax = axes[0]; auto fft_axes = axes_; auto out_shape = in.shape(); - for (auto& fft_ax : fft_axes) { - if (fft_ax >= ax) { - fft_ax++; - } - if (real_) { - auto n = out_shape[fft_ax]; - out_shape[fft_ax] = inverse_ ? 2 * (n - 1) : n / 2 + 1; + if (ax >= 0) { + for (auto& fft_ax : fft_axes) { + if (fft_ax >= ax) { + fft_ax++; + } + if (real_) { + auto n = out_shape[fft_ax]; + out_shape[fft_ax] = inverse_ ? 2 * (n - 1) : n / 2 + 1; + } } } return { @@ -2064,7 +2084,8 @@ std::pair, std::vector> Partition::vmap( assert(inputs.size() == 1); assert(axes.size() == 1); - return {{partition(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes}; + int axis_left = axes[0] >= 0 && axes[0] <= axis_; + return {{partition(inputs[0], axis_ + axis_left, stream())}, axes}; } bool Partition::is_equivalent(const Primitive& other) const { @@ -2185,7 +2206,9 @@ std::pair, std::vector> RandomBits::vmap( } auto shape = shape_; - shape.insert(shape.begin() + kax, key.shape()[kax]); + if (kax >= 0) { + shape.insert(shape.begin() + kax, key.shape()[kax]); + } auto get_dtype = [width = width_]() { switch (width) { @@ -2217,15 +2240,19 @@ std::pair, std::vector> Reshape::vmap( // Transpose the input so that the vmap dim is first. auto& in = inputs[0]; auto ax = axes[0]; - std::vector reorder(in.ndim()); - std::iota(reorder.begin(), reorder.end(), 0); - reorder.erase(reorder.begin() + ax); - reorder.insert(reorder.begin(), ax); - // Insert the vmap dim into the shape at the beginning. - auto out = transpose(in, reorder, stream()); - shape_.insert(shape_.begin(), in.shape()[ax]); - // Reshape the transposed input to the new shape. - return {{reshape(out, shape_, stream())}, {0}}; + if (ax >= 0) { + std::vector reorder(in.ndim()); + std::iota(reorder.begin(), reorder.end(), 0); + reorder.erase(reorder.begin() + ax); + reorder.insert(reorder.begin(), ax); + // Insert the vmap dim into the shape at the beginning. + auto out = transpose(in, reorder, stream()); + shape_.insert(shape_.begin(), in.shape()[ax]); + // Reshape the transposed input to the new shape. + return {{reshape(out, shape_, stream())}, {0}}; + } else { + return {{reshape(in, shape_, stream())}, {ax}}; + } } std::vector Reshape::vjp( @@ -2349,9 +2376,11 @@ std::pair, std::vector> Reduce::vmap( const std::vector& axes) { auto ax = axes[0]; auto reduce_axes = axes_; - for (auto& rax : reduce_axes) { - if (rax >= ax) { - rax++; + if (ax >= 0) { + for (auto& rax : reduce_axes) { + if (rax >= ax) { + rax++; + } } } auto& in = inputs[0]; @@ -2424,16 +2453,13 @@ std::pair, std::vector> Scan::vmap( auto& in = inputs[0]; auto out_dtype = (in.dtype() == bool_ && reduce_type_ == Scan::Sum) ? int32 : in.dtype(); + int axis_left = axes[0] >= 0 && axes[0] <= axis_; return { {array( in.shape(), out_dtype, std::make_unique( - stream(), - reduce_type_, - axis_ + (axes[0] <= axis_), - reverse_, - inclusive_), + stream(), reduce_type_, axis_ + axis_left, reverse_, inclusive_), {in})}, axes}; } @@ -2698,9 +2724,11 @@ std::pair, std::vector> Slice::vmap( auto strides = strides_; auto ax = axes[0]; auto& input = inputs[0]; - start.insert(start.begin() + ax, 0); - stop.insert(stop.begin() + ax, input.shape(ax)); - strides.insert(strides.begin() + ax, 1); + if (ax >= 0) { + start.insert(start.begin() + ax, 0); + stop.insert(stop.begin() + ax, input.shape(ax)); + strides.insert(strides.begin() + ax, 1); + } return {{slice(input, start, stop, strides, stream())}, {ax}}; } @@ -2796,7 +2824,7 @@ std::pair, std::vector> Softmax::vmap( // We are vectorizing over an axis other than the last one so keep the // softmax axis unchanged - if (axes[0] < inputs[0].ndim() - 1) { + if (axes[0] >= 0 && axes[0] < inputs[0].ndim() - 1) { softmax_axes.push_back(-1); } else { softmax_axes.push_back(-2); @@ -2837,7 +2865,8 @@ std::pair, std::vector> Sort::vmap( assert(inputs.size() == 1); assert(axes.size() == 1); - return {{sort(inputs[0], axis_ + (axes[0] <= axis_), stream())}, axes}; + int axis_left = axes[0] >= 0 && axes[0] <= axis_; + return {{sort(inputs[0], axis_ + axis_left, stream())}, axes}; } std::vector Sort::vjp( @@ -2867,8 +2896,8 @@ bool Sort::is_equivalent(const Primitive& other) const { std::pair, std::vector> Split::vmap( const std::vector& inputs, const std::vector& axes) { - return { - {split(inputs[0], indices_, axis_ + (axes[0] <= axis_), stream())}, axes}; + int axis_left = axes[0] >= 0 && axes[0] <= axis_; + return {{split(inputs[0], indices_, axis_ + axis_left, stream())}, axes}; } std::vector Split::vjp( @@ -2971,7 +3000,7 @@ bool Sqrt::is_equivalent(const Primitive& other) const { std::pair, std::vector> StopGradient::vmap( const std::vector& inputs, const std::vector& axes) { - return {inputs, axes}; + return {{stop_gradient(inputs[0], stream())}, axes}; }; std::vector Subtract::vjp( @@ -3093,12 +3122,14 @@ std::pair, std::vector> Transpose::vmap( assert(inputs.size() == 1); assert(axes.size() == 1); auto vdim = axes[0]; - for (auto& dim : axes_) { - if (dim >= vdim) { - dim++; + if (vdim >= 0) { + for (auto& dim : axes_) { + if (dim >= vdim) { + dim++; + } } + axes_.insert(axes_.begin() + vdim, vdim); } - axes_.insert(axes_.begin() + vdim, vdim); return {{transpose(inputs[0], axes_, stream())}, {vdim}}; } @@ -3107,4 +3138,35 @@ bool Transpose::is_equivalent(const Primitive& other) const { return axes_ == t_other.axes_; } +std::pair, std::vector> NumberOfElements::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + + std::vector new_axes = axes_; + auto vdim = axes[0]; + if (vdim >= 0) { + for (auto& dim : new_axes) { + if (dim >= vdim) { + dim++; + } + } + } + + array out = array( + std::vector{}, + dtype_, + std::make_unique(stream(), new_axes, inverted_, dtype_), + inputs); + + return {{out}, {-1}}; +} + +bool NumberOfElements::is_equivalent(const Primitive& other) const { + const NumberOfElements& n_other = static_cast(other); + return axes_ == n_other.axes_ && inverted_ == n_other.inverted_ && + dtype_ == n_other.dtype_; +} + } // namespace mlx::core diff --git a/mlx/primitives.h b/mlx/primitives.h index aea99eda9..3b79231de 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1229,6 +1229,37 @@ class NotEqual : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class NumberOfElements : public UnaryPrimitive { + public: + explicit NumberOfElements( + Stream stream, + std::vector axes, + bool inverted, + Dtype dtype) + : UnaryPrimitive(stream), + axes_(std::move(axes)), + inverted_(inverted), + dtype_(dtype) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_PRINT(NumberOfElements) + bool is_equivalent(const Primitive& other) const override; + std::vector> output_shapes( + const std::vector& inputs) override { + return {{}}; + } + + private: + std::vector axes_; + bool inverted_; + Dtype dtype_; + + void eval(const std::vector& inputs, array& out); +}; + class Pad : public UnaryPrimitive { public: explicit Pad( diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index e89d50375..695159a57 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -653,6 +653,7 @@ std::vector vmap_replace( v_axes.push_back(-1); } } + auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes); // For each primitive's outputs add its id, the vout id and the vax auto outputs = a.outputs(); diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index beba9ca95..cfdd334cb 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -653,6 +653,24 @@ class TestCompile(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(expected[0], out[0])) self.assertTrue(mx.allclose(expected[1], out[1])) + def test_shapeless_mean(self): + def mean(x): + return mx.mean(x, keepdims=True) + + cmean = mx.compile(mean, shapeless=True) + + x = mx.ones(2) + out = cmean(x) + self.assertTrue(mx.allclose(out, mean(x))) + + x = mx.ones(4) + out = cmean(x) + self.assertTrue(mx.allclose(out, mean(x))) + + x = mx.ones(7) + out = cmean(x) + self.assertTrue(mx.allclose(out, mean(x))) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 72a28654f..99e34c21b 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -253,6 +253,17 @@ class TestVmap(mlx_tests.MLXTestCase): expected = mx.array([2, 1]) self.assertTrue(mx.array_equal(out, expected)) + def test_vmap_mean(self): + a = mx.arange(8).reshape(2, 4) + out = mx.vmap(mx.mean)(a) + expected = mx.mean(a, axis=1) + self.assertTrue(mx.allclose(out, expected)) + + a = mx.arange(16).reshape(2, 2, 4) + out = mx.vmap(mx.vmap(mx.mean))(a) + expected = mx.mean(a, axis=2) + self.assertTrue(mx.allclose(out, expected)) + def test_mismatch_input_sizes(self): a = mx.ones((10, 1)) b = mx.ones((1, 1, 1, 5))