mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-10 03:06:39 +08:00
Dynamic broadcasting for shapeless compile/export (#1722)
* working towards dynamic broadcast * shapeless broadcast * fix build + nits * use broadcast arrays in quantize matmul * some cleanup / consistency * mend * some comments * add vjp, jvp for broadcast axes
This commit is contained in:
parent
ec36bfa317
commit
1ccaf80575
@ -10,20 +10,6 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
/** Return true if we are currently performing a function transformation in
|
||||
* order to keep the graph when evaluating tracer arrays. */
|
||||
bool in_tracing() {
|
||||
return detail::InTracing::in_tracing();
|
||||
}
|
||||
|
||||
bool retain_graph() {
|
||||
return detail::RetainGraph::retain_graph();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
|
||||
auto cval = static_cast<complex64_t>(val);
|
||||
@ -119,7 +105,8 @@ void array::eval() {
|
||||
}
|
||||
|
||||
bool array::is_tracer() const {
|
||||
return (array_desc_->is_tracer && in_tracing()) || retain_graph();
|
||||
return (array_desc_->is_tracer && detail::in_tracing()) ||
|
||||
detail::retain_graph();
|
||||
}
|
||||
|
||||
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||
|
@ -32,6 +32,7 @@ DEFAULT(ArgSort)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(BroadcastAxes)
|
||||
DEFAULT(Ceil)
|
||||
DEFAULT(Concatenate)
|
||||
DEFAULT(Conjugate)
|
||||
|
@ -42,9 +42,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
return move_or_copy(in, out, strides_, flags, data_size, offset_);
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
const auto& in = inputs[0];
|
||||
void broadcast(const array& in, array& out) {
|
||||
if (out.size() == 0) {
|
||||
out.set_data(nullptr);
|
||||
return;
|
||||
@ -61,6 +59,14 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
move_or_copy(in, out, strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||
broadcast(inputs[0], out);
|
||||
}
|
||||
|
||||
void BroadcastAxes::eval(const std::vector<array>& inputs, array& out) {
|
||||
broadcast(inputs[0], out);
|
||||
}
|
||||
|
||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
move_or_copy(inputs[0], out);
|
||||
|
@ -37,6 +37,7 @@ DEFAULT(ArgSort)
|
||||
DEFAULT(AsType)
|
||||
DEFAULT(AsStrided)
|
||||
DEFAULT(Broadcast)
|
||||
DEFAULT(BroadcastAxes)
|
||||
DEFAULT(BlockMaskedMM)
|
||||
DEFAULT(GatherMM)
|
||||
DEFAULT(GatherQMM)
|
||||
|
@ -1,6 +1,5 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <fmt/format.h>
|
||||
#include <iostream> //TODO
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
|
@ -240,6 +240,10 @@ void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
eval(inputs, out);
|
||||
}
|
||||
|
||||
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
concatenate_gpu(inputs, out, axis_, stream());
|
||||
}
|
||||
|
@ -35,6 +35,7 @@ NO_CPU(AsStrided)
|
||||
NO_CPU(BitwiseBinary)
|
||||
NO_CPU(BlockMaskedMM)
|
||||
NO_CPU(Broadcast)
|
||||
NO_CPU(BroadcastAxes)
|
||||
NO_CPU(Ceil)
|
||||
NO_CPU(Cholesky)
|
||||
NO_CPU(Concatenate)
|
||||
|
@ -36,6 +36,7 @@ NO_GPU(AsStrided)
|
||||
NO_GPU(BitwiseBinary)
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(Broadcast)
|
||||
NO_GPU(BroadcastAxes)
|
||||
NO_GPU(Ceil)
|
||||
NO_GPU_MULTI(Compiled)
|
||||
NO_GPU(Concatenate)
|
||||
|
@ -284,9 +284,10 @@ CompilerCache& compiler_cache() {
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& inputs) {
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless) {
|
||||
// Set the global tracing flag.
|
||||
detail::InTracing in_tracing;
|
||||
detail::InTracing in_tracing{shapeless};
|
||||
|
||||
// Run the function on placeholder inputs
|
||||
// to get compute graph
|
||||
@ -824,7 +825,8 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
// Set the constants
|
||||
entry.constants = std::move(constants);
|
||||
// Trace to build the graph
|
||||
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
|
||||
std::tie(entry.inputs, entry.outputs) =
|
||||
compile_trace(fun, inputs, shapeless);
|
||||
|
||||
// DFS the graph and get a tape, and a map of array id to (parent,
|
||||
// position in parent inputs)
|
||||
|
@ -27,7 +27,8 @@ bool compile_available_for_device(const Device& device);
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||
const std::vector<array>& inputs);
|
||||
const std::vector<array>& inputs,
|
||||
bool shapeless);
|
||||
|
||||
using ParentsMap =
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||
|
@ -243,6 +243,7 @@ struct PrimitiveFactory {
|
||||
"RightShift"),
|
||||
SERIALIZE_PRIMITIVE(BlockMaskedMM),
|
||||
SERIALIZE_PRIMITIVE(Broadcast),
|
||||
SERIALIZE_PRIMITIVE(BroadcastAxes),
|
||||
SERIALIZE_PRIMITIVE(Ceil),
|
||||
SERIALIZE_PRIMITIVE(Concatenate),
|
||||
SERIALIZE_PRIMITIVE(Conjugate),
|
||||
@ -568,7 +569,8 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
||||
};
|
||||
|
||||
// Trace to build the graph
|
||||
auto [trace_inputs, trace_outputs] = detail::compile_trace(flat_fun, inputs);
|
||||
auto [trace_inputs, trace_outputs] =
|
||||
detail::compile_trace(flat_fun, inputs, ftable->shapeless);
|
||||
|
||||
// DFS the graph and get the tape
|
||||
auto [tape, parents_map] =
|
||||
|
261
mlx/ops.cpp
261
mlx/ops.cpp
@ -14,6 +14,7 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@ -1399,29 +1400,151 @@ array broadcast_to(
|
||||
{a});
|
||||
}
|
||||
|
||||
std::vector<array>
|
||||
broadcast_arrays(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto shape = broadcast_shapes(a.shape(), b.shape());
|
||||
return {broadcast_to(a, shape, s), broadcast_to(b, shape, s)};
|
||||
/** Broadcast the input arrays against one another while ignoring the
|
||||
* axes specified in `ignore_axes`. Note, this API is internal only.
|
||||
* The `ignore_axes` should be:
|
||||
* - negative values indicating axes from the end
|
||||
* - sorted in increasing order
|
||||
*/
|
||||
std::vector<array> broadcast_arrays(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<int> ignore_axes,
|
||||
StreamOrDevice s) {
|
||||
if (inputs.size() <= 1) {
|
||||
return inputs;
|
||||
}
|
||||
|
||||
std::vector<array> outputs;
|
||||
auto shape = BroadcastAxes::output_shape(inputs, ignore_axes);
|
||||
auto check_and_get_shape = [&shape, &ignore_axes](const array& in) {
|
||||
auto out_shape = shape;
|
||||
for (int i = 0; i < ignore_axes.size(); ++i) {
|
||||
auto ax = ignore_axes[i];
|
||||
auto pos_ax = in.ndim() + ax;
|
||||
if (pos_ax < 0 || pos_ax > in.ndim() ||
|
||||
(i > 0 && ax <= ignore_axes[i - 1])) {
|
||||
throw std::invalid_argument(
|
||||
"[broadcast_arrays] Received invalid axes to ignore.");
|
||||
}
|
||||
out_shape[out_shape.size() + ax] = in.shape(ax);
|
||||
}
|
||||
return out_shape;
|
||||
};
|
||||
|
||||
if (!detail::in_dynamic_tracing()) {
|
||||
for (auto& in : inputs) {
|
||||
auto out_shape = check_and_get_shape(in);
|
||||
if (in.shape() == out_shape) {
|
||||
outputs.push_back(in);
|
||||
} else {
|
||||
outputs.push_back(array(
|
||||
std::move(out_shape),
|
||||
in.dtype(),
|
||||
std::make_shared<Broadcast>(to_stream(s), out_shape),
|
||||
{in}));
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::vector<array> stop_grad_inputs;
|
||||
for (auto& in : inputs) {
|
||||
stop_grad_inputs.push_back(stop_gradient(in, s));
|
||||
}
|
||||
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
auto out_shape = check_and_get_shape(in);
|
||||
if (in.shape() == out_shape) {
|
||||
outputs.push_back(in);
|
||||
} else {
|
||||
// broadcasted array goes first followed by other stopgrad inputs
|
||||
std::vector<array> p_inputs = {in};
|
||||
for (int j = 0; j < inputs.size(); ++j) {
|
||||
if (j == i) {
|
||||
continue;
|
||||
}
|
||||
p_inputs.push_back(stop_grad_inputs[j]);
|
||||
}
|
||||
outputs.push_back(array(
|
||||
std::move(out_shape),
|
||||
in.dtype(),
|
||||
std::make_shared<BroadcastAxes>(to_stream(s), ignore_axes),
|
||||
std::move(p_inputs)));
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::vector<array> broadcast_arrays(
|
||||
const std::vector<array>& inputs,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
Shape shape{};
|
||||
for (const auto& in : inputs) {
|
||||
shape = broadcast_shapes(shape, in.shape());
|
||||
if (inputs.size() <= 1) {
|
||||
return inputs;
|
||||
}
|
||||
auto shape = Broadcast::output_shape(inputs);
|
||||
std::vector<array> outputs;
|
||||
for (const auto& in : inputs) {
|
||||
outputs.push_back(broadcast_to(in, shape, s));
|
||||
|
||||
if (!detail::in_dynamic_tracing()) {
|
||||
for (auto& in : inputs) {
|
||||
if (in.shape() == shape) {
|
||||
outputs.push_back(in);
|
||||
} else {
|
||||
outputs.push_back(array(
|
||||
shape,
|
||||
in.dtype(),
|
||||
std::make_shared<Broadcast>(to_stream(s), shape),
|
||||
{in}));
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::vector<array> stop_grad_inputs;
|
||||
for (auto& in : inputs) {
|
||||
stop_grad_inputs.push_back(stop_gradient(in, s));
|
||||
}
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
if (in.shape() == shape) {
|
||||
outputs.push_back(in);
|
||||
} else {
|
||||
// broadcasted array goes first followed by other stopgrad inputs
|
||||
std::vector<array> p_inputs = {in};
|
||||
for (int j = 0; j < inputs.size(); ++j) {
|
||||
if (j == i) {
|
||||
continue;
|
||||
}
|
||||
p_inputs.push_back(stop_grad_inputs[j]);
|
||||
}
|
||||
outputs.push_back(array(
|
||||
shape,
|
||||
in.dtype(),
|
||||
std::make_shared<Broadcast>(to_stream(s), shape),
|
||||
std::move(p_inputs)));
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::pair<array, array>
|
||||
broadcast_arrays(const array& a, const array& b, StreamOrDevice s) {
|
||||
auto out = broadcast_arrays({a, b}, s);
|
||||
return {out[0], out[1]};
|
||||
}
|
||||
|
||||
std::pair<array, array> broadcast_arrays(
|
||||
const array& a,
|
||||
const array& b,
|
||||
std::vector<int> ignore_axes,
|
||||
StreamOrDevice s) {
|
||||
auto out = broadcast_arrays({a, b}, std::move(ignore_axes), s);
|
||||
return {out[0], out[1]};
|
||||
}
|
||||
|
||||
array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape, bool_, std::make_shared<Equal>(to_stream(s)), std::move(inputs));
|
||||
@ -1429,7 +1552,7 @@ array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@ -1440,7 +1563,7 @@ array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape, bool_, std::make_shared<Greater>(to_stream(s)), std::move(inputs));
|
||||
@ -1451,7 +1574,7 @@ array greater_equal(
|
||||
const array& b,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@ -1462,7 +1585,7 @@ array greater_equal(
|
||||
|
||||
array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape, bool_, std::make_shared<Less>(to_stream(s)), std::move(inputs));
|
||||
@ -1470,7 +1593,7 @@ array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@ -2277,7 +2400,7 @@ array logical_not(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
// Broadcast arrays to a common shape
|
||||
auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@ -2291,7 +2414,7 @@ array operator&&(const array& a, const array& b) {
|
||||
|
||||
array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
// Broadcast arrays to a common shape
|
||||
auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@ -2311,7 +2434,7 @@ array reciprocal(const array& a, StreamOrDevice s /* = {} */) {
|
||||
array add(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape, out_type, std::make_shared<Add>(to_stream(s)), std::move(inputs));
|
||||
@ -2324,7 +2447,7 @@ array operator+(const array& a, const array& b) {
|
||||
array subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@ -2340,7 +2463,7 @@ array operator-(const array& a, const array& b) {
|
||||
array multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@ -2355,8 +2478,8 @@ array operator*(const array& a, const array& b) {
|
||||
|
||||
array divide(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
|
||||
auto inputs = broadcast_arrays(
|
||||
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
|
||||
@ -2380,7 +2503,7 @@ array floor_divide(
|
||||
return floor(divide(a, b, s), s);
|
||||
}
|
||||
|
||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
|
||||
@ -2388,8 +2511,8 @@ array floor_divide(
|
||||
|
||||
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
|
||||
auto inputs = broadcast_arrays(
|
||||
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@ -2407,8 +2530,8 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
if (issubdtype(dtype, complexfloating)) {
|
||||
throw std::invalid_argument("[divmod] Complex type not supported.");
|
||||
}
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
|
||||
auto inputs = broadcast_arrays(
|
||||
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
|
||||
return array::make_arrays(
|
||||
{inputs[0].shape(), inputs[0].shape()},
|
||||
{inputs[0].dtype(), inputs[0].dtype()},
|
||||
@ -2419,7 +2542,7 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@ -2431,7 +2554,7 @@ array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@ -2514,7 +2637,7 @@ array arctan(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array arctan2(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
|
||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape, dtype, std::make_shared<ArcTan2>(to_stream(s)), std::move(inputs));
|
||||
@ -2610,7 +2733,7 @@ array logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
// Make sure out type is floating point
|
||||
auto out_type = at_least_float(promote_types(a.dtype(), b.dtype()));
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@ -2710,19 +2833,7 @@ array matmul(
|
||||
if (in_a.ndim() > 2 && in_b.ndim() <= 2) {
|
||||
a = flatten(a, 0, -2, s);
|
||||
} else if (in_b.ndim() > 2) {
|
||||
Shape bsx_a(a.shape().begin(), a.shape().end() - 2);
|
||||
Shape bsx_b(b.shape().begin(), b.shape().end() - 2);
|
||||
auto inner_shape = broadcast_shapes(bsx_a, bsx_b);
|
||||
|
||||
// Broadcast a
|
||||
inner_shape.push_back(a.shape(-2));
|
||||
inner_shape.push_back(a.shape(-1));
|
||||
a = broadcast_to(a, inner_shape, s);
|
||||
|
||||
// Broadcast b
|
||||
*(inner_shape.end() - 2) = b.shape(-2);
|
||||
*(inner_shape.end() - 1) = b.shape(-1);
|
||||
b = broadcast_to(b, inner_shape, s);
|
||||
std::tie(a, b) = broadcast_arrays(a, b, {-2, -1}, s);
|
||||
}
|
||||
|
||||
auto out_shape = a.shape();
|
||||
@ -3780,29 +3891,6 @@ array quantized_matmul(
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
||||
|
||||
// QuantizedMatmul handles w.ndim == 2 case.
|
||||
if (x.ndim() > 2 && w.ndim() > 2) {
|
||||
Shape bsx_x(x.shape().begin(), x.shape().end() - 2);
|
||||
Shape bsx_w(w.shape().begin(), w.shape().end() - 2);
|
||||
auto inner_shape = broadcast_shapes(bsx_x, bsx_w);
|
||||
|
||||
// Broadcast x
|
||||
inner_shape.push_back(x.shape(-2));
|
||||
inner_shape.push_back(x.shape(-1));
|
||||
x = broadcast_to(x, inner_shape, s);
|
||||
|
||||
// Broadcast w
|
||||
*(inner_shape.end() - 2) = w.shape(-2);
|
||||
*(inner_shape.end() - 1) = w.shape(-1);
|
||||
w = broadcast_to(w, inner_shape, s);
|
||||
|
||||
*(inner_shape.end() - 1) = scales.shape(-1);
|
||||
scales = broadcast_to(scales, inner_shape, s);
|
||||
|
||||
*(inner_shape.end() - 1) = biases.shape(-1);
|
||||
biases = broadcast_to(biases, inner_shape, s);
|
||||
}
|
||||
|
||||
auto dtype = result_type(x, scales, biases);
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
std::ostringstream msg;
|
||||
@ -3812,18 +3900,21 @@ array quantized_matmul(
|
||||
<< " and biases.dtype() == " << biases.dtype();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
std::vector<array> inputs = {
|
||||
astype(x, dtype), w, astype(scales, dtype), astype(biases, dtype)};
|
||||
|
||||
auto out_shape = x.shape();
|
||||
if (x.ndim() > 2 && w.ndim() > 2) {
|
||||
inputs = broadcast_arrays(inputs, {-2, -1}, s);
|
||||
}
|
||||
|
||||
auto out_shape = inputs[0].shape();
|
||||
out_shape.back() = w_outer_dims;
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
dtype,
|
||||
std::make_shared<QuantizedMatmul>(
|
||||
to_stream(s), group_size, bits, transpose),
|
||||
{astype(x, dtype, s),
|
||||
w,
|
||||
astype(scales, dtype, s),
|
||||
astype(biases, dtype, s)});
|
||||
std::move(inputs));
|
||||
}
|
||||
|
||||
std::tuple<array, array, array> quantize(
|
||||
@ -3866,13 +3957,11 @@ array gather_qmm(
|
||||
// Extract indices and broadcast them
|
||||
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
||||
array rhs_indices = indices_or_default(rhs_indices_, w, s);
|
||||
auto out_bsx_shape =
|
||||
broadcast_shapes(lhs_indices.shape(), rhs_indices.shape());
|
||||
lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s);
|
||||
rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s);
|
||||
std::tie(lhs_indices, rhs_indices) =
|
||||
broadcast_arrays(lhs_indices, rhs_indices, s);
|
||||
|
||||
// Compute the full output shape
|
||||
auto out_shape = out_bsx_shape;
|
||||
auto out_shape = lhs_indices.shape();
|
||||
out_shape.push_back(x.shape(-2));
|
||||
out_shape.push_back(w_outer_dims);
|
||||
|
||||
@ -4374,13 +4463,10 @@ array gather_mm(
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
auto out_bsx_shape =
|
||||
broadcast_shapes(lhs_indices.shape(), rhs_indices.shape());
|
||||
std::tie(lhs_indices, rhs_indices) =
|
||||
broadcast_arrays(lhs_indices, rhs_indices, s);
|
||||
|
||||
lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s);
|
||||
rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s);
|
||||
|
||||
auto out_shape = out_bsx_shape;
|
||||
auto out_shape = lhs_indices.shape();
|
||||
out_shape.push_back(M);
|
||||
out_shape.push_back(N);
|
||||
|
||||
@ -4640,6 +4726,13 @@ array number_of_elements(
|
||||
ax = normal_axis;
|
||||
}
|
||||
|
||||
if (!detail::in_dynamic_tracing()) {
|
||||
double numel = 1;
|
||||
for (auto ax : axes) {
|
||||
numel *= a.shape(ax);
|
||||
}
|
||||
return array(inverted ? 1.0 / numel : numel, dtype);
|
||||
}
|
||||
return stop_gradient(array(
|
||||
Shape{},
|
||||
dtype,
|
||||
@ -4673,7 +4766,7 @@ array bitwise_impl(
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||
auto& out_shape = inputs[0].shape();
|
||||
return array(
|
||||
out_shape,
|
||||
|
@ -686,51 +686,51 @@ std::vector<array> BitwiseBinary::vjp(
|
||||
return jvp(primals, cotangents, argnums);
|
||||
}
|
||||
|
||||
std::vector<array> Broadcast::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>& argnums,
|
||||
const std::vector<array>&) {
|
||||
assert(argnums.size() == 1);
|
||||
|
||||
std::vector<array>
|
||||
broadcast_vjp(const array& primal, const array& cotan, const Stream& s) {
|
||||
// Reduce cotangents to the shape of the primal
|
||||
auto& shape = primals[0].shape();
|
||||
auto& cotan = cotangents[0];
|
||||
auto& shape = primal.shape();
|
||||
int diff = cotan.ndim() - shape.size();
|
||||
std::vector<int> reduce_axes;
|
||||
for (int i = 0; i < cotan.ndim(); ++i) {
|
||||
if (i < diff) {
|
||||
reduce_axes.push_back(i);
|
||||
} else if (shape[i - diff] != cotan.shape(i)) {
|
||||
std::vector<int> squeeze_axes(diff);
|
||||
std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);
|
||||
auto reduce_axes = squeeze_axes;
|
||||
for (int i = diff; i < cotan.ndim(); ++i) {
|
||||
if (shape[i - diff] != cotan.shape(i)) {
|
||||
reduce_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
return {reshape(sum(cotan, reduce_axes, true, stream()), shape, stream())};
|
||||
return {squeeze(sum(cotan, reduce_axes, true, s), squeeze_axes, s)};
|
||||
}
|
||||
|
||||
std::vector<array> Broadcast::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
return broadcast_vjp(primals[0], cotangents[0], stream());
|
||||
}
|
||||
|
||||
std::vector<array> Broadcast::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
assert(argnums.size() == 1);
|
||||
return {broadcast_to(tangents[0], shape_, stream())};
|
||||
return {array(
|
||||
shape_,
|
||||
tangents[0].dtype(),
|
||||
std::make_shared<Broadcast>(stream(), shape_),
|
||||
tangents)};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> Broadcast::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(axes.size() == 1);
|
||||
auto ax = axes[0];
|
||||
auto in = inputs[0];
|
||||
auto& in = inputs[0];
|
||||
if (ax >= 0) {
|
||||
auto in_shape = in.shape();
|
||||
int diff = shape_.size() - in.ndim() + 1;
|
||||
assert(diff >= 0);
|
||||
in_shape.insert(in_shape.begin(), diff, 1);
|
||||
shape_.insert(shape_.begin() + ax + diff, in.shape(ax));
|
||||
ax += diff;
|
||||
shape_.insert(shape_.begin() + ax, in_shape[ax]);
|
||||
in = reshape(in, in_shape, stream());
|
||||
}
|
||||
return {{broadcast_to(in, shape_, stream())}, {ax}};
|
||||
}
|
||||
@ -740,11 +740,76 @@ bool Broadcast::is_equivalent(const Primitive& other) const {
|
||||
return shape_ == b_other.shape_;
|
||||
}
|
||||
|
||||
std::vector<Shape> Broadcast::output_shapes(const std::vector<array>& inputs) {
|
||||
if (broadcast_shapes(inputs[0].shape(), shape_) != shape_) {
|
||||
throw std::invalid_argument("[Broadcast] Unable to infer broadcast shape");
|
||||
Shape Broadcast::output_shape(const std::vector<array>& inputs) {
|
||||
auto shape = inputs[0].shape();
|
||||
for (int i = 1; i < inputs.size(); ++i) {
|
||||
shape = broadcast_shapes(shape, inputs[i].shape());
|
||||
}
|
||||
return {shape_};
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<Shape> Broadcast::output_shapes(const std::vector<array>& inputs) {
|
||||
if (inputs.size() < 2) {
|
||||
if (broadcast_shapes(inputs[0].shape(), shape_) != shape_) {
|
||||
throw std::invalid_argument(
|
||||
"[Broadcast] Unable to infer broadcast shape");
|
||||
}
|
||||
return {shape_};
|
||||
}
|
||||
return {output_shape(inputs)};
|
||||
};
|
||||
|
||||
std::vector<array> BroadcastAxes::vjp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& cotangents,
|
||||
const std::vector<int>&,
|
||||
const std::vector<array>&) {
|
||||
return broadcast_vjp(primals[0], cotangents[0], stream());
|
||||
}
|
||||
|
||||
std::vector<array> BroadcastAxes::jvp(
|
||||
const std::vector<array>& primals,
|
||||
const std::vector<array>& tangents,
|
||||
const std::vector<int>& argnums) {
|
||||
return {array(
|
||||
output_shape(primals, ignore_axes_),
|
||||
tangents[0].dtype(),
|
||||
std::make_shared<BroadcastAxes>(stream(), ignore_axes_),
|
||||
tangents)};
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<int>> BroadcastAxes::vmap(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& axes) {
|
||||
throw std::invalid_argument("[BroadcastAxes] VMAP NYI");
|
||||
}
|
||||
|
||||
bool BroadcastAxes::is_equivalent(const Primitive& other) const {
|
||||
const auto& b_other = static_cast<const BroadcastAxes&>(other);
|
||||
return ignore_axes_ == b_other.ignore_axes_;
|
||||
}
|
||||
|
||||
Shape BroadcastAxes::output_shape(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& ignore_axes) {
|
||||
auto shape = Shape{};
|
||||
for (auto& in : inputs) {
|
||||
auto in_shape = in.shape();
|
||||
for (auto it = ignore_axes.rbegin(); it != ignore_axes.rend(); ++it) {
|
||||
in_shape.erase(in_shape.begin() + in.ndim() + *it);
|
||||
}
|
||||
shape = broadcast_shapes(shape, in_shape);
|
||||
}
|
||||
int dims = ignore_axes.size() + shape.size();
|
||||
for (auto ax : ignore_axes) {
|
||||
shape.insert(shape.begin() + dims + ax, inputs[0].shape(ax));
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<Shape> BroadcastAxes::output_shapes(
|
||||
const std::vector<array>& inputs) {
|
||||
return {output_shape(inputs, ignore_axes_)};
|
||||
}
|
||||
|
||||
std::vector<array> Ceil::vjp(
|
||||
@ -3066,14 +3131,9 @@ std::vector<array> Reduce::vjp(
|
||||
const std::vector<array>& outputs) {
|
||||
auto in = primals[0];
|
||||
|
||||
auto shape = in.shape();
|
||||
for (auto ax : axes_) {
|
||||
shape[ax] = 1;
|
||||
}
|
||||
auto& cotan = cotangents[0];
|
||||
if (reduce_type_ == Reduce::Sum) {
|
||||
return {
|
||||
broadcast_to(reshape(cotan, shape, stream()), in.shape(), stream())};
|
||||
return {broadcast_arrays({cotan, in}, stream())[0]};
|
||||
} else if (reduce_type_ == Reduce::Prod) {
|
||||
auto s = stream();
|
||||
auto prod_grad_single_axis =
|
||||
@ -3129,7 +3189,7 @@ std::vector<array> Reduce::vjp(
|
||||
|
||||
return {grad};
|
||||
} else {
|
||||
return {prod_grad_single_axis(in, reshape(cotan, shape, s), axes_[0])};
|
||||
return {prod_grad_single_axis(in, cotan, axes_[0])};
|
||||
}
|
||||
|
||||
} else if (reduce_type_ == Reduce::Min || reduce_type_ == Reduce::Max) {
|
||||
@ -3139,9 +3199,7 @@ std::vector<array> Reduce::vjp(
|
||||
}
|
||||
auto mask = equal(in, out, stream());
|
||||
auto normalizer = sum(mask, axes_, true, stream());
|
||||
auto cotan_reshape = reshape(cotan, shape, stream());
|
||||
cotan_reshape = divide(cotan_reshape, normalizer, stream());
|
||||
return {multiply(cotan_reshape, mask, stream())};
|
||||
return {multiply(divide(cotan, normalizer, stream()), mask, stream())};
|
||||
}
|
||||
|
||||
else {
|
||||
|
@ -547,6 +547,31 @@ class GatherMM : public UnaryPrimitive {
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class BroadcastAxes : public UnaryPrimitive {
|
||||
public:
|
||||
explicit BroadcastAxes(Stream stream, std::vector<int> ignore_axes = {})
|
||||
: UnaryPrimitive(stream), ignore_axes_(std::move(ignore_axes)) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(BroadcastAxes)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
static Shape output_shape(
|
||||
const std::vector<array>& inputs,
|
||||
const std::vector<int>& ignore_axes);
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
auto state() const {
|
||||
return ignore_axes_;
|
||||
}
|
||||
|
||||
private:
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
std::vector<int> ignore_axes_;
|
||||
};
|
||||
|
||||
class Broadcast : public UnaryPrimitive {
|
||||
public:
|
||||
explicit Broadcast(Stream stream, const Shape& shape)
|
||||
@ -558,13 +583,13 @@ class Broadcast : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_GRADS()
|
||||
DEFINE_PRINT(Broadcast)
|
||||
static Shape output_shape(const std::vector<array>& inputs);
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<int> state() const {
|
||||
return shape_;
|
||||
};
|
||||
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
|
||||
private:
|
||||
Shape shape_;
|
||||
|
||||
|
@ -31,13 +31,13 @@ class Synchronizer : public Primitive {
|
||||
DEFINE_PRINT(Synchronize);
|
||||
};
|
||||
|
||||
// Initialize the static tracing counter from transforms_impl.h .
|
||||
// Initialize the static tracing members from transforms_impl.h
|
||||
//
|
||||
// This is used to implement the in_tracing() function the returns true if we
|
||||
// These are used to implement the in_tracing() function the returns true if we
|
||||
// are currently under a function transformation and the retain_graph()
|
||||
// function which returns true if we are forced to retain the graph during
|
||||
// evaluation.
|
||||
int detail::InTracing::tracing_counter{0};
|
||||
std::vector<bool> detail::InTracing::trace_stack{};
|
||||
int detail::RetainGraph::tracing_counter{0};
|
||||
|
||||
array eval_impl(std::vector<array> outputs, bool async) {
|
||||
@ -434,7 +434,6 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<array> vjps;
|
||||
for (auto& primal : primals_) {
|
||||
if (auto cotan_it = cotan_map.find(primal.id());
|
||||
@ -629,7 +628,7 @@ ValueAndGradFn value_and_grad(
|
||||
for (auto arg : args) {
|
||||
ginputs.push_back(inputs[arg]);
|
||||
}
|
||||
// Set the incoming gradient to int32, vjp will cast it to the output type
|
||||
// Set the incoming gradient to float32, vjp will cast it to the output type
|
||||
auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)});
|
||||
return std::make_pair(outputs, grads);
|
||||
};
|
||||
|
@ -20,19 +20,23 @@ std::vector<array> vmap_replace(
|
||||
// of the codebase that we are during tracing so evals should not throw away
|
||||
// the graph.
|
||||
struct InTracing {
|
||||
InTracing() {
|
||||
tracing_counter++;
|
||||
explicit InTracing(bool dynamic = false) {
|
||||
trace_stack.push_back(dynamic);
|
||||
}
|
||||
~InTracing() {
|
||||
tracing_counter--;
|
||||
trace_stack.pop_back();
|
||||
}
|
||||
|
||||
static bool in_tracing() {
|
||||
return tracing_counter > 0;
|
||||
return !trace_stack.empty();
|
||||
}
|
||||
static bool in_dynamic_tracing() {
|
||||
// compile is always and only the outer-most transform
|
||||
return in_tracing() && trace_stack.front();
|
||||
}
|
||||
|
||||
private:
|
||||
static int tracing_counter;
|
||||
static std::vector<bool> trace_stack;
|
||||
};
|
||||
|
||||
struct RetainGraph {
|
||||
@ -51,4 +55,20 @@ struct RetainGraph {
|
||||
static int tracing_counter;
|
||||
};
|
||||
|
||||
/** Return true if we are currently performing a function transformation in
|
||||
* order to keep the graph when evaluating tracer arrays. */
|
||||
inline bool in_tracing() {
|
||||
return detail::InTracing::in_tracing();
|
||||
}
|
||||
|
||||
/** Return true if we are in a dynamic (shapeless) trace used for compiling or
|
||||
* exporting graphs with dynamic shapes. */
|
||||
inline bool in_dynamic_tracing() {
|
||||
return detail::InTracing::in_dynamic_tracing();
|
||||
}
|
||||
|
||||
inline bool retain_graph() {
|
||||
return detail::RetainGraph::retain_graph();
|
||||
}
|
||||
|
||||
} // namespace mlx::core::detail
|
||||
|
@ -91,8 +91,8 @@ Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
|
||||
const auto& small = ndim1 > ndim2 ? s2 : s1;
|
||||
Shape out_shape(ndim);
|
||||
for (int i = ndim - 1; i >= diff; --i) {
|
||||
int a = big[i];
|
||||
int b = small[i - diff];
|
||||
auto a = big[i];
|
||||
auto b = small[i - diff];
|
||||
if (b == a) {
|
||||
out_shape[i] = a;
|
||||
} else if (a == 1 || b == 1) {
|
||||
@ -100,7 +100,8 @@ Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
|
||||
out_shape[i] = a * b;
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "Shapes " << s1 << " and " << s2 << " cannot be broadcast.";
|
||||
msg << "[broadcast_shapes] Shapes " << s1 << " and " << s2
|
||||
<< " cannot be broadcast.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
@ -2799,6 +2799,27 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The output array with the new shape.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"broadcast_arrays",
|
||||
[](const nb::args& args, mx::StreamOrDevice s) {
|
||||
return broadcast_arrays(nb::cast<std::vector<mx::array>>(args), s);
|
||||
},
|
||||
nb::arg(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def broadcast_arrays(*arrays: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, ...]"),
|
||||
R"pbdoc(
|
||||
Broadcast arrays against one another.
|
||||
|
||||
The broadcasting semantics are the same as Numpy.
|
||||
|
||||
Args:
|
||||
*arrays (array): The input arrays.
|
||||
|
||||
Returns:
|
||||
tuple(array): The output arrays with the broadcasted shape.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"softmax",
|
||||
[](const mx::array& a,
|
||||
@ -3853,8 +3874,8 @@ void init_ops(nb::module_& m) {
|
||||
|
||||
Args:
|
||||
file (file, str): Path to file to which the arrays are saved.
|
||||
args (arrays): Arrays to be saved.
|
||||
kwargs (arrays): Arrays to be saved. Each array will be saved
|
||||
*args (arrays): Arrays to be saved.
|
||||
**kwargs (arrays): Arrays to be saved. Each array will be saved
|
||||
with the associated keyword as the output file name.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
|
@ -849,6 +849,79 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
compiled_fun(x)
|
||||
|
||||
def test_compile_shapeless_with_broadcast(self):
|
||||
a = mx.array(0.0)
|
||||
b = mx.ones((2, 2))
|
||||
|
||||
def fun(a):
|
||||
return mx.broadcast_to(a, b.shape)
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
# Works on the first shape
|
||||
cfun(a)
|
||||
|
||||
# Fails on a different shape
|
||||
with self.assertRaises(ValueError):
|
||||
cfun(mx.array(0.0).reshape(1, 1, 1))
|
||||
|
||||
def fun(a, b):
|
||||
return mx.broadcast_arrays(a, b)
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
a, b = cfun(a, b)
|
||||
self.assertEqual(a.shape, (2, 2))
|
||||
self.assertEqual(b.shape, (2, 2))
|
||||
|
||||
# Batched matmul
|
||||
a = mx.zeros((2, 1, 4, 2))
|
||||
b = mx.zeros((3, 2, 5))
|
||||
|
||||
def fun(a, b):
|
||||
return a @ b
|
||||
|
||||
cfun = mx.compile(fun, shapeless=True)
|
||||
out = cfun(a, b)
|
||||
self.assertEqual(out.shape, (2, 3, 4, 5))
|
||||
|
||||
# Shapeless compile should be preserved over vjp, jvp, vmap
|
||||
def fun(args):
|
||||
return sum(args).sum()
|
||||
|
||||
a = mx.array(0.0)
|
||||
b = mx.ones((2, 2))
|
||||
|
||||
cfun = mx.compile(mx.grad(fun), shapeless=True)
|
||||
out = cfun((a, b))
|
||||
|
||||
self.assertEqual(out[0].shape, ())
|
||||
self.assertEqual(out[1].shape, (2, 2))
|
||||
|
||||
out = cfun((b, a))
|
||||
|
||||
self.assertEqual(out[0].shape, (2, 2))
|
||||
self.assertEqual(out[1].shape, ())
|
||||
|
||||
# Shapeless compile should be preserved over vjp, jvp, vmap
|
||||
def fun(args):
|
||||
return (args[0] @ args[1]).sum()
|
||||
|
||||
a = mx.zeros((2, 1, 4, 2))
|
||||
b = mx.zeros((3, 2, 5))
|
||||
|
||||
cfun = mx.compile(mx.grad(fun), shapeless=True)
|
||||
out = cfun((a, b))
|
||||
|
||||
self.assertEqual(out[0].shape, (2, 1, 4, 2))
|
||||
self.assertEqual(out[1].shape, (3, 2, 5))
|
||||
|
||||
a = mx.zeros((3, 1, 4, 2))
|
||||
b = mx.zeros((2, 2, 5))
|
||||
|
||||
out = cfun((a, b))
|
||||
|
||||
self.assertEqual(out[0].shape, (3, 1, 4, 2))
|
||||
self.assertEqual(out[1].shape, (2, 2, 5))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -2782,6 +2782,19 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
expected[1:, 2:, 3:] = update
|
||||
self.assertTrue(mx.array_equal(expected, out))
|
||||
|
||||
def test_broadcast_arrays(self):
|
||||
a = mx.array(1)
|
||||
b = mx.array(1.0)
|
||||
a, b = mx.broadcast_arrays(a, b)
|
||||
self.assertEqual(a.shape, ())
|
||||
self.assertEqual(a.dtype, mx.int32)
|
||||
self.assertEqual(b.shape, ())
|
||||
self.assertEqual(b.dtype, mx.float32)
|
||||
|
||||
a, b = mx.broadcast_arrays(mx.zeros((3, 1, 2)), mx.zeros((4, 1)))
|
||||
self.assertEqual(a.shape, (3, 4, 2))
|
||||
self.assertEqual(b.shape, (3, 4, 2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user