Dynamic broadcasting for shapeless compile/export (#1722)

* working towards dynamic broadcast

* shapeless broadcast

* fix build + nits

* use broadcast arrays in quantize matmul

* some cleanup / consistency

* mend

* some comments

* add vjp, jvp for broadcast axes
This commit is contained in:
Awni Hannun
2025-01-09 11:04:24 -08:00
committed by GitHub
parent ec36bfa317
commit 1ccaf80575
20 changed files with 471 additions and 163 deletions

View File

@@ -14,6 +14,7 @@
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
namespace mlx::core {
@@ -1399,29 +1400,151 @@ array broadcast_to(
{a});
}
std::vector<array>
broadcast_arrays(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto shape = broadcast_shapes(a.shape(), b.shape());
return {broadcast_to(a, shape, s), broadcast_to(b, shape, s)};
/** Broadcast the input arrays against one another while ignoring the
* axes specified in `ignore_axes`. Note, this API is internal only.
* The `ignore_axes` should be:
* - negative values indicating axes from the end
* - sorted in increasing order
*/
std::vector<array> broadcast_arrays(
const std::vector<array>& inputs,
std::vector<int> ignore_axes,
StreamOrDevice s) {
if (inputs.size() <= 1) {
return inputs;
}
std::vector<array> outputs;
auto shape = BroadcastAxes::output_shape(inputs, ignore_axes);
auto check_and_get_shape = [&shape, &ignore_axes](const array& in) {
auto out_shape = shape;
for (int i = 0; i < ignore_axes.size(); ++i) {
auto ax = ignore_axes[i];
auto pos_ax = in.ndim() + ax;
if (pos_ax < 0 || pos_ax > in.ndim() ||
(i > 0 && ax <= ignore_axes[i - 1])) {
throw std::invalid_argument(
"[broadcast_arrays] Received invalid axes to ignore.");
}
out_shape[out_shape.size() + ax] = in.shape(ax);
}
return out_shape;
};
if (!detail::in_dynamic_tracing()) {
for (auto& in : inputs) {
auto out_shape = check_and_get_shape(in);
if (in.shape() == out_shape) {
outputs.push_back(in);
} else {
outputs.push_back(array(
std::move(out_shape),
in.dtype(),
std::make_shared<Broadcast>(to_stream(s), out_shape),
{in}));
}
}
return outputs;
}
std::vector<array> stop_grad_inputs;
for (auto& in : inputs) {
stop_grad_inputs.push_back(stop_gradient(in, s));
}
for (int i = 0; i < inputs.size(); ++i) {
auto& in = inputs[i];
auto out_shape = check_and_get_shape(in);
if (in.shape() == out_shape) {
outputs.push_back(in);
} else {
// broadcasted array goes first followed by other stopgrad inputs
std::vector<array> p_inputs = {in};
for (int j = 0; j < inputs.size(); ++j) {
if (j == i) {
continue;
}
p_inputs.push_back(stop_grad_inputs[j]);
}
outputs.push_back(array(
std::move(out_shape),
in.dtype(),
std::make_shared<BroadcastAxes>(to_stream(s), ignore_axes),
std::move(p_inputs)));
}
}
return outputs;
}
std::vector<array> broadcast_arrays(
const std::vector<array>& inputs,
StreamOrDevice s /* = {} */) {
Shape shape{};
for (const auto& in : inputs) {
shape = broadcast_shapes(shape, in.shape());
if (inputs.size() <= 1) {
return inputs;
}
auto shape = Broadcast::output_shape(inputs);
std::vector<array> outputs;
for (const auto& in : inputs) {
outputs.push_back(broadcast_to(in, shape, s));
if (!detail::in_dynamic_tracing()) {
for (auto& in : inputs) {
if (in.shape() == shape) {
outputs.push_back(in);
} else {
outputs.push_back(array(
shape,
in.dtype(),
std::make_shared<Broadcast>(to_stream(s), shape),
{in}));
}
}
return outputs;
}
std::vector<array> stop_grad_inputs;
for (auto& in : inputs) {
stop_grad_inputs.push_back(stop_gradient(in, s));
}
for (int i = 0; i < inputs.size(); ++i) {
auto& in = inputs[i];
if (in.shape() == shape) {
outputs.push_back(in);
} else {
// broadcasted array goes first followed by other stopgrad inputs
std::vector<array> p_inputs = {in};
for (int j = 0; j < inputs.size(); ++j) {
if (j == i) {
continue;
}
p_inputs.push_back(stop_grad_inputs[j]);
}
outputs.push_back(array(
shape,
in.dtype(),
std::make_shared<Broadcast>(to_stream(s), shape),
std::move(p_inputs)));
}
}
return outputs;
}
std::pair<array, array>
broadcast_arrays(const array& a, const array& b, StreamOrDevice s) {
auto out = broadcast_arrays({a, b}, s);
return {out[0], out[1]};
}
std::pair<array, array> broadcast_arrays(
const array& a,
const array& b,
std::vector<int> ignore_axes,
StreamOrDevice s) {
auto out = broadcast_arrays({a, b}, std::move(ignore_axes), s);
return {out[0], out[1]};
}
array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape, bool_, std::make_shared<Equal>(to_stream(s)), std::move(inputs));
@@ -1429,7 +1552,7 @@ array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@@ -1440,7 +1563,7 @@ array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape, bool_, std::make_shared<Greater>(to_stream(s)), std::move(inputs));
@@ -1451,7 +1574,7 @@ array greater_equal(
const array& b,
StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@@ -1462,7 +1585,7 @@ array greater_equal(
array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape, bool_, std::make_shared<Less>(to_stream(s)), std::move(inputs));
@@ -1470,7 +1593,7 @@ array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@@ -2277,7 +2400,7 @@ array logical_not(const array& a, StreamOrDevice s /* = {} */) {
array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Broadcast arrays to a common shape
auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@@ -2291,7 +2414,7 @@ array operator&&(const array& a, const array& b) {
array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Broadcast arrays to a common shape
auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@@ -2311,7 +2434,7 @@ array reciprocal(const array& a, StreamOrDevice s /* = {} */) {
array add(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape, out_type, std::make_shared<Add>(to_stream(s)), std::move(inputs));
@@ -2324,7 +2447,7 @@ array operator+(const array& a, const array& b) {
array subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@@ -2340,7 +2463,7 @@ array operator-(const array& a, const array& b) {
array multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@@ -2355,8 +2478,8 @@ array operator*(const array& a, const array& b) {
array divide(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
auto inputs =
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
auto inputs = broadcast_arrays(
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
auto& shape = inputs[0].shape();
return array(
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
@@ -2380,7 +2503,7 @@ array floor_divide(
return floor(divide(a, b, s), s);
}
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
@@ -2388,8 +2511,8 @@ array floor_divide(
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
auto inputs = broadcast_arrays(
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@@ -2407,8 +2530,8 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
if (issubdtype(dtype, complexfloating)) {
throw std::invalid_argument("[divmod] Complex type not supported.");
}
auto inputs =
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
auto inputs = broadcast_arrays(
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
return array::make_arrays(
{inputs[0].shape(), inputs[0].shape()},
{inputs[0].dtype(), inputs[0].dtype()},
@@ -2419,7 +2542,7 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@@ -2431,7 +2554,7 @@ array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@@ -2514,7 +2637,7 @@ array arctan(const array& a, StreamOrDevice s /* = {} */) {
array arctan2(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape, dtype, std::make_shared<ArcTan2>(to_stream(s)), std::move(inputs));
@@ -2610,7 +2733,7 @@ array logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Make sure out type is floating point
auto out_type = at_least_float(promote_types(a.dtype(), b.dtype()));
auto inputs =
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& shape = inputs[0].shape();
return array(
shape,
@@ -2710,19 +2833,7 @@ array matmul(
if (in_a.ndim() > 2 && in_b.ndim() <= 2) {
a = flatten(a, 0, -2, s);
} else if (in_b.ndim() > 2) {
Shape bsx_a(a.shape().begin(), a.shape().end() - 2);
Shape bsx_b(b.shape().begin(), b.shape().end() - 2);
auto inner_shape = broadcast_shapes(bsx_a, bsx_b);
// Broadcast a
inner_shape.push_back(a.shape(-2));
inner_shape.push_back(a.shape(-1));
a = broadcast_to(a, inner_shape, s);
// Broadcast b
*(inner_shape.end() - 2) = b.shape(-2);
*(inner_shape.end() - 1) = b.shape(-1);
b = broadcast_to(b, inner_shape, s);
std::tie(a, b) = broadcast_arrays(a, b, {-2, -1}, s);
}
auto out_shape = a.shape();
@@ -3780,29 +3891,6 @@ array quantized_matmul(
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
// QuantizedMatmul handles w.ndim == 2 case.
if (x.ndim() > 2 && w.ndim() > 2) {
Shape bsx_x(x.shape().begin(), x.shape().end() - 2);
Shape bsx_w(w.shape().begin(), w.shape().end() - 2);
auto inner_shape = broadcast_shapes(bsx_x, bsx_w);
// Broadcast x
inner_shape.push_back(x.shape(-2));
inner_shape.push_back(x.shape(-1));
x = broadcast_to(x, inner_shape, s);
// Broadcast w
*(inner_shape.end() - 2) = w.shape(-2);
*(inner_shape.end() - 1) = w.shape(-1);
w = broadcast_to(w, inner_shape, s);
*(inner_shape.end() - 1) = scales.shape(-1);
scales = broadcast_to(scales, inner_shape, s);
*(inner_shape.end() - 1) = biases.shape(-1);
biases = broadcast_to(biases, inner_shape, s);
}
auto dtype = result_type(x, scales, biases);
if (!issubdtype(dtype, floating)) {
std::ostringstream msg;
@@ -3812,18 +3900,21 @@ array quantized_matmul(
<< " and biases.dtype() == " << biases.dtype();
throw std::invalid_argument(msg.str());
}
std::vector<array> inputs = {
astype(x, dtype), w, astype(scales, dtype), astype(biases, dtype)};
auto out_shape = x.shape();
if (x.ndim() > 2 && w.ndim() > 2) {
inputs = broadcast_arrays(inputs, {-2, -1}, s);
}
auto out_shape = inputs[0].shape();
out_shape.back() = w_outer_dims;
return array(
std::move(out_shape),
dtype,
std::make_shared<QuantizedMatmul>(
to_stream(s), group_size, bits, transpose),
{astype(x, dtype, s),
w,
astype(scales, dtype, s),
astype(biases, dtype, s)});
std::move(inputs));
}
std::tuple<array, array, array> quantize(
@@ -3866,13 +3957,11 @@ array gather_qmm(
// Extract indices and broadcast them
array lhs_indices = indices_or_default(lhs_indices_, x, s);
array rhs_indices = indices_or_default(rhs_indices_, w, s);
auto out_bsx_shape =
broadcast_shapes(lhs_indices.shape(), rhs_indices.shape());
lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s);
rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s);
std::tie(lhs_indices, rhs_indices) =
broadcast_arrays(lhs_indices, rhs_indices, s);
// Compute the full output shape
auto out_shape = out_bsx_shape;
auto out_shape = lhs_indices.shape();
out_shape.push_back(x.shape(-2));
out_shape.push_back(w_outer_dims);
@@ -4374,13 +4463,10 @@ array gather_mm(
int N = b.shape(-1);
int K = a.shape(-1);
auto out_bsx_shape =
broadcast_shapes(lhs_indices.shape(), rhs_indices.shape());
std::tie(lhs_indices, rhs_indices) =
broadcast_arrays(lhs_indices, rhs_indices, s);
lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s);
rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s);
auto out_shape = out_bsx_shape;
auto out_shape = lhs_indices.shape();
out_shape.push_back(M);
out_shape.push_back(N);
@@ -4640,6 +4726,13 @@ array number_of_elements(
ax = normal_axis;
}
if (!detail::in_dynamic_tracing()) {
double numel = 1;
for (auto ax : axes) {
numel *= a.shape(ax);
}
return array(inverted ? 1.0 / numel : numel, dtype);
}
return stop_gradient(array(
Shape{},
dtype,
@@ -4673,7 +4766,7 @@ array bitwise_impl(
throw std::runtime_error(msg.str());
}
auto inputs =
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
auto& out_shape = inputs[0].shape();
return array(
out_shape,