From 1ccaf80575e33183237d7c3af5c2ead4437e037d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 9 Jan 2025 11:04:24 -0800 Subject: [PATCH] Dynamic broadcasting for shapeless compile/export (#1722) * working towards dynamic broadcast * shapeless broadcast * fix build + nits * use broadcast arrays in quantize matmul * some cleanup / consistency * mend * some comments * add vjp, jvp for broadcast axes --- mlx/array.cpp | 17 +- mlx/backend/accelerate/primitives.cpp | 1 + mlx/backend/common/common.cpp | 12 +- mlx/backend/common/default_primitives.cpp | 1 + mlx/backend/metal/compiled.cpp | 1 - mlx/backend/metal/primitives.cpp | 4 + mlx/backend/no_cpu/primitives.cpp | 1 + mlx/backend/no_metal/primitives.cpp | 1 + mlx/compile.cpp | 8 +- mlx/compile_impl.h | 3 +- mlx/export.cpp | 4 +- mlx/ops.cpp | 261 +++++++++++++++------- mlx/primitives.cpp | 134 +++++++---- mlx/primitives.h | 29 ++- mlx/transforms.cpp | 9 +- mlx/transforms_impl.h | 30 ++- mlx/utils.cpp | 7 +- python/src/ops.cpp | 25 ++- python/tests/test_compile.py | 73 ++++++ python/tests/test_ops.py | 13 ++ 20 files changed, 471 insertions(+), 163 deletions(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index 331914518f..c2edb4940d 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -10,20 +10,6 @@ namespace mlx::core { -namespace { - -/** Return true if we are currently performing a function transformation in - * order to keep the graph when evaluating tracer arrays. */ -bool in_tracing() { - return detail::InTracing::in_tracing(); -} - -bool retain_graph() { - return detail::RetainGraph::retain_graph(); -} - -} // namespace - array::array(const std::complex& val, Dtype dtype /* = complex64 */) : array_desc_(std::make_shared(Shape{}, dtype)) { auto cval = static_cast(val); @@ -119,7 +105,8 @@ void array::eval() { } bool array::is_tracer() const { - return (array_desc_->is_tracer && in_tracing()) || retain_graph(); + return (array_desc_->is_tracer && detail::in_tracing()) || + detail::retain_graph(); } void array::set_data(allocator::Buffer buffer, Deleter d) { diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 46246812f1..c1c1e61ee3 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -32,6 +32,7 @@ DEFAULT(ArgSort) DEFAULT(AsStrided) DEFAULT(BlockMaskedMM) DEFAULT(Broadcast) +DEFAULT(BroadcastAxes) DEFAULT(Ceil) DEFAULT(Concatenate) DEFAULT(Conjugate) diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp index e960e6ec67..310a02a499 100644 --- a/mlx/backend/common/common.cpp +++ b/mlx/backend/common/common.cpp @@ -42,9 +42,7 @@ void AsStrided::eval(const std::vector& inputs, array& out) { return move_or_copy(in, out, strides_, flags, data_size, offset_); } -void Broadcast::eval(const std::vector& inputs, array& out) { - assert(inputs.size() == 1); - const auto& in = inputs[0]; +void broadcast(const array& in, array& out) { if (out.size() == 0) { out.set_data(nullptr); return; @@ -61,6 +59,14 @@ void Broadcast::eval(const std::vector& inputs, array& out) { move_or_copy(in, out, strides, flags, in.data_size()); } +void Broadcast::eval(const std::vector& inputs, array& out) { + broadcast(inputs[0], out); +} + +void BroadcastAxes::eval(const std::vector& inputs, array& out) { + broadcast(inputs[0], out); +} + void Copy::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); move_or_copy(inputs[0], out); diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 5d359aeadc..f7548a12ff 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -37,6 +37,7 @@ DEFAULT(ArgSort) DEFAULT(AsType) DEFAULT(AsStrided) DEFAULT(Broadcast) +DEFAULT(BroadcastAxes) DEFAULT(BlockMaskedMM) DEFAULT(GatherMM) DEFAULT(GatherQMM) diff --git a/mlx/backend/metal/compiled.cpp b/mlx/backend/metal/compiled.cpp index b75044f1aa..e8e5e77e8e 100644 --- a/mlx/backend/metal/compiled.cpp +++ b/mlx/backend/metal/compiled.cpp @@ -1,6 +1,5 @@ // Copyright © 2023-2024 Apple Inc. #include -#include //TODO #include #include "mlx/backend/common/compiled.h" diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index f796af8053..e12d67d5f3 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -240,6 +240,10 @@ void Broadcast::eval_gpu(const std::vector& inputs, array& out) { eval(inputs, out); } +void BroadcastAxes::eval_gpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} + void Concatenate::eval_gpu(const std::vector& inputs, array& out) { concatenate_gpu(inputs, out, axis_, stream()); } diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 751a234186..acc5d7560f 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -35,6 +35,7 @@ NO_CPU(AsStrided) NO_CPU(BitwiseBinary) NO_CPU(BlockMaskedMM) NO_CPU(Broadcast) +NO_CPU(BroadcastAxes) NO_CPU(Ceil) NO_CPU(Cholesky) NO_CPU(Concatenate) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index b864fc7f29..a4999aa9f0 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -36,6 +36,7 @@ NO_GPU(AsStrided) NO_GPU(BitwiseBinary) NO_GPU(BlockMaskedMM) NO_GPU(Broadcast) +NO_GPU(BroadcastAxes) NO_GPU(Ceil) NO_GPU_MULTI(Compiled) NO_GPU(Concatenate) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 924dc80d67..f0ca6a679f 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -284,9 +284,10 @@ CompilerCache& compiler_cache() { std::pair, std::vector> compile_trace( const std::function(const std::vector&)>& fun, - const std::vector& inputs) { + const std::vector& inputs, + bool shapeless) { // Set the global tracing flag. - detail::InTracing in_tracing; + detail::InTracing in_tracing{shapeless}; // Run the function on placeholder inputs // to get compute graph @@ -824,7 +825,8 @@ std::function(const std::vector&)> compile( // Set the constants entry.constants = std::move(constants); // Trace to build the graph - std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs); + std::tie(entry.inputs, entry.outputs) = + compile_trace(fun, inputs, shapeless); // DFS the graph and get a tape, and a map of array id to (parent, // position in parent inputs) diff --git a/mlx/compile_impl.h b/mlx/compile_impl.h index 56bb02003e..7e9235f363 100644 --- a/mlx/compile_impl.h +++ b/mlx/compile_impl.h @@ -27,7 +27,8 @@ bool compile_available_for_device(const Device& device); std::pair, std::vector> compile_trace( const std::function(const std::vector&)>& fun, - const std::vector& inputs); + const std::vector& inputs, + bool shapeless); using ParentsMap = std::unordered_map>>; diff --git a/mlx/export.cpp b/mlx/export.cpp index 3ab6213cf1..935c871423 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -243,6 +243,7 @@ struct PrimitiveFactory { "RightShift"), SERIALIZE_PRIMITIVE(BlockMaskedMM), SERIALIZE_PRIMITIVE(Broadcast), + SERIALIZE_PRIMITIVE(BroadcastAxes), SERIALIZE_PRIMITIVE(Ceil), SERIALIZE_PRIMITIVE(Concatenate), SERIALIZE_PRIMITIVE(Conjugate), @@ -568,7 +569,8 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { }; // Trace to build the graph - auto [trace_inputs, trace_outputs] = detail::compile_trace(flat_fun, inputs); + auto [trace_inputs, trace_outputs] = + detail::compile_trace(flat_fun, inputs, ftable->shapeless); // DFS the graph and get the tape auto [tape, parents_map] = diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4362b1d085..f1ca2e945e 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -14,6 +14,7 @@ #include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/transforms.h" +#include "mlx/transforms_impl.h" #include "mlx/utils.h" namespace mlx::core { @@ -1399,29 +1400,151 @@ array broadcast_to( {a}); } -std::vector -broadcast_arrays(const array& a, const array& b, StreamOrDevice s /* = {} */) { - auto shape = broadcast_shapes(a.shape(), b.shape()); - return {broadcast_to(a, shape, s), broadcast_to(b, shape, s)}; +/** Broadcast the input arrays against one another while ignoring the + * axes specified in `ignore_axes`. Note, this API is internal only. + * The `ignore_axes` should be: + * - negative values indicating axes from the end + * - sorted in increasing order + */ +std::vector broadcast_arrays( + const std::vector& inputs, + std::vector ignore_axes, + StreamOrDevice s) { + if (inputs.size() <= 1) { + return inputs; + } + + std::vector outputs; + auto shape = BroadcastAxes::output_shape(inputs, ignore_axes); + auto check_and_get_shape = [&shape, &ignore_axes](const array& in) { + auto out_shape = shape; + for (int i = 0; i < ignore_axes.size(); ++i) { + auto ax = ignore_axes[i]; + auto pos_ax = in.ndim() + ax; + if (pos_ax < 0 || pos_ax > in.ndim() || + (i > 0 && ax <= ignore_axes[i - 1])) { + throw std::invalid_argument( + "[broadcast_arrays] Received invalid axes to ignore."); + } + out_shape[out_shape.size() + ax] = in.shape(ax); + } + return out_shape; + }; + + if (!detail::in_dynamic_tracing()) { + for (auto& in : inputs) { + auto out_shape = check_and_get_shape(in); + if (in.shape() == out_shape) { + outputs.push_back(in); + } else { + outputs.push_back(array( + std::move(out_shape), + in.dtype(), + std::make_shared(to_stream(s), out_shape), + {in})); + } + } + return outputs; + } + + std::vector stop_grad_inputs; + for (auto& in : inputs) { + stop_grad_inputs.push_back(stop_gradient(in, s)); + } + + for (int i = 0; i < inputs.size(); ++i) { + auto& in = inputs[i]; + auto out_shape = check_and_get_shape(in); + if (in.shape() == out_shape) { + outputs.push_back(in); + } else { + // broadcasted array goes first followed by other stopgrad inputs + std::vector p_inputs = {in}; + for (int j = 0; j < inputs.size(); ++j) { + if (j == i) { + continue; + } + p_inputs.push_back(stop_grad_inputs[j]); + } + outputs.push_back(array( + std::move(out_shape), + in.dtype(), + std::make_shared(to_stream(s), ignore_axes), + std::move(p_inputs))); + } + } + return outputs; } std::vector broadcast_arrays( const std::vector& inputs, StreamOrDevice s /* = {} */) { - Shape shape{}; - for (const auto& in : inputs) { - shape = broadcast_shapes(shape, in.shape()); + if (inputs.size() <= 1) { + return inputs; } + auto shape = Broadcast::output_shape(inputs); std::vector outputs; - for (const auto& in : inputs) { - outputs.push_back(broadcast_to(in, shape, s)); + + if (!detail::in_dynamic_tracing()) { + for (auto& in : inputs) { + if (in.shape() == shape) { + outputs.push_back(in); + } else { + outputs.push_back(array( + shape, + in.dtype(), + std::make_shared(to_stream(s), shape), + {in})); + } + } + return outputs; + } + + std::vector stop_grad_inputs; + for (auto& in : inputs) { + stop_grad_inputs.push_back(stop_gradient(in, s)); + } + for (int i = 0; i < inputs.size(); ++i) { + auto& in = inputs[i]; + if (in.shape() == shape) { + outputs.push_back(in); + } else { + // broadcasted array goes first followed by other stopgrad inputs + std::vector p_inputs = {in}; + for (int j = 0; j < inputs.size(); ++j) { + if (j == i) { + continue; + } + p_inputs.push_back(stop_grad_inputs[j]); + } + outputs.push_back(array( + shape, + in.dtype(), + std::make_shared(to_stream(s), shape), + std::move(p_inputs))); + } } return outputs; } +std::pair +broadcast_arrays(const array& a, const array& b, StreamOrDevice s) { + auto out = broadcast_arrays({a, b}, s); + return {out[0], out[1]}; +} + +std::pair broadcast_arrays( + const array& a, + const array& b, + std::vector ignore_axes, + StreamOrDevice s) { + auto out = broadcast_arrays({a, b}, std::move(ignore_axes), s); + return {out[0], out[1]}; +} + array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); - auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); + auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); auto& shape = inputs[0].shape(); return array( shape, bool_, std::make_shared(to_stream(s)), std::move(inputs)); @@ -1429,7 +1552,7 @@ array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); - auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); + auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); auto& shape = inputs[0].shape(); return array( shape, @@ -1440,7 +1563,7 @@ array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); - auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); + auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); auto& shape = inputs[0].shape(); return array( shape, bool_, std::make_shared(to_stream(s)), std::move(inputs)); @@ -1451,7 +1574,7 @@ array greater_equal( const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); - auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); + auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); auto& shape = inputs[0].shape(); return array( shape, @@ -1462,7 +1585,7 @@ array greater_equal( array less(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); - auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); + auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); auto& shape = inputs[0].shape(); return array( shape, bool_, std::make_shared(to_stream(s)), std::move(inputs)); @@ -1470,7 +1593,7 @@ array less(const array& a, const array& b, StreamOrDevice s /* = {} */) { array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); - auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); + auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); auto& shape = inputs[0].shape(); return array( shape, @@ -2277,7 +2400,7 @@ array logical_not(const array& a, StreamOrDevice s /* = {} */) { array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) { // Broadcast arrays to a common shape - auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s); + auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s); auto& shape = inputs[0].shape(); return array( shape, @@ -2291,7 +2414,7 @@ array operator&&(const array& a, const array& b) { array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) { // Broadcast arrays to a common shape - auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s); + auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s); auto& shape = inputs[0].shape(); return array( shape, @@ -2311,7 +2434,7 @@ array reciprocal(const array& a, StreamOrDevice s /* = {} */) { array add(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto out_type = promote_types(a.dtype(), b.dtype()); auto inputs = - broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s); + broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); auto& shape = inputs[0].shape(); return array( shape, out_type, std::make_shared(to_stream(s)), std::move(inputs)); @@ -2324,7 +2447,7 @@ array operator+(const array& a, const array& b) { array subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto out_type = promote_types(a.dtype(), b.dtype()); auto inputs = - broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s); + broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); auto& shape = inputs[0].shape(); return array( shape, @@ -2340,7 +2463,7 @@ array operator-(const array& a, const array& b) { array multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto out_type = promote_types(a.dtype(), b.dtype()); auto inputs = - broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s); + broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); auto& shape = inputs[0].shape(); return array( shape, @@ -2355,8 +2478,8 @@ array operator*(const array& a, const array& b) { array divide(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(promote_types(a.dtype(), b.dtype())); - auto inputs = - broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s); + auto inputs = broadcast_arrays( + {astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s); auto& shape = inputs[0].shape(); return array( shape, dtype, std::make_shared(to_stream(s)), std::move(inputs)); @@ -2380,7 +2503,7 @@ array floor_divide( return floor(divide(a, b, s), s); } - auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); + auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); auto& shape = inputs[0].shape(); return array( shape, dtype, std::make_shared(to_stream(s)), std::move(inputs)); @@ -2388,8 +2511,8 @@ array floor_divide( array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); - auto inputs = - broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s); + auto inputs = broadcast_arrays( + {astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s); auto& shape = inputs[0].shape(); return array( shape, @@ -2407,8 +2530,8 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) { if (issubdtype(dtype, complexfloating)) { throw std::invalid_argument("[divmod] Complex type not supported."); } - auto inputs = - broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s); + auto inputs = broadcast_arrays( + {astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s); return array::make_arrays( {inputs[0].shape(), inputs[0].shape()}, {inputs[0].dtype(), inputs[0].dtype()}, @@ -2419,7 +2542,7 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) { array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto out_type = promote_types(a.dtype(), b.dtype()); auto inputs = - broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s); + broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); auto& shape = inputs[0].shape(); return array( shape, @@ -2431,7 +2554,7 @@ array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) { array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto out_type = promote_types(a.dtype(), b.dtype()); auto inputs = - broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s); + broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); auto& shape = inputs[0].shape(); return array( shape, @@ -2514,7 +2637,7 @@ array arctan(const array& a, StreamOrDevice s /* = {} */) { array arctan2(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(promote_types(a.dtype(), b.dtype())); - auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); + auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); auto& shape = inputs[0].shape(); return array( shape, dtype, std::make_shared(to_stream(s)), std::move(inputs)); @@ -2610,7 +2733,7 @@ array logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) { // Make sure out type is floating point auto out_type = at_least_float(promote_types(a.dtype(), b.dtype())); auto inputs = - broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s); + broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); auto& shape = inputs[0].shape(); return array( shape, @@ -2710,19 +2833,7 @@ array matmul( if (in_a.ndim() > 2 && in_b.ndim() <= 2) { a = flatten(a, 0, -2, s); } else if (in_b.ndim() > 2) { - Shape bsx_a(a.shape().begin(), a.shape().end() - 2); - Shape bsx_b(b.shape().begin(), b.shape().end() - 2); - auto inner_shape = broadcast_shapes(bsx_a, bsx_b); - - // Broadcast a - inner_shape.push_back(a.shape(-2)); - inner_shape.push_back(a.shape(-1)); - a = broadcast_to(a, inner_shape, s); - - // Broadcast b - *(inner_shape.end() - 2) = b.shape(-2); - *(inner_shape.end() - 1) = b.shape(-1); - b = broadcast_to(b, inner_shape, s); + std::tie(a, b) = broadcast_arrays(a, b, {-2, -1}, s); } auto out_shape = a.shape(); @@ -3780,29 +3891,6 @@ array quantized_matmul( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( "quantized_matmul", x, w, scales, biases, transpose, group_size, bits); - // QuantizedMatmul handles w.ndim == 2 case. - if (x.ndim() > 2 && w.ndim() > 2) { - Shape bsx_x(x.shape().begin(), x.shape().end() - 2); - Shape bsx_w(w.shape().begin(), w.shape().end() - 2); - auto inner_shape = broadcast_shapes(bsx_x, bsx_w); - - // Broadcast x - inner_shape.push_back(x.shape(-2)); - inner_shape.push_back(x.shape(-1)); - x = broadcast_to(x, inner_shape, s); - - // Broadcast w - *(inner_shape.end() - 2) = w.shape(-2); - *(inner_shape.end() - 1) = w.shape(-1); - w = broadcast_to(w, inner_shape, s); - - *(inner_shape.end() - 1) = scales.shape(-1); - scales = broadcast_to(scales, inner_shape, s); - - *(inner_shape.end() - 1) = biases.shape(-1); - biases = broadcast_to(biases, inner_shape, s); - } - auto dtype = result_type(x, scales, biases); if (!issubdtype(dtype, floating)) { std::ostringstream msg; @@ -3812,18 +3900,21 @@ array quantized_matmul( << " and biases.dtype() == " << biases.dtype(); throw std::invalid_argument(msg.str()); } + std::vector inputs = { + astype(x, dtype), w, astype(scales, dtype), astype(biases, dtype)}; - auto out_shape = x.shape(); + if (x.ndim() > 2 && w.ndim() > 2) { + inputs = broadcast_arrays(inputs, {-2, -1}, s); + } + + auto out_shape = inputs[0].shape(); out_shape.back() = w_outer_dims; return array( std::move(out_shape), dtype, std::make_shared( to_stream(s), group_size, bits, transpose), - {astype(x, dtype, s), - w, - astype(scales, dtype, s), - astype(biases, dtype, s)}); + std::move(inputs)); } std::tuple quantize( @@ -3866,13 +3957,11 @@ array gather_qmm( // Extract indices and broadcast them array lhs_indices = indices_or_default(lhs_indices_, x, s); array rhs_indices = indices_or_default(rhs_indices_, w, s); - auto out_bsx_shape = - broadcast_shapes(lhs_indices.shape(), rhs_indices.shape()); - lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s); - rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s); + std::tie(lhs_indices, rhs_indices) = + broadcast_arrays(lhs_indices, rhs_indices, s); // Compute the full output shape - auto out_shape = out_bsx_shape; + auto out_shape = lhs_indices.shape(); out_shape.push_back(x.shape(-2)); out_shape.push_back(w_outer_dims); @@ -4374,13 +4463,10 @@ array gather_mm( int N = b.shape(-1); int K = a.shape(-1); - auto out_bsx_shape = - broadcast_shapes(lhs_indices.shape(), rhs_indices.shape()); + std::tie(lhs_indices, rhs_indices) = + broadcast_arrays(lhs_indices, rhs_indices, s); - lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s); - rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s); - - auto out_shape = out_bsx_shape; + auto out_shape = lhs_indices.shape(); out_shape.push_back(M); out_shape.push_back(N); @@ -4640,6 +4726,13 @@ array number_of_elements( ax = normal_axis; } + if (!detail::in_dynamic_tracing()) { + double numel = 1; + for (auto ax : axes) { + numel *= a.shape(ax); + } + return array(inverted ? 1.0 / numel : numel, dtype); + } return stop_gradient(array( Shape{}, dtype, @@ -4673,7 +4766,7 @@ array bitwise_impl( throw std::runtime_error(msg.str()); } auto inputs = - broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s); + broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); auto& out_shape = inputs[0].shape(); return array( out_shape, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 5b7bf238ba..807215bb4b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -686,51 +686,51 @@ std::vector BitwiseBinary::vjp( return jvp(primals, cotangents, argnums); } -std::vector Broadcast::vjp( - const std::vector& primals, - const std::vector& cotangents, - const std::vector& argnums, - const std::vector&) { - assert(argnums.size() == 1); - +std::vector +broadcast_vjp(const array& primal, const array& cotan, const Stream& s) { // Reduce cotangents to the shape of the primal - auto& shape = primals[0].shape(); - auto& cotan = cotangents[0]; + auto& shape = primal.shape(); int diff = cotan.ndim() - shape.size(); - std::vector reduce_axes; - for (int i = 0; i < cotan.ndim(); ++i) { - if (i < diff) { - reduce_axes.push_back(i); - } else if (shape[i - diff] != cotan.shape(i)) { + std::vector squeeze_axes(diff); + std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0); + auto reduce_axes = squeeze_axes; + for (int i = diff; i < cotan.ndim(); ++i) { + if (shape[i - diff] != cotan.shape(i)) { reduce_axes.push_back(i); } } - return {reshape(sum(cotan, reduce_axes, true, stream()), shape, stream())}; + return {squeeze(sum(cotan, reduce_axes, true, s), squeeze_axes, s)}; +} + +std::vector Broadcast::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector&, + const std::vector&) { + return broadcast_vjp(primals[0], cotangents[0], stream()); } std::vector Broadcast::jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) { - assert(argnums.size() == 1); - return {broadcast_to(tangents[0], shape_, stream())}; + return {array( + shape_, + tangents[0].dtype(), + std::make_shared(stream(), shape_), + tangents)}; } std::pair, std::vector> Broadcast::vmap( const std::vector& inputs, const std::vector& axes) { - assert(inputs.size() == 1); - assert(axes.size() == 1); auto ax = axes[0]; - auto in = inputs[0]; + auto& in = inputs[0]; if (ax >= 0) { - auto in_shape = in.shape(); int diff = shape_.size() - in.ndim() + 1; assert(diff >= 0); - in_shape.insert(in_shape.begin(), diff, 1); + shape_.insert(shape_.begin() + ax + diff, in.shape(ax)); ax += diff; - shape_.insert(shape_.begin() + ax, in_shape[ax]); - in = reshape(in, in_shape, stream()); } return {{broadcast_to(in, shape_, stream())}, {ax}}; } @@ -740,11 +740,76 @@ bool Broadcast::is_equivalent(const Primitive& other) const { return shape_ == b_other.shape_; } -std::vector Broadcast::output_shapes(const std::vector& inputs) { - if (broadcast_shapes(inputs[0].shape(), shape_) != shape_) { - throw std::invalid_argument("[Broadcast] Unable to infer broadcast shape"); +Shape Broadcast::output_shape(const std::vector& inputs) { + auto shape = inputs[0].shape(); + for (int i = 1; i < inputs.size(); ++i) { + shape = broadcast_shapes(shape, inputs[i].shape()); } - return {shape_}; + return shape; +} + +std::vector Broadcast::output_shapes(const std::vector& inputs) { + if (inputs.size() < 2) { + if (broadcast_shapes(inputs[0].shape(), shape_) != shape_) { + throw std::invalid_argument( + "[Broadcast] Unable to infer broadcast shape"); + } + return {shape_}; + } + return {output_shape(inputs)}; +}; + +std::vector BroadcastAxes::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector&, + const std::vector&) { + return broadcast_vjp(primals[0], cotangents[0], stream()); +} + +std::vector BroadcastAxes::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + return {array( + output_shape(primals, ignore_axes_), + tangents[0].dtype(), + std::make_shared(stream(), ignore_axes_), + tangents)}; +} + +std::pair, std::vector> BroadcastAxes::vmap( + const std::vector& inputs, + const std::vector& axes) { + throw std::invalid_argument("[BroadcastAxes] VMAP NYI"); +} + +bool BroadcastAxes::is_equivalent(const Primitive& other) const { + const auto& b_other = static_cast(other); + return ignore_axes_ == b_other.ignore_axes_; +} + +Shape BroadcastAxes::output_shape( + const std::vector& inputs, + const std::vector& ignore_axes) { + auto shape = Shape{}; + for (auto& in : inputs) { + auto in_shape = in.shape(); + for (auto it = ignore_axes.rbegin(); it != ignore_axes.rend(); ++it) { + in_shape.erase(in_shape.begin() + in.ndim() + *it); + } + shape = broadcast_shapes(shape, in_shape); + } + int dims = ignore_axes.size() + shape.size(); + for (auto ax : ignore_axes) { + shape.insert(shape.begin() + dims + ax, inputs[0].shape(ax)); + } + return shape; +} + +std::vector BroadcastAxes::output_shapes( + const std::vector& inputs) { + return {output_shape(inputs, ignore_axes_)}; } std::vector Ceil::vjp( @@ -3066,14 +3131,9 @@ std::vector Reduce::vjp( const std::vector& outputs) { auto in = primals[0]; - auto shape = in.shape(); - for (auto ax : axes_) { - shape[ax] = 1; - } auto& cotan = cotangents[0]; if (reduce_type_ == Reduce::Sum) { - return { - broadcast_to(reshape(cotan, shape, stream()), in.shape(), stream())}; + return {broadcast_arrays({cotan, in}, stream())[0]}; } else if (reduce_type_ == Reduce::Prod) { auto s = stream(); auto prod_grad_single_axis = @@ -3129,7 +3189,7 @@ std::vector Reduce::vjp( return {grad}; } else { - return {prod_grad_single_axis(in, reshape(cotan, shape, s), axes_[0])}; + return {prod_grad_single_axis(in, cotan, axes_[0])}; } } else if (reduce_type_ == Reduce::Min || reduce_type_ == Reduce::Max) { @@ -3139,9 +3199,7 @@ std::vector Reduce::vjp( } auto mask = equal(in, out, stream()); auto normalizer = sum(mask, axes_, true, stream()); - auto cotan_reshape = reshape(cotan, shape, stream()); - cotan_reshape = divide(cotan_reshape, normalizer, stream()); - return {multiply(cotan_reshape, mask, stream())}; + return {multiply(divide(cotan, normalizer, stream()), mask, stream())}; } else { diff --git a/mlx/primitives.h b/mlx/primitives.h index 97f0c20215..6a66d5f5ae 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -547,6 +547,31 @@ class GatherMM : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class BroadcastAxes : public UnaryPrimitive { + public: + explicit BroadcastAxes(Stream stream, std::vector ignore_axes = {}) + : UnaryPrimitive(stream), ignore_axes_(std::move(ignore_axes)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_PRINT(BroadcastAxes) + bool is_equivalent(const Primitive& other) const override; + static Shape output_shape( + const std::vector& inputs, + const std::vector& ignore_axes); + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return ignore_axes_; + } + + private: + void eval(const std::vector& inputs, array& out); + std::vector ignore_axes_; +}; + class Broadcast : public UnaryPrimitive { public: explicit Broadcast(Stream stream, const Shape& shape) @@ -558,13 +583,13 @@ class Broadcast : public UnaryPrimitive { DEFINE_VMAP() DEFINE_GRADS() DEFINE_PRINT(Broadcast) + static Shape output_shape(const std::vector& inputs); + std::vector output_shapes(const std::vector& inputs) override; bool is_equivalent(const Primitive& other) const override; std::vector state() const { return shape_; }; - std::vector output_shapes(const std::vector& inputs) override; - private: Shape shape_; diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 3c3066e93b..3ac0f8c685 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -31,13 +31,13 @@ class Synchronizer : public Primitive { DEFINE_PRINT(Synchronize); }; -// Initialize the static tracing counter from transforms_impl.h . +// Initialize the static tracing members from transforms_impl.h // -// This is used to implement the in_tracing() function the returns true if we +// These are used to implement the in_tracing() function the returns true if we // are currently under a function transformation and the retain_graph() // function which returns true if we are forced to retain the graph during // evaluation. -int detail::InTracing::tracing_counter{0}; +std::vector detail::InTracing::trace_stack{}; int detail::RetainGraph::tracing_counter{0}; array eval_impl(std::vector outputs, bool async) { @@ -434,7 +434,6 @@ std::pair, std::vector> vjp( } } } - std::vector vjps; for (auto& primal : primals_) { if (auto cotan_it = cotan_map.find(primal.id()); @@ -629,7 +628,7 @@ ValueAndGradFn value_and_grad( for (auto arg : args) { ginputs.push_back(inputs[arg]); } - // Set the incoming gradient to int32, vjp will cast it to the output type + // Set the incoming gradient to float32, vjp will cast it to the output type auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)}); return std::make_pair(outputs, grads); }; diff --git a/mlx/transforms_impl.h b/mlx/transforms_impl.h index c83f6795d5..073957e6b4 100644 --- a/mlx/transforms_impl.h +++ b/mlx/transforms_impl.h @@ -20,19 +20,23 @@ std::vector vmap_replace( // of the codebase that we are during tracing so evals should not throw away // the graph. struct InTracing { - InTracing() { - tracing_counter++; + explicit InTracing(bool dynamic = false) { + trace_stack.push_back(dynamic); } ~InTracing() { - tracing_counter--; + trace_stack.pop_back(); } static bool in_tracing() { - return tracing_counter > 0; + return !trace_stack.empty(); + } + static bool in_dynamic_tracing() { + // compile is always and only the outer-most transform + return in_tracing() && trace_stack.front(); } private: - static int tracing_counter; + static std::vector trace_stack; }; struct RetainGraph { @@ -51,4 +55,20 @@ struct RetainGraph { static int tracing_counter; }; +/** Return true if we are currently performing a function transformation in + * order to keep the graph when evaluating tracer arrays. */ +inline bool in_tracing() { + return detail::InTracing::in_tracing(); +} + +/** Return true if we are in a dynamic (shapeless) trace used for compiling or + * exporting graphs with dynamic shapes. */ +inline bool in_dynamic_tracing() { + return detail::InTracing::in_dynamic_tracing(); +} + +inline bool retain_graph() { + return detail::RetainGraph::retain_graph(); +} + } // namespace mlx::core::detail diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 13fa70994c..2848af2a3d 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -91,8 +91,8 @@ Shape broadcast_shapes(const Shape& s1, const Shape& s2) { const auto& small = ndim1 > ndim2 ? s2 : s1; Shape out_shape(ndim); for (int i = ndim - 1; i >= diff; --i) { - int a = big[i]; - int b = small[i - diff]; + auto a = big[i]; + auto b = small[i - diff]; if (b == a) { out_shape[i] = a; } else if (a == 1 || b == 1) { @@ -100,7 +100,8 @@ Shape broadcast_shapes(const Shape& s1, const Shape& s2) { out_shape[i] = a * b; } else { std::ostringstream msg; - msg << "Shapes " << s1 << " and " << s2 << " cannot be broadcast."; + msg << "[broadcast_shapes] Shapes " << s1 << " and " << s2 + << " cannot be broadcast."; throw std::invalid_argument(msg.str()); } } diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 0147cd709f..04979ea55b 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2799,6 +2799,27 @@ void init_ops(nb::module_& m) { Returns: array: The output array with the new shape. )pbdoc"); + m.def( + "broadcast_arrays", + [](const nb::args& args, mx::StreamOrDevice s) { + return broadcast_arrays(nb::cast>(args), s); + }, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def broadcast_arrays(*arrays: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, ...]"), + R"pbdoc( + Broadcast arrays against one another. + + The broadcasting semantics are the same as Numpy. + + Args: + *arrays (array): The input arrays. + + Returns: + tuple(array): The output arrays with the broadcasted shape. + )pbdoc"); m.def( "softmax", [](const mx::array& a, @@ -3853,8 +3874,8 @@ void init_ops(nb::module_& m) { Args: file (file, str): Path to file to which the arrays are saved. - args (arrays): Arrays to be saved. - kwargs (arrays): Arrays to be saved. Each array will be saved + *args (arrays): Arrays to be saved. + **kwargs (arrays): Arrays to be saved. Each array will be saved with the associated keyword as the output file name. )pbdoc"); m.def( diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 85a6036d74..1dbd052c18 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -849,6 +849,79 @@ class TestCompile(mlx_tests.MLXTestCase): with self.assertRaises(ValueError): compiled_fun(x) + def test_compile_shapeless_with_broadcast(self): + a = mx.array(0.0) + b = mx.ones((2, 2)) + + def fun(a): + return mx.broadcast_to(a, b.shape) + + cfun = mx.compile(fun, shapeless=True) + # Works on the first shape + cfun(a) + + # Fails on a different shape + with self.assertRaises(ValueError): + cfun(mx.array(0.0).reshape(1, 1, 1)) + + def fun(a, b): + return mx.broadcast_arrays(a, b) + + cfun = mx.compile(fun, shapeless=True) + a, b = cfun(a, b) + self.assertEqual(a.shape, (2, 2)) + self.assertEqual(b.shape, (2, 2)) + + # Batched matmul + a = mx.zeros((2, 1, 4, 2)) + b = mx.zeros((3, 2, 5)) + + def fun(a, b): + return a @ b + + cfun = mx.compile(fun, shapeless=True) + out = cfun(a, b) + self.assertEqual(out.shape, (2, 3, 4, 5)) + + # Shapeless compile should be preserved over vjp, jvp, vmap + def fun(args): + return sum(args).sum() + + a = mx.array(0.0) + b = mx.ones((2, 2)) + + cfun = mx.compile(mx.grad(fun), shapeless=True) + out = cfun((a, b)) + + self.assertEqual(out[0].shape, ()) + self.assertEqual(out[1].shape, (2, 2)) + + out = cfun((b, a)) + + self.assertEqual(out[0].shape, (2, 2)) + self.assertEqual(out[1].shape, ()) + + # Shapeless compile should be preserved over vjp, jvp, vmap + def fun(args): + return (args[0] @ args[1]).sum() + + a = mx.zeros((2, 1, 4, 2)) + b = mx.zeros((3, 2, 5)) + + cfun = mx.compile(mx.grad(fun), shapeless=True) + out = cfun((a, b)) + + self.assertEqual(out[0].shape, (2, 1, 4, 2)) + self.assertEqual(out[1].shape, (3, 2, 5)) + + a = mx.zeros((3, 1, 4, 2)) + b = mx.zeros((2, 2, 5)) + + out = cfun((a, b)) + + self.assertEqual(out[0].shape, (3, 1, 4, 2)) + self.assertEqual(out[1].shape, (2, 2, 5)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 9e72a83c75..899899c4e0 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2782,6 +2782,19 @@ class TestOps(mlx_tests.MLXTestCase): expected[1:, 2:, 3:] = update self.assertTrue(mx.array_equal(expected, out)) + def test_broadcast_arrays(self): + a = mx.array(1) + b = mx.array(1.0) + a, b = mx.broadcast_arrays(a, b) + self.assertEqual(a.shape, ()) + self.assertEqual(a.dtype, mx.int32) + self.assertEqual(b.shape, ()) + self.assertEqual(b.dtype, mx.float32) + + a, b = mx.broadcast_arrays(mx.zeros((3, 1, 2)), mx.zeros((4, 1))) + self.assertEqual(a.shape, (3, 4, 2)) + self.assertEqual(b.shape, (3, 4, 2)) + if __name__ == "__main__": unittest.main()