Dynamic broadcasting for shapeless compile/export (#1722)

* working towards dynamic broadcast

* shapeless broadcast

* fix build + nits

* use broadcast arrays in quantize matmul

* some cleanup / consistency

* mend

* some comments

* add vjp, jvp for broadcast axes
This commit is contained in:
Awni Hannun 2025-01-09 11:04:24 -08:00 committed by GitHub
parent ec36bfa317
commit 1ccaf80575
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 471 additions and 163 deletions

View File

@ -10,20 +10,6 @@
namespace mlx::core {
namespace {
/** Return true if we are currently performing a function transformation in
* order to keep the graph when evaluating tracer arrays. */
bool in_tracing() {
return detail::InTracing::in_tracing();
}
bool retain_graph() {
return detail::RetainGraph::retain_graph();
}
} // namespace
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
auto cval = static_cast<complex64_t>(val);
@ -119,7 +105,8 @@ void array::eval() {
}
bool array::is_tracer() const {
return (array_desc_->is_tracer && in_tracing()) || retain_graph();
return (array_desc_->is_tracer && detail::in_tracing()) ||
detail::retain_graph();
}
void array::set_data(allocator::Buffer buffer, Deleter d) {

View File

@ -32,6 +32,7 @@ DEFAULT(ArgSort)
DEFAULT(AsStrided)
DEFAULT(BlockMaskedMM)
DEFAULT(Broadcast)
DEFAULT(BroadcastAxes)
DEFAULT(Ceil)
DEFAULT(Concatenate)
DEFAULT(Conjugate)

View File

@ -42,9 +42,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
return move_or_copy(in, out, strides_, flags, data_size, offset_);
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
void broadcast(const array& in, array& out) {
if (out.size() == 0) {
out.set_data(nullptr);
return;
@ -61,6 +59,14 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
move_or_copy(in, out, strides, flags, in.data_size());
}
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
broadcast(inputs[0], out);
}
void BroadcastAxes::eval(const std::vector<array>& inputs, array& out) {
broadcast(inputs[0], out);
}
void Copy::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
move_or_copy(inputs[0], out);

View File

@ -37,6 +37,7 @@ DEFAULT(ArgSort)
DEFAULT(AsType)
DEFAULT(AsStrided)
DEFAULT(Broadcast)
DEFAULT(BroadcastAxes)
DEFAULT(BlockMaskedMM)
DEFAULT(GatherMM)
DEFAULT(GatherQMM)

View File

@ -1,6 +1,5 @@
// Copyright © 2023-2024 Apple Inc.
#include <fmt/format.h>
#include <iostream> //TODO
#include <sstream>
#include "mlx/backend/common/compiled.h"

View File

@ -240,6 +240,10 @@ void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
eval(inputs, out);
}
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
concatenate_gpu(inputs, out, axis_, stream());
}

View File

@ -35,6 +35,7 @@ NO_CPU(AsStrided)
NO_CPU(BitwiseBinary)
NO_CPU(BlockMaskedMM)
NO_CPU(Broadcast)
NO_CPU(BroadcastAxes)
NO_CPU(Ceil)
NO_CPU(Cholesky)
NO_CPU(Concatenate)

View File

@ -36,6 +36,7 @@ NO_GPU(AsStrided)
NO_GPU(BitwiseBinary)
NO_GPU(BlockMaskedMM)
NO_GPU(Broadcast)
NO_GPU(BroadcastAxes)
NO_GPU(Ceil)
NO_GPU_MULTI(Compiled)
NO_GPU(Concatenate)

View File

@ -284,9 +284,10 @@ CompilerCache& compiler_cache() {
std::pair<std::vector<array>, std::vector<array>> compile_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& inputs) {
const std::vector<array>& inputs,
bool shapeless) {
// Set the global tracing flag.
detail::InTracing in_tracing;
detail::InTracing in_tracing{shapeless};
// Run the function on placeholder inputs
// to get compute graph
@ -824,7 +825,8 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
// Set the constants
entry.constants = std::move(constants);
// Trace to build the graph
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
std::tie(entry.inputs, entry.outputs) =
compile_trace(fun, inputs, shapeless);
// DFS the graph and get a tape, and a map of array id to (parent,
// position in parent inputs)

View File

@ -27,7 +27,8 @@ bool compile_available_for_device(const Device& device);
std::pair<std::vector<array>, std::vector<array>> compile_trace(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& inputs);
const std::vector<array>& inputs,
bool shapeless);
using ParentsMap =
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;

View File

@ -243,6 +243,7 @@ struct PrimitiveFactory {
"RightShift"),
SERIALIZE_PRIMITIVE(BlockMaskedMM),
SERIALIZE_PRIMITIVE(Broadcast),
SERIALIZE_PRIMITIVE(BroadcastAxes),
SERIALIZE_PRIMITIVE(Ceil),
SERIALIZE_PRIMITIVE(Concatenate),
SERIALIZE_PRIMITIVE(Conjugate),
@ -568,7 +569,8 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
};
// Trace to build the graph
auto [trace_inputs, trace_outputs] = detail::compile_trace(flat_fun, inputs);
auto [trace_inputs, trace_outputs] =
detail::compile_trace(flat_fun, inputs, ftable->shapeless);
// DFS the graph and get the tape
auto [tape, parents_map] =

View File

@ -14,6 +14,7 @@
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
namespace mlx::core {
@ -1399,29 +1400,151 @@ array broadcast_to(
{a});
}
std::vector<array>
broadcast_arrays(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto shape = broadcast_shapes(a.shape(), b.shape());
return {broadcast_to(a, shape, s), broadcast_to(b, shape, s)};
/** Broadcast the input arrays against one another while ignoring the
* axes specified in `ignore_axes`. Note, this API is internal only.
* The `ignore_axes` should be:
* - negative values indicating axes from the end
* - sorted in increasing order
*/
std::vector<array> broadcast_arrays(
const std::vector<array>& inputs,
std::vector<int> ignore_axes,
StreamOrDevice s) {
if (inputs.size() <= 1) {
return inputs;
}
std::vector<array> outputs;
auto shape = BroadcastAxes::output_shape(inputs, ignore_axes);
auto check_and_get_shape = [&shape, &ignore_axes](const array& in) {
auto out_shape = shape;
for (int i = 0; i < ignore_axes.size(); ++i) {
auto ax = ignore_axes[i];
auto pos_ax = in.ndim() + ax;
if (pos_ax < 0 || pos_ax > in.ndim() ||
(i > 0 && ax <= ignore_axes[i - 1])) {
throw std::invalid_argument(
"[broadcast_arrays] Received invalid axes to ignore.");
}
out_shape[out_shape.size() + ax] = in.shape(ax);
}
return out_shape;
};
if (!detail::in_dynamic_tracing()) {
for (auto& in : inputs) {
auto out_shape = check_and_get_shape(in);
if (in.shape() == out_shape) {
outputs.push_back(in);
} else {
outputs.push_back(array(
std::move(out_shape),
in.dtype(),
std::make_shared<Broadcast>(to_stream(s), out_shape),
{in}));
}
}
return outputs;
}
std::vector<array> stop_grad_inputs;
for (auto& in : inputs) {
stop_grad_inputs.push_back(stop_gradient(in, s));
}
for (int i = 0; i < inputs.size(); ++i) {
auto& in = inputs[i];
auto out_shape = check_and_get_shape(in);
if (in.shape() == out_shape) {
outputs.push_back(in);
} else {
// broadcasted array goes first followed by other stopgrad inputs
std::vector<array> p_inputs = {in};
for (int j = 0; j < inputs.size(); ++j) {
if (j == i) {
continue;
}
p_inputs.push_back(stop_grad_inputs[j]);
}
outputs.push_back(array(
std::move(out_shape),
in.dtype(),
std::make_shared<BroadcastAxes>(to_stream(s), ignore_axes),
std::move(p_inputs)));
}
}
return outputs;
}
std::vector<array> broadcast_arrays(
const std::vector<array>& inputs,
StreamOrDevice s /* = {} */) {
Shape shape{};
for (const auto& in : inputs) {
shape = broadcast_shapes(shape, in.shape());
if (inputs.size() <= 1) {
return inputs;
}
auto shape = Broadcast::output_shape(inputs);
std::vector<array> outputs;
for (const auto& in : inputs) {
outputs.push_back(broadcast_to(in, shape, s));
if (!detail::in_dynamic_tracing()) {
for (auto& in : inputs) {
if (in.shape() == shape) {
outputs.push_back(in);
} else {
outputs.push_back(array(
shape,
in.dtype(),
std::make_shared<Broadcast>(to_stream(s), shape),
{in}));
}
}
return outputs;
}
std::vector<array> stop_grad_inputs;
for (auto& in : inputs) {
stop_grad_inputs.push_back(stop_gradient(in, s));
}
for (int i = 0; i < inputs.size(); ++i) {
auto& in = inputs[i];
if (in.shape() == shape) {
outputs.push_back(in);
} else {
// broadcasted array goes first followed by other stopgrad inputs
std::vector<array> p_inputs = {in};
for (int j = 0; j < inputs.size(); ++j) {
if (j == i) {
continue;
}
p_inputs.push_back(stop_grad_inputs[j]);
}
outputs.push_back(array(
shape,
in.dtype(),
std::make_shared<Broadcast>(to_stream(s), shape),
std::move(p_inputs)));
}
}
return outputs;
}
std::pair<array, array>
broadcast_arrays(const array& a, const array& b, StreamOrDevice s) {
auto out = broadcast_arrays({a, b}, s);
return {out[0], out[1]};
}
std::pair<array, array> broadcast_arrays(
const array& a,
const array& b,
std::vector<int> ignore_axes,
StreamOrDevice s) {
auto out = broadcast_arrays({a, b}, std::move(ignore_axes), s);
return {out[0], out[1]};
}
array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape, bool_, std::make_shared<Equal>(to_stream(s)), std::move(inputs));
@ -1429,7 +1552,7 @@ array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@ -1440,7 +1563,7 @@ array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape, bool_, std::make_shared<Greater>(to_stream(s)), std::move(inputs));
@ -1451,7 +1574,7 @@ array greater_equal(
const array& b,
StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@ -1462,7 +1585,7 @@ array greater_equal(
array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape, bool_, std::make_shared<Less>(to_stream(s)), std::move(inputs));
@ -1470,7 +1593,7 @@ array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@ -2277,7 +2400,7 @@ array logical_not(const array& a, StreamOrDevice s /* = {} */) {
array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Broadcast arrays to a common shape
auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@ -2291,7 +2414,7 @@ array operator&&(const array& a, const array& b) {
array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Broadcast arrays to a common shape
auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@ -2311,7 +2434,7 @@ array reciprocal(const array& a, StreamOrDevice s /* = {} */) {
array add(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape, out_type, std::make_shared<Add>(to_stream(s)), std::move(inputs));
@ -2324,7 +2447,7 @@ array operator+(const array& a, const array& b) {
array subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@ -2340,7 +2463,7 @@ array operator-(const array& a, const array& b) {
array multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@ -2355,8 +2478,8 @@ array operator*(const array& a, const array& b) {
array divide(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
auto inputs =
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
auto inputs = broadcast_arrays(
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
auto& shape = inputs[0].shape();
return array(
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
@ -2380,7 +2503,7 @@ array floor_divide(
return floor(divide(a, b, s), s);
}
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
@ -2388,8 +2511,8 @@ array floor_divide(
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
auto inputs = broadcast_arrays(
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@ -2407,8 +2530,8 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
if (issubdtype(dtype, complexfloating)) {
throw std::invalid_argument("[divmod] Complex type not supported.");
}
auto inputs =
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
auto inputs = broadcast_arrays(
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
return array::make_arrays(
{inputs[0].shape(), inputs[0].shape()},
{inputs[0].dtype(), inputs[0].dtype()},
@ -2419,7 +2542,7 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@ -2431,7 +2554,7 @@ array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@ -2514,7 +2637,7 @@ array arctan(const array& a, StreamOrDevice s /* = {} */) {
array arctan2(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape, dtype, std::make_shared<ArcTan2>(to_stream(s)), std::move(inputs));
@ -2610,7 +2733,7 @@ array logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Make sure out type is floating point
auto out_type = at_least_float(promote_types(a.dtype(), b.dtype()));
auto inputs =
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@ -2710,19 +2833,7 @@ array matmul(
if (in_a.ndim() > 2 && in_b.ndim() <= 2) {
a = flatten(a, 0, -2, s);
} else if (in_b.ndim() > 2) {
Shape bsx_a(a.shape().begin(), a.shape().end() - 2);
Shape bsx_b(b.shape().begin(), b.shape().end() - 2);
auto inner_shape = broadcast_shapes(bsx_a, bsx_b);
// Broadcast a
inner_shape.push_back(a.shape(-2));
inner_shape.push_back(a.shape(-1));
a = broadcast_to(a, inner_shape, s);
// Broadcast b
*(inner_shape.end() - 2) = b.shape(-2);
*(inner_shape.end() - 1) = b.shape(-1);
b = broadcast_to(b, inner_shape, s);
std::tie(a, b) = broadcast_arrays(a, b, {-2, -1}, s);
}
auto out_shape = a.shape();
@ -3780,29 +3891,6 @@ array quantized_matmul(
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
// QuantizedMatmul handles w.ndim == 2 case.
if (x.ndim() > 2 && w.ndim() > 2) {
Shape bsx_x(x.shape().begin(), x.shape().end() - 2);
Shape bsx_w(w.shape().begin(), w.shape().end() - 2);
auto inner_shape = broadcast_shapes(bsx_x, bsx_w);
// Broadcast x
inner_shape.push_back(x.shape(-2));
inner_shape.push_back(x.shape(-1));
x = broadcast_to(x, inner_shape, s);
// Broadcast w
*(inner_shape.end() - 2) = w.shape(-2);
*(inner_shape.end() - 1) = w.shape(-1);
w = broadcast_to(w, inner_shape, s);
*(inner_shape.end() - 1) = scales.shape(-1);
scales = broadcast_to(scales, inner_shape, s);
*(inner_shape.end() - 1) = biases.shape(-1);
biases = broadcast_to(biases, inner_shape, s);
}
auto dtype = result_type(x, scales, biases);
if (!issubdtype(dtype, floating)) {
std::ostringstream msg;
@ -3812,18 +3900,21 @@ array quantized_matmul(
<< " and biases.dtype() == " << biases.dtype();
throw std::invalid_argument(msg.str());
}
std::vector<array> inputs = {
astype(x, dtype), w, astype(scales, dtype), astype(biases, dtype)};
auto out_shape = x.shape();
if (x.ndim() > 2 && w.ndim() > 2) {
inputs = broadcast_arrays(inputs, {-2, -1}, s);
}
auto out_shape = inputs[0].shape();
out_shape.back() = w_outer_dims;
return array(
std::move(out_shape),
dtype,
std::make_shared<QuantizedMatmul>(
to_stream(s), group_size, bits, transpose),
{astype(x, dtype, s),
w,
astype(scales, dtype, s),
astype(biases, dtype, s)});
std::move(inputs));
}
std::tuple<array, array, array> quantize(
@ -3866,13 +3957,11 @@ array gather_qmm(
// Extract indices and broadcast them
array lhs_indices = indices_or_default(lhs_indices_, x, s);
array rhs_indices = indices_or_default(rhs_indices_, w, s);
auto out_bsx_shape =
broadcast_shapes(lhs_indices.shape(), rhs_indices.shape());
lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s);
rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s);
std::tie(lhs_indices, rhs_indices) =
broadcast_arrays(lhs_indices, rhs_indices, s);
// Compute the full output shape
auto out_shape = out_bsx_shape;
auto out_shape = lhs_indices.shape();
out_shape.push_back(x.shape(-2));
out_shape.push_back(w_outer_dims);
@ -4374,13 +4463,10 @@ array gather_mm(
int N = b.shape(-1);
int K = a.shape(-1);
auto out_bsx_shape =
broadcast_shapes(lhs_indices.shape(), rhs_indices.shape());
std::tie(lhs_indices, rhs_indices) =
broadcast_arrays(lhs_indices, rhs_indices, s);
lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s);
rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s);
auto out_shape = out_bsx_shape;
auto out_shape = lhs_indices.shape();
out_shape.push_back(M);
out_shape.push_back(N);
@ -4640,6 +4726,13 @@ array number_of_elements(
ax = normal_axis;
}
if (!detail::in_dynamic_tracing()) {
double numel = 1;
for (auto ax : axes) {
numel *= a.shape(ax);
}
return array(inverted ? 1.0 / numel : numel, dtype);
}
return stop_gradient(array(
Shape{},
dtype,
@ -4673,7 +4766,7 @@ array bitwise_impl(
throw std::runtime_error(msg.str());
}
auto inputs =
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& out_shape = inputs[0].shape();
return array(
out_shape,

View File

@ -686,51 +686,51 @@ std::vector<array> BitwiseBinary::vjp(
return jvp(primals, cotangents, argnums);
}
std::vector<array> Broadcast::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
assert(argnums.size() == 1);
std::vector<array>
broadcast_vjp(const array& primal, const array& cotan, const Stream& s) {
// Reduce cotangents to the shape of the primal
auto& shape = primals[0].shape();
auto& cotan = cotangents[0];
auto& shape = primal.shape();
int diff = cotan.ndim() - shape.size();
std::vector<int> reduce_axes;
for (int i = 0; i < cotan.ndim(); ++i) {
if (i < diff) {
reduce_axes.push_back(i);
} else if (shape[i - diff] != cotan.shape(i)) {
std::vector<int> squeeze_axes(diff);
std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);
auto reduce_axes = squeeze_axes;
for (int i = diff; i < cotan.ndim(); ++i) {
if (shape[i - diff] != cotan.shape(i)) {
reduce_axes.push_back(i);
}
}
return {reshape(sum(cotan, reduce_axes, true, stream()), shape, stream())};
return {squeeze(sum(cotan, reduce_axes, true, s), squeeze_axes, s)};
}
std::vector<array> Broadcast::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
return broadcast_vjp(primals[0], cotangents[0], stream());
}
std::vector<array> Broadcast::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(argnums.size() == 1);
return {broadcast_to(tangents[0], shape_, stream())};
return {array(
shape_,
tangents[0].dtype(),
std::make_shared<Broadcast>(stream(), shape_),
tangents)};
}
std::pair<std::vector<array>, std::vector<int>> Broadcast::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
auto ax = axes[0];
auto in = inputs[0];
auto& in = inputs[0];
if (ax >= 0) {
auto in_shape = in.shape();
int diff = shape_.size() - in.ndim() + 1;
assert(diff >= 0);
in_shape.insert(in_shape.begin(), diff, 1);
shape_.insert(shape_.begin() + ax + diff, in.shape(ax));
ax += diff;
shape_.insert(shape_.begin() + ax, in_shape[ax]);
in = reshape(in, in_shape, stream());
}
return {{broadcast_to(in, shape_, stream())}, {ax}};
}
@ -740,11 +740,76 @@ bool Broadcast::is_equivalent(const Primitive& other) const {
return shape_ == b_other.shape_;
}
std::vector<Shape> Broadcast::output_shapes(const std::vector<array>& inputs) {
if (broadcast_shapes(inputs[0].shape(), shape_) != shape_) {
throw std::invalid_argument("[Broadcast] Unable to infer broadcast shape");
Shape Broadcast::output_shape(const std::vector<array>& inputs) {
auto shape = inputs[0].shape();
for (int i = 1; i < inputs.size(); ++i) {
shape = broadcast_shapes(shape, inputs[i].shape());
}
return {shape_};
return shape;
}
std::vector<Shape> Broadcast::output_shapes(const std::vector<array>& inputs) {
if (inputs.size() < 2) {
if (broadcast_shapes(inputs[0].shape(), shape_) != shape_) {
throw std::invalid_argument(
"[Broadcast] Unable to infer broadcast shape");
}
return {shape_};
}
return {output_shape(inputs)};
};
std::vector<array> BroadcastAxes::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>&,
const std::vector<array>&) {
return broadcast_vjp(primals[0], cotangents[0], stream());
}
std::vector<array> BroadcastAxes::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
return {array(
output_shape(primals, ignore_axes_),
tangents[0].dtype(),
std::make_shared<BroadcastAxes>(stream(), ignore_axes_),
tangents)};
}
std::pair<std::vector<array>, std::vector<int>> BroadcastAxes::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::invalid_argument("[BroadcastAxes] VMAP NYI");
}
bool BroadcastAxes::is_equivalent(const Primitive& other) const {
const auto& b_other = static_cast<const BroadcastAxes&>(other);
return ignore_axes_ == b_other.ignore_axes_;
}
Shape BroadcastAxes::output_shape(
const std::vector<array>& inputs,
const std::vector<int>& ignore_axes) {
auto shape = Shape{};
for (auto& in : inputs) {
auto in_shape = in.shape();
for (auto it = ignore_axes.rbegin(); it != ignore_axes.rend(); ++it) {
in_shape.erase(in_shape.begin() + in.ndim() + *it);
}
shape = broadcast_shapes(shape, in_shape);
}
int dims = ignore_axes.size() + shape.size();
for (auto ax : ignore_axes) {
shape.insert(shape.begin() + dims + ax, inputs[0].shape(ax));
}
return shape;
}
std::vector<Shape> BroadcastAxes::output_shapes(
const std::vector<array>& inputs) {
return {output_shape(inputs, ignore_axes_)};
}
std::vector<array> Ceil::vjp(
@ -3066,14 +3131,9 @@ std::vector<array> Reduce::vjp(
const std::vector<array>& outputs) {
auto in = primals[0];
auto shape = in.shape();
for (auto ax : axes_) {
shape[ax] = 1;
}
auto& cotan = cotangents[0];
if (reduce_type_ == Reduce::Sum) {
return {
broadcast_to(reshape(cotan, shape, stream()), in.shape(), stream())};
return {broadcast_arrays({cotan, in}, stream())[0]};
} else if (reduce_type_ == Reduce::Prod) {
auto s = stream();
auto prod_grad_single_axis =
@ -3129,7 +3189,7 @@ std::vector<array> Reduce::vjp(
return {grad};
} else {
return {prod_grad_single_axis(in, reshape(cotan, shape, s), axes_[0])};
return {prod_grad_single_axis(in, cotan, axes_[0])};
}
} else if (reduce_type_ == Reduce::Min || reduce_type_ == Reduce::Max) {
@ -3139,9 +3199,7 @@ std::vector<array> Reduce::vjp(
}
auto mask = equal(in, out, stream());
auto normalizer = sum(mask, axes_, true, stream());
auto cotan_reshape = reshape(cotan, shape, stream());
cotan_reshape = divide(cotan_reshape, normalizer, stream());
return {multiply(cotan_reshape, mask, stream())};
return {multiply(divide(cotan, normalizer, stream()), mask, stream())};
}
else {

View File

@ -547,6 +547,31 @@ class GatherMM : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& out);
};
class BroadcastAxes : public UnaryPrimitive {
public:
explicit BroadcastAxes(Stream stream, std::vector<int> ignore_axes = {})
: UnaryPrimitive(stream), ignore_axes_(std::move(ignore_axes)) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(BroadcastAxes)
bool is_equivalent(const Primitive& other) const override;
static Shape output_shape(
const std::vector<array>& inputs,
const std::vector<int>& ignore_axes);
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return ignore_axes_;
}
private:
void eval(const std::vector<array>& inputs, array& out);
std::vector<int> ignore_axes_;
};
class Broadcast : public UnaryPrimitive {
public:
explicit Broadcast(Stream stream, const Shape& shape)
@ -558,13 +583,13 @@ class Broadcast : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Broadcast)
static Shape output_shape(const std::vector<array>& inputs);
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
bool is_equivalent(const Primitive& other) const override;
std::vector<int> state() const {
return shape_;
};
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
private:
Shape shape_;

View File

@ -31,13 +31,13 @@ class Synchronizer : public Primitive {
DEFINE_PRINT(Synchronize);
};
// Initialize the static tracing counter from transforms_impl.h .
// Initialize the static tracing members from transforms_impl.h
//
// This is used to implement the in_tracing() function the returns true if we
// These are used to implement the in_tracing() function the returns true if we
// are currently under a function transformation and the retain_graph()
// function which returns true if we are forced to retain the graph during
// evaluation.
int detail::InTracing::tracing_counter{0};
std::vector<bool> detail::InTracing::trace_stack{};
int detail::RetainGraph::tracing_counter{0};
array eval_impl(std::vector<array> outputs, bool async) {
@ -434,7 +434,6 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
}
}
}
std::vector<array> vjps;
for (auto& primal : primals_) {
if (auto cotan_it = cotan_map.find(primal.id());
@ -629,7 +628,7 @@ ValueAndGradFn value_and_grad(
for (auto arg : args) {
ginputs.push_back(inputs[arg]);
}
// Set the incoming gradient to int32, vjp will cast it to the output type
// Set the incoming gradient to float32, vjp will cast it to the output type
auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)});
return std::make_pair(outputs, grads);
};

View File

@ -20,19 +20,23 @@ std::vector<array> vmap_replace(
// of the codebase that we are during tracing so evals should not throw away
// the graph.
struct InTracing {
InTracing() {
tracing_counter++;
explicit InTracing(bool dynamic = false) {
trace_stack.push_back(dynamic);
}
~InTracing() {
tracing_counter--;
trace_stack.pop_back();
}
static bool in_tracing() {
return tracing_counter > 0;
return !trace_stack.empty();
}
static bool in_dynamic_tracing() {
// compile is always and only the outer-most transform
return in_tracing() && trace_stack.front();
}
private:
static int tracing_counter;
static std::vector<bool> trace_stack;
};
struct RetainGraph {
@ -51,4 +55,20 @@ struct RetainGraph {
static int tracing_counter;
};
/** Return true if we are currently performing a function transformation in
* order to keep the graph when evaluating tracer arrays. */
inline bool in_tracing() {
return detail::InTracing::in_tracing();
}
/** Return true if we are in a dynamic (shapeless) trace used for compiling or
* exporting graphs with dynamic shapes. */
inline bool in_dynamic_tracing() {
return detail::InTracing::in_dynamic_tracing();
}
inline bool retain_graph() {
return detail::RetainGraph::retain_graph();
}
} // namespace mlx::core::detail

View File

@ -91,8 +91,8 @@ Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
const auto& small = ndim1 > ndim2 ? s2 : s1;
Shape out_shape(ndim);
for (int i = ndim - 1; i >= diff; --i) {
int a = big[i];
int b = small[i - diff];
auto a = big[i];
auto b = small[i - diff];
if (b == a) {
out_shape[i] = a;
} else if (a == 1 || b == 1) {
@ -100,7 +100,8 @@ Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
out_shape[i] = a * b;
} else {
std::ostringstream msg;
msg << "Shapes " << s1 << " and " << s2 << " cannot be broadcast.";
msg << "[broadcast_shapes] Shapes " << s1 << " and " << s2
<< " cannot be broadcast.";
throw std::invalid_argument(msg.str());
}
}

View File

@ -2799,6 +2799,27 @@ void init_ops(nb::module_& m) {
Returns:
array: The output array with the new shape.
)pbdoc");
m.def(
"broadcast_arrays",
[](const nb::args& args, mx::StreamOrDevice s) {
return broadcast_arrays(nb::cast<std::vector<mx::array>>(args), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def broadcast_arrays(*arrays: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, ...]"),
R"pbdoc(
Broadcast arrays against one another.
The broadcasting semantics are the same as Numpy.
Args:
*arrays (array): The input arrays.
Returns:
tuple(array): The output arrays with the broadcasted shape.
)pbdoc");
m.def(
"softmax",
[](const mx::array& a,
@ -3853,8 +3874,8 @@ void init_ops(nb::module_& m) {
Args:
file (file, str): Path to file to which the arrays are saved.
args (arrays): Arrays to be saved.
kwargs (arrays): Arrays to be saved. Each array will be saved
*args (arrays): Arrays to be saved.
**kwargs (arrays): Arrays to be saved. Each array will be saved
with the associated keyword as the output file name.
)pbdoc");
m.def(

View File

@ -849,6 +849,79 @@ class TestCompile(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
compiled_fun(x)
def test_compile_shapeless_with_broadcast(self):
a = mx.array(0.0)
b = mx.ones((2, 2))
def fun(a):
return mx.broadcast_to(a, b.shape)
cfun = mx.compile(fun, shapeless=True)
# Works on the first shape
cfun(a)
# Fails on a different shape
with self.assertRaises(ValueError):
cfun(mx.array(0.0).reshape(1, 1, 1))
def fun(a, b):
return mx.broadcast_arrays(a, b)
cfun = mx.compile(fun, shapeless=True)
a, b = cfun(a, b)
self.assertEqual(a.shape, (2, 2))
self.assertEqual(b.shape, (2, 2))
# Batched matmul
a = mx.zeros((2, 1, 4, 2))
b = mx.zeros((3, 2, 5))
def fun(a, b):
return a @ b
cfun = mx.compile(fun, shapeless=True)
out = cfun(a, b)
self.assertEqual(out.shape, (2, 3, 4, 5))
# Shapeless compile should be preserved over vjp, jvp, vmap
def fun(args):
return sum(args).sum()
a = mx.array(0.0)
b = mx.ones((2, 2))
cfun = mx.compile(mx.grad(fun), shapeless=True)
out = cfun((a, b))
self.assertEqual(out[0].shape, ())
self.assertEqual(out[1].shape, (2, 2))
out = cfun((b, a))
self.assertEqual(out[0].shape, (2, 2))
self.assertEqual(out[1].shape, ())
# Shapeless compile should be preserved over vjp, jvp, vmap
def fun(args):
return (args[0] @ args[1]).sum()
a = mx.zeros((2, 1, 4, 2))
b = mx.zeros((3, 2, 5))
cfun = mx.compile(mx.grad(fun), shapeless=True)
out = cfun((a, b))
self.assertEqual(out[0].shape, (2, 1, 4, 2))
self.assertEqual(out[1].shape, (3, 2, 5))
a = mx.zeros((3, 1, 4, 2))
b = mx.zeros((2, 2, 5))
out = cfun((a, b))
self.assertEqual(out[0].shape, (3, 1, 4, 2))
self.assertEqual(out[1].shape, (2, 2, 5))
if __name__ == "__main__":
unittest.main()

View File

@ -2782,6 +2782,19 @@ class TestOps(mlx_tests.MLXTestCase):
expected[1:, 2:, 3:] = update
self.assertTrue(mx.array_equal(expected, out))
def test_broadcast_arrays(self):
a = mx.array(1)
b = mx.array(1.0)
a, b = mx.broadcast_arrays(a, b)
self.assertEqual(a.shape, ())
self.assertEqual(a.dtype, mx.int32)
self.assertEqual(b.shape, ())
self.assertEqual(b.dtype, mx.float32)
a, b = mx.broadcast_arrays(mx.zeros((3, 1, 2)), mx.zeros((4, 1)))
self.assertEqual(a.shape, (3, 4, 2))
self.assertEqual(b.shape, (3, 4, 2))
if __name__ == "__main__":
unittest.main()