mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-11 11:48:37 +08:00
Dynamic broadcasting for shapeless compile/export (#1722)
* working towards dynamic broadcast * shapeless broadcast * fix build + nits * use broadcast arrays in quantize matmul * some cleanup / consistency * mend * some comments * add vjp, jvp for broadcast axes
This commit is contained in:
parent
ec36bfa317
commit
1ccaf80575
@ -10,20 +10,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
/** Return true if we are currently performing a function transformation in
|
|
||||||
* order to keep the graph when evaluating tracer arrays. */
|
|
||||||
bool in_tracing() {
|
|
||||||
return detail::InTracing::in_tracing();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool retain_graph() {
|
|
||||||
return detail::RetainGraph::retain_graph();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
array::array(const std::complex<float>& val, Dtype dtype /* = complex64 */)
|
||||||
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
|
: array_desc_(std::make_shared<ArrayDesc>(Shape{}, dtype)) {
|
||||||
auto cval = static_cast<complex64_t>(val);
|
auto cval = static_cast<complex64_t>(val);
|
||||||
@ -119,7 +105,8 @@ void array::eval() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool array::is_tracer() const {
|
bool array::is_tracer() const {
|
||||||
return (array_desc_->is_tracer && in_tracing()) || retain_graph();
|
return (array_desc_->is_tracer && detail::in_tracing()) ||
|
||||||
|
detail::retain_graph();
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
void array::set_data(allocator::Buffer buffer, Deleter d) {
|
||||||
|
@ -32,6 +32,7 @@ DEFAULT(ArgSort)
|
|||||||
DEFAULT(AsStrided)
|
DEFAULT(AsStrided)
|
||||||
DEFAULT(BlockMaskedMM)
|
DEFAULT(BlockMaskedMM)
|
||||||
DEFAULT(Broadcast)
|
DEFAULT(Broadcast)
|
||||||
|
DEFAULT(BroadcastAxes)
|
||||||
DEFAULT(Ceil)
|
DEFAULT(Ceil)
|
||||||
DEFAULT(Concatenate)
|
DEFAULT(Concatenate)
|
||||||
DEFAULT(Conjugate)
|
DEFAULT(Conjugate)
|
||||||
|
@ -42,9 +42,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
return move_or_copy(in, out, strides_, flags, data_size, offset_);
|
return move_or_copy(in, out, strides_, flags, data_size, offset_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
void broadcast(const array& in, array& out) {
|
||||||
assert(inputs.size() == 1);
|
|
||||||
const auto& in = inputs[0];
|
|
||||||
if (out.size() == 0) {
|
if (out.size() == 0) {
|
||||||
out.set_data(nullptr);
|
out.set_data(nullptr);
|
||||||
return;
|
return;
|
||||||
@ -61,6 +59,14 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
move_or_copy(in, out, strides, flags, in.data_size());
|
move_or_copy(in, out, strides, flags, in.data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
broadcast(inputs[0], out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BroadcastAxes::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
broadcast(inputs[0], out);
|
||||||
|
}
|
||||||
|
|
||||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
move_or_copy(inputs[0], out);
|
move_or_copy(inputs[0], out);
|
||||||
|
@ -37,6 +37,7 @@ DEFAULT(ArgSort)
|
|||||||
DEFAULT(AsType)
|
DEFAULT(AsType)
|
||||||
DEFAULT(AsStrided)
|
DEFAULT(AsStrided)
|
||||||
DEFAULT(Broadcast)
|
DEFAULT(Broadcast)
|
||||||
|
DEFAULT(BroadcastAxes)
|
||||||
DEFAULT(BlockMaskedMM)
|
DEFAULT(BlockMaskedMM)
|
||||||
DEFAULT(GatherMM)
|
DEFAULT(GatherMM)
|
||||||
DEFAULT(GatherQMM)
|
DEFAULT(GatherQMM)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <iostream> //TODO
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/backend/common/compiled.h"
|
#include "mlx/backend/common/compiled.h"
|
||||||
|
@ -240,6 +240,10 @@ void Broadcast::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void BroadcastAxes::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
eval(inputs, out);
|
||||||
|
}
|
||||||
|
|
||||||
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Concatenate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
concatenate_gpu(inputs, out, axis_, stream());
|
concatenate_gpu(inputs, out, axis_, stream());
|
||||||
}
|
}
|
||||||
|
@ -35,6 +35,7 @@ NO_CPU(AsStrided)
|
|||||||
NO_CPU(BitwiseBinary)
|
NO_CPU(BitwiseBinary)
|
||||||
NO_CPU(BlockMaskedMM)
|
NO_CPU(BlockMaskedMM)
|
||||||
NO_CPU(Broadcast)
|
NO_CPU(Broadcast)
|
||||||
|
NO_CPU(BroadcastAxes)
|
||||||
NO_CPU(Ceil)
|
NO_CPU(Ceil)
|
||||||
NO_CPU(Cholesky)
|
NO_CPU(Cholesky)
|
||||||
NO_CPU(Concatenate)
|
NO_CPU(Concatenate)
|
||||||
|
@ -36,6 +36,7 @@ NO_GPU(AsStrided)
|
|||||||
NO_GPU(BitwiseBinary)
|
NO_GPU(BitwiseBinary)
|
||||||
NO_GPU(BlockMaskedMM)
|
NO_GPU(BlockMaskedMM)
|
||||||
NO_GPU(Broadcast)
|
NO_GPU(Broadcast)
|
||||||
|
NO_GPU(BroadcastAxes)
|
||||||
NO_GPU(Ceil)
|
NO_GPU(Ceil)
|
||||||
NO_GPU_MULTI(Compiled)
|
NO_GPU_MULTI(Compiled)
|
||||||
NO_GPU(Concatenate)
|
NO_GPU(Concatenate)
|
||||||
|
@ -284,9 +284,10 @@ CompilerCache& compiler_cache() {
|
|||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
const std::vector<array>& inputs) {
|
const std::vector<array>& inputs,
|
||||||
|
bool shapeless) {
|
||||||
// Set the global tracing flag.
|
// Set the global tracing flag.
|
||||||
detail::InTracing in_tracing;
|
detail::InTracing in_tracing{shapeless};
|
||||||
|
|
||||||
// Run the function on placeholder inputs
|
// Run the function on placeholder inputs
|
||||||
// to get compute graph
|
// to get compute graph
|
||||||
@ -824,7 +825,8 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|||||||
// Set the constants
|
// Set the constants
|
||||||
entry.constants = std::move(constants);
|
entry.constants = std::move(constants);
|
||||||
// Trace to build the graph
|
// Trace to build the graph
|
||||||
std::tie(entry.inputs, entry.outputs) = compile_trace(fun, inputs);
|
std::tie(entry.inputs, entry.outputs) =
|
||||||
|
compile_trace(fun, inputs, shapeless);
|
||||||
|
|
||||||
// DFS the graph and get a tape, and a map of array id to (parent,
|
// DFS the graph and get a tape, and a map of array id to (parent,
|
||||||
// position in parent inputs)
|
// position in parent inputs)
|
||||||
|
@ -27,7 +27,8 @@ bool compile_available_for_device(const Device& device);
|
|||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
std::pair<std::vector<array>, std::vector<array>> compile_trace(
|
||||||
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
|
||||||
const std::vector<array>& inputs);
|
const std::vector<array>& inputs,
|
||||||
|
bool shapeless);
|
||||||
|
|
||||||
using ParentsMap =
|
using ParentsMap =
|
||||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
|
||||||
|
@ -243,6 +243,7 @@ struct PrimitiveFactory {
|
|||||||
"RightShift"),
|
"RightShift"),
|
||||||
SERIALIZE_PRIMITIVE(BlockMaskedMM),
|
SERIALIZE_PRIMITIVE(BlockMaskedMM),
|
||||||
SERIALIZE_PRIMITIVE(Broadcast),
|
SERIALIZE_PRIMITIVE(Broadcast),
|
||||||
|
SERIALIZE_PRIMITIVE(BroadcastAxes),
|
||||||
SERIALIZE_PRIMITIVE(Ceil),
|
SERIALIZE_PRIMITIVE(Ceil),
|
||||||
SERIALIZE_PRIMITIVE(Concatenate),
|
SERIALIZE_PRIMITIVE(Concatenate),
|
||||||
SERIALIZE_PRIMITIVE(Conjugate),
|
SERIALIZE_PRIMITIVE(Conjugate),
|
||||||
@ -568,7 +569,8 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Trace to build the graph
|
// Trace to build the graph
|
||||||
auto [trace_inputs, trace_outputs] = detail::compile_trace(flat_fun, inputs);
|
auto [trace_inputs, trace_outputs] =
|
||||||
|
detail::compile_trace(flat_fun, inputs, ftable->shapeless);
|
||||||
|
|
||||||
// DFS the graph and get the tape
|
// DFS the graph and get the tape
|
||||||
auto [tape, parents_map] =
|
auto [tape, parents_map] =
|
||||||
|
261
mlx/ops.cpp
261
mlx/ops.cpp
@ -14,6 +14,7 @@
|
|||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
|
#include "mlx/transforms_impl.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -1399,29 +1400,151 @@ array broadcast_to(
|
|||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array>
|
/** Broadcast the input arrays against one another while ignoring the
|
||||||
broadcast_arrays(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
* axes specified in `ignore_axes`. Note, this API is internal only.
|
||||||
auto shape = broadcast_shapes(a.shape(), b.shape());
|
* The `ignore_axes` should be:
|
||||||
return {broadcast_to(a, shape, s), broadcast_to(b, shape, s)};
|
* - negative values indicating axes from the end
|
||||||
|
* - sorted in increasing order
|
||||||
|
*/
|
||||||
|
std::vector<array> broadcast_arrays(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<int> ignore_axes,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
if (inputs.size() <= 1) {
|
||||||
|
return inputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> outputs;
|
||||||
|
auto shape = BroadcastAxes::output_shape(inputs, ignore_axes);
|
||||||
|
auto check_and_get_shape = [&shape, &ignore_axes](const array& in) {
|
||||||
|
auto out_shape = shape;
|
||||||
|
for (int i = 0; i < ignore_axes.size(); ++i) {
|
||||||
|
auto ax = ignore_axes[i];
|
||||||
|
auto pos_ax = in.ndim() + ax;
|
||||||
|
if (pos_ax < 0 || pos_ax > in.ndim() ||
|
||||||
|
(i > 0 && ax <= ignore_axes[i - 1])) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[broadcast_arrays] Received invalid axes to ignore.");
|
||||||
|
}
|
||||||
|
out_shape[out_shape.size() + ax] = in.shape(ax);
|
||||||
|
}
|
||||||
|
return out_shape;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!detail::in_dynamic_tracing()) {
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
auto out_shape = check_and_get_shape(in);
|
||||||
|
if (in.shape() == out_shape) {
|
||||||
|
outputs.push_back(in);
|
||||||
|
} else {
|
||||||
|
outputs.push_back(array(
|
||||||
|
std::move(out_shape),
|
||||||
|
in.dtype(),
|
||||||
|
std::make_shared<Broadcast>(to_stream(s), out_shape),
|
||||||
|
{in}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> stop_grad_inputs;
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
stop_grad_inputs.push_back(stop_gradient(in, s));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
auto& in = inputs[i];
|
||||||
|
auto out_shape = check_and_get_shape(in);
|
||||||
|
if (in.shape() == out_shape) {
|
||||||
|
outputs.push_back(in);
|
||||||
|
} else {
|
||||||
|
// broadcasted array goes first followed by other stopgrad inputs
|
||||||
|
std::vector<array> p_inputs = {in};
|
||||||
|
for (int j = 0; j < inputs.size(); ++j) {
|
||||||
|
if (j == i) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
p_inputs.push_back(stop_grad_inputs[j]);
|
||||||
|
}
|
||||||
|
outputs.push_back(array(
|
||||||
|
std::move(out_shape),
|
||||||
|
in.dtype(),
|
||||||
|
std::make_shared<BroadcastAxes>(to_stream(s), ignore_axes),
|
||||||
|
std::move(p_inputs)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> broadcast_arrays(
|
std::vector<array> broadcast_arrays(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
Shape shape{};
|
if (inputs.size() <= 1) {
|
||||||
for (const auto& in : inputs) {
|
return inputs;
|
||||||
shape = broadcast_shapes(shape, in.shape());
|
|
||||||
}
|
}
|
||||||
|
auto shape = Broadcast::output_shape(inputs);
|
||||||
std::vector<array> outputs;
|
std::vector<array> outputs;
|
||||||
for (const auto& in : inputs) {
|
|
||||||
outputs.push_back(broadcast_to(in, shape, s));
|
if (!detail::in_dynamic_tracing()) {
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
if (in.shape() == shape) {
|
||||||
|
outputs.push_back(in);
|
||||||
|
} else {
|
||||||
|
outputs.push_back(array(
|
||||||
|
shape,
|
||||||
|
in.dtype(),
|
||||||
|
std::make_shared<Broadcast>(to_stream(s), shape),
|
||||||
|
{in}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> stop_grad_inputs;
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
stop_grad_inputs.push_back(stop_gradient(in, s));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
auto& in = inputs[i];
|
||||||
|
if (in.shape() == shape) {
|
||||||
|
outputs.push_back(in);
|
||||||
|
} else {
|
||||||
|
// broadcasted array goes first followed by other stopgrad inputs
|
||||||
|
std::vector<array> p_inputs = {in};
|
||||||
|
for (int j = 0; j < inputs.size(); ++j) {
|
||||||
|
if (j == i) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
p_inputs.push_back(stop_grad_inputs[j]);
|
||||||
|
}
|
||||||
|
outputs.push_back(array(
|
||||||
|
shape,
|
||||||
|
in.dtype(),
|
||||||
|
std::make_shared<Broadcast>(to_stream(s), shape),
|
||||||
|
std::move(p_inputs)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<array, array>
|
||||||
|
broadcast_arrays(const array& a, const array& b, StreamOrDevice s) {
|
||||||
|
auto out = broadcast_arrays({a, b}, s);
|
||||||
|
return {out[0], out[1]};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<array, array> broadcast_arrays(
|
||||||
|
const array& a,
|
||||||
|
const array& b,
|
||||||
|
std::vector<int> ignore_axes,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
auto out = broadcast_arrays({a, b}, std::move(ignore_axes), s);
|
||||||
|
return {out[0], out[1]};
|
||||||
|
}
|
||||||
|
|
||||||
array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
shape, bool_, std::make_shared<Equal>(to_stream(s)), std::move(inputs));
|
shape, bool_, std::make_shared<Equal>(to_stream(s)), std::move(inputs));
|
||||||
@ -1429,7 +1552,7 @@ array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|||||||
|
|
||||||
array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
shape,
|
shape,
|
||||||
@ -1440,7 +1563,7 @@ array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|||||||
|
|
||||||
array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
shape, bool_, std::make_shared<Greater>(to_stream(s)), std::move(inputs));
|
shape, bool_, std::make_shared<Greater>(to_stream(s)), std::move(inputs));
|
||||||
@ -1451,7 +1574,7 @@ array greater_equal(
|
|||||||
const array& b,
|
const array& b,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
shape,
|
shape,
|
||||||
@ -1462,7 +1585,7 @@ array greater_equal(
|
|||||||
|
|
||||||
array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
shape, bool_, std::make_shared<Less>(to_stream(s)), std::move(inputs));
|
shape, bool_, std::make_shared<Less>(to_stream(s)), std::move(inputs));
|
||||||
@ -1470,7 +1593,7 @@ array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|||||||
|
|
||||||
array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
shape,
|
shape,
|
||||||
@ -2277,7 +2400,7 @@ array logical_not(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
|
|
||||||
array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
// Broadcast arrays to a common shape
|
// Broadcast arrays to a common shape
|
||||||
auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);
|
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
shape,
|
shape,
|
||||||
@ -2291,7 +2414,7 @@ array operator&&(const array& a, const array& b) {
|
|||||||
|
|
||||||
array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
// Broadcast arrays to a common shape
|
// Broadcast arrays to a common shape
|
||||||
auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);
|
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
shape,
|
shape,
|
||||||
@ -2311,7 +2434,7 @@ array reciprocal(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
array add(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array add(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs =
|
auto inputs =
|
||||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
shape, out_type, std::make_shared<Add>(to_stream(s)), std::move(inputs));
|
shape, out_type, std::make_shared<Add>(to_stream(s)), std::move(inputs));
|
||||||
@ -2324,7 +2447,7 @@ array operator+(const array& a, const array& b) {
|
|||||||
array subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs =
|
auto inputs =
|
||||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
shape,
|
shape,
|
||||||
@ -2340,7 +2463,7 @@ array operator-(const array& a, const array& b) {
|
|||||||
array multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs =
|
auto inputs =
|
||||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
shape,
|
shape,
|
||||||
@ -2355,8 +2478,8 @@ array operator*(const array& a, const array& b) {
|
|||||||
|
|
||||||
array divide(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array divide(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
|
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
|
||||||
auto inputs =
|
auto inputs = broadcast_arrays(
|
||||||
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
|
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
|
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
|
||||||
@ -2380,7 +2503,7 @@ array floor_divide(
|
|||||||
return floor(divide(a, b, s), s);
|
return floor(divide(a, b, s), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
|
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
|
||||||
@ -2388,8 +2511,8 @@ array floor_divide(
|
|||||||
|
|
||||||
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs =
|
auto inputs = broadcast_arrays(
|
||||||
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
|
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
shape,
|
shape,
|
||||||
@ -2407,8 +2530,8 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|||||||
if (issubdtype(dtype, complexfloating)) {
|
if (issubdtype(dtype, complexfloating)) {
|
||||||
throw std::invalid_argument("[divmod] Complex type not supported.");
|
throw std::invalid_argument("[divmod] Complex type not supported.");
|
||||||
}
|
}
|
||||||
auto inputs =
|
auto inputs = broadcast_arrays(
|
||||||
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
|
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
|
||||||
return array::make_arrays(
|
return array::make_arrays(
|
||||||
{inputs[0].shape(), inputs[0].shape()},
|
{inputs[0].shape(), inputs[0].shape()},
|
||||||
{inputs[0].dtype(), inputs[0].dtype()},
|
{inputs[0].dtype(), inputs[0].dtype()},
|
||||||
@ -2419,7 +2542,7 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|||||||
array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs =
|
auto inputs =
|
||||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
shape,
|
shape,
|
||||||
@ -2431,7 +2554,7 @@ array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|||||||
array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||||
auto inputs =
|
auto inputs =
|
||||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
shape,
|
shape,
|
||||||
@ -2514,7 +2637,7 @@ array arctan(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
|
|
||||||
array arctan2(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
array arctan2(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||||
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
|
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
|
||||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
shape, dtype, std::make_shared<ArcTan2>(to_stream(s)), std::move(inputs));
|
shape, dtype, std::make_shared<ArcTan2>(to_stream(s)), std::move(inputs));
|
||||||
@ -2610,7 +2733,7 @@ array logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|||||||
// Make sure out type is floating point
|
// Make sure out type is floating point
|
||||||
auto out_type = at_least_float(promote_types(a.dtype(), b.dtype()));
|
auto out_type = at_least_float(promote_types(a.dtype(), b.dtype()));
|
||||||
auto inputs =
|
auto inputs =
|
||||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||||
auto& shape = inputs[0].shape();
|
auto& shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
shape,
|
shape,
|
||||||
@ -2710,19 +2833,7 @@ array matmul(
|
|||||||
if (in_a.ndim() > 2 && in_b.ndim() <= 2) {
|
if (in_a.ndim() > 2 && in_b.ndim() <= 2) {
|
||||||
a = flatten(a, 0, -2, s);
|
a = flatten(a, 0, -2, s);
|
||||||
} else if (in_b.ndim() > 2) {
|
} else if (in_b.ndim() > 2) {
|
||||||
Shape bsx_a(a.shape().begin(), a.shape().end() - 2);
|
std::tie(a, b) = broadcast_arrays(a, b, {-2, -1}, s);
|
||||||
Shape bsx_b(b.shape().begin(), b.shape().end() - 2);
|
|
||||||
auto inner_shape = broadcast_shapes(bsx_a, bsx_b);
|
|
||||||
|
|
||||||
// Broadcast a
|
|
||||||
inner_shape.push_back(a.shape(-2));
|
|
||||||
inner_shape.push_back(a.shape(-1));
|
|
||||||
a = broadcast_to(a, inner_shape, s);
|
|
||||||
|
|
||||||
// Broadcast b
|
|
||||||
*(inner_shape.end() - 2) = b.shape(-2);
|
|
||||||
*(inner_shape.end() - 1) = b.shape(-1);
|
|
||||||
b = broadcast_to(b, inner_shape, s);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto out_shape = a.shape();
|
auto out_shape = a.shape();
|
||||||
@ -3780,29 +3891,6 @@ array quantized_matmul(
|
|||||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
||||||
|
|
||||||
// QuantizedMatmul handles w.ndim == 2 case.
|
|
||||||
if (x.ndim() > 2 && w.ndim() > 2) {
|
|
||||||
Shape bsx_x(x.shape().begin(), x.shape().end() - 2);
|
|
||||||
Shape bsx_w(w.shape().begin(), w.shape().end() - 2);
|
|
||||||
auto inner_shape = broadcast_shapes(bsx_x, bsx_w);
|
|
||||||
|
|
||||||
// Broadcast x
|
|
||||||
inner_shape.push_back(x.shape(-2));
|
|
||||||
inner_shape.push_back(x.shape(-1));
|
|
||||||
x = broadcast_to(x, inner_shape, s);
|
|
||||||
|
|
||||||
// Broadcast w
|
|
||||||
*(inner_shape.end() - 2) = w.shape(-2);
|
|
||||||
*(inner_shape.end() - 1) = w.shape(-1);
|
|
||||||
w = broadcast_to(w, inner_shape, s);
|
|
||||||
|
|
||||||
*(inner_shape.end() - 1) = scales.shape(-1);
|
|
||||||
scales = broadcast_to(scales, inner_shape, s);
|
|
||||||
|
|
||||||
*(inner_shape.end() - 1) = biases.shape(-1);
|
|
||||||
biases = broadcast_to(biases, inner_shape, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto dtype = result_type(x, scales, biases);
|
auto dtype = result_type(x, scales, biases);
|
||||||
if (!issubdtype(dtype, floating)) {
|
if (!issubdtype(dtype, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -3812,18 +3900,21 @@ array quantized_matmul(
|
|||||||
<< " and biases.dtype() == " << biases.dtype();
|
<< " and biases.dtype() == " << biases.dtype();
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
std::vector<array> inputs = {
|
||||||
|
astype(x, dtype), w, astype(scales, dtype), astype(biases, dtype)};
|
||||||
|
|
||||||
auto out_shape = x.shape();
|
if (x.ndim() > 2 && w.ndim() > 2) {
|
||||||
|
inputs = broadcast_arrays(inputs, {-2, -1}, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto out_shape = inputs[0].shape();
|
||||||
out_shape.back() = w_outer_dims;
|
out_shape.back() = w_outer_dims;
|
||||||
return array(
|
return array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
dtype,
|
dtype,
|
||||||
std::make_shared<QuantizedMatmul>(
|
std::make_shared<QuantizedMatmul>(
|
||||||
to_stream(s), group_size, bits, transpose),
|
to_stream(s), group_size, bits, transpose),
|
||||||
{astype(x, dtype, s),
|
std::move(inputs));
|
||||||
w,
|
|
||||||
astype(scales, dtype, s),
|
|
||||||
astype(biases, dtype, s)});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<array, array, array> quantize(
|
std::tuple<array, array, array> quantize(
|
||||||
@ -3866,13 +3957,11 @@ array gather_qmm(
|
|||||||
// Extract indices and broadcast them
|
// Extract indices and broadcast them
|
||||||
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
||||||
array rhs_indices = indices_or_default(rhs_indices_, w, s);
|
array rhs_indices = indices_or_default(rhs_indices_, w, s);
|
||||||
auto out_bsx_shape =
|
std::tie(lhs_indices, rhs_indices) =
|
||||||
broadcast_shapes(lhs_indices.shape(), rhs_indices.shape());
|
broadcast_arrays(lhs_indices, rhs_indices, s);
|
||||||
lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s);
|
|
||||||
rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s);
|
|
||||||
|
|
||||||
// Compute the full output shape
|
// Compute the full output shape
|
||||||
auto out_shape = out_bsx_shape;
|
auto out_shape = lhs_indices.shape();
|
||||||
out_shape.push_back(x.shape(-2));
|
out_shape.push_back(x.shape(-2));
|
||||||
out_shape.push_back(w_outer_dims);
|
out_shape.push_back(w_outer_dims);
|
||||||
|
|
||||||
@ -4374,13 +4463,10 @@ array gather_mm(
|
|||||||
int N = b.shape(-1);
|
int N = b.shape(-1);
|
||||||
int K = a.shape(-1);
|
int K = a.shape(-1);
|
||||||
|
|
||||||
auto out_bsx_shape =
|
std::tie(lhs_indices, rhs_indices) =
|
||||||
broadcast_shapes(lhs_indices.shape(), rhs_indices.shape());
|
broadcast_arrays(lhs_indices, rhs_indices, s);
|
||||||
|
|
||||||
lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s);
|
auto out_shape = lhs_indices.shape();
|
||||||
rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s);
|
|
||||||
|
|
||||||
auto out_shape = out_bsx_shape;
|
|
||||||
out_shape.push_back(M);
|
out_shape.push_back(M);
|
||||||
out_shape.push_back(N);
|
out_shape.push_back(N);
|
||||||
|
|
||||||
@ -4640,6 +4726,13 @@ array number_of_elements(
|
|||||||
ax = normal_axis;
|
ax = normal_axis;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!detail::in_dynamic_tracing()) {
|
||||||
|
double numel = 1;
|
||||||
|
for (auto ax : axes) {
|
||||||
|
numel *= a.shape(ax);
|
||||||
|
}
|
||||||
|
return array(inverted ? 1.0 / numel : numel, dtype);
|
||||||
|
}
|
||||||
return stop_gradient(array(
|
return stop_gradient(array(
|
||||||
Shape{},
|
Shape{},
|
||||||
dtype,
|
dtype,
|
||||||
@ -4673,7 +4766,7 @@ array bitwise_impl(
|
|||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
auto inputs =
|
auto inputs =
|
||||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||||
auto& out_shape = inputs[0].shape();
|
auto& out_shape = inputs[0].shape();
|
||||||
return array(
|
return array(
|
||||||
out_shape,
|
out_shape,
|
||||||
|
@ -686,51 +686,51 @@ std::vector<array> BitwiseBinary::vjp(
|
|||||||
return jvp(primals, cotangents, argnums);
|
return jvp(primals, cotangents, argnums);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> Broadcast::vjp(
|
std::vector<array>
|
||||||
const std::vector<array>& primals,
|
broadcast_vjp(const array& primal, const array& cotan, const Stream& s) {
|
||||||
const std::vector<array>& cotangents,
|
|
||||||
const std::vector<int>& argnums,
|
|
||||||
const std::vector<array>&) {
|
|
||||||
assert(argnums.size() == 1);
|
|
||||||
|
|
||||||
// Reduce cotangents to the shape of the primal
|
// Reduce cotangents to the shape of the primal
|
||||||
auto& shape = primals[0].shape();
|
auto& shape = primal.shape();
|
||||||
auto& cotan = cotangents[0];
|
|
||||||
int diff = cotan.ndim() - shape.size();
|
int diff = cotan.ndim() - shape.size();
|
||||||
std::vector<int> reduce_axes;
|
std::vector<int> squeeze_axes(diff);
|
||||||
for (int i = 0; i < cotan.ndim(); ++i) {
|
std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);
|
||||||
if (i < diff) {
|
auto reduce_axes = squeeze_axes;
|
||||||
reduce_axes.push_back(i);
|
for (int i = diff; i < cotan.ndim(); ++i) {
|
||||||
} else if (shape[i - diff] != cotan.shape(i)) {
|
if (shape[i - diff] != cotan.shape(i)) {
|
||||||
reduce_axes.push_back(i);
|
reduce_axes.push_back(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return {reshape(sum(cotan, reduce_axes, true, stream()), shape, stream())};
|
return {squeeze(sum(cotan, reduce_axes, true, s), squeeze_axes, s)};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> Broadcast::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>&,
|
||||||
|
const std::vector<array>&) {
|
||||||
|
return broadcast_vjp(primals[0], cotangents[0], stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> Broadcast::jvp(
|
std::vector<array> Broadcast::jvp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
assert(argnums.size() == 1);
|
return {array(
|
||||||
return {broadcast_to(tangents[0], shape_, stream())};
|
shape_,
|
||||||
|
tangents[0].dtype(),
|
||||||
|
std::make_shared<Broadcast>(stream(), shape_),
|
||||||
|
tangents)};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> Broadcast::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Broadcast::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
assert(inputs.size() == 1);
|
|
||||||
assert(axes.size() == 1);
|
|
||||||
auto ax = axes[0];
|
auto ax = axes[0];
|
||||||
auto in = inputs[0];
|
auto& in = inputs[0];
|
||||||
if (ax >= 0) {
|
if (ax >= 0) {
|
||||||
auto in_shape = in.shape();
|
|
||||||
int diff = shape_.size() - in.ndim() + 1;
|
int diff = shape_.size() - in.ndim() + 1;
|
||||||
assert(diff >= 0);
|
assert(diff >= 0);
|
||||||
in_shape.insert(in_shape.begin(), diff, 1);
|
shape_.insert(shape_.begin() + ax + diff, in.shape(ax));
|
||||||
ax += diff;
|
ax += diff;
|
||||||
shape_.insert(shape_.begin() + ax, in_shape[ax]);
|
|
||||||
in = reshape(in, in_shape, stream());
|
|
||||||
}
|
}
|
||||||
return {{broadcast_to(in, shape_, stream())}, {ax}};
|
return {{broadcast_to(in, shape_, stream())}, {ax}};
|
||||||
}
|
}
|
||||||
@ -740,11 +740,76 @@ bool Broadcast::is_equivalent(const Primitive& other) const {
|
|||||||
return shape_ == b_other.shape_;
|
return shape_ == b_other.shape_;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Shape> Broadcast::output_shapes(const std::vector<array>& inputs) {
|
Shape Broadcast::output_shape(const std::vector<array>& inputs) {
|
||||||
if (broadcast_shapes(inputs[0].shape(), shape_) != shape_) {
|
auto shape = inputs[0].shape();
|
||||||
throw std::invalid_argument("[Broadcast] Unable to infer broadcast shape");
|
for (int i = 1; i < inputs.size(); ++i) {
|
||||||
|
shape = broadcast_shapes(shape, inputs[i].shape());
|
||||||
}
|
}
|
||||||
return {shape_};
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Shape> Broadcast::output_shapes(const std::vector<array>& inputs) {
|
||||||
|
if (inputs.size() < 2) {
|
||||||
|
if (broadcast_shapes(inputs[0].shape(), shape_) != shape_) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[Broadcast] Unable to infer broadcast shape");
|
||||||
|
}
|
||||||
|
return {shape_};
|
||||||
|
}
|
||||||
|
return {output_shape(inputs)};
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<array> BroadcastAxes::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>&,
|
||||||
|
const std::vector<array>&) {
|
||||||
|
return broadcast_vjp(primals[0], cotangents[0], stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> BroadcastAxes::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
return {array(
|
||||||
|
output_shape(primals, ignore_axes_),
|
||||||
|
tangents[0].dtype(),
|
||||||
|
std::make_shared<BroadcastAxes>(stream(), ignore_axes_),
|
||||||
|
tangents)};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> BroadcastAxes::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
throw std::invalid_argument("[BroadcastAxes] VMAP NYI");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool BroadcastAxes::is_equivalent(const Primitive& other) const {
|
||||||
|
const auto& b_other = static_cast<const BroadcastAxes&>(other);
|
||||||
|
return ignore_axes_ == b_other.ignore_axes_;
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape BroadcastAxes::output_shape(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& ignore_axes) {
|
||||||
|
auto shape = Shape{};
|
||||||
|
for (auto& in : inputs) {
|
||||||
|
auto in_shape = in.shape();
|
||||||
|
for (auto it = ignore_axes.rbegin(); it != ignore_axes.rend(); ++it) {
|
||||||
|
in_shape.erase(in_shape.begin() + in.ndim() + *it);
|
||||||
|
}
|
||||||
|
shape = broadcast_shapes(shape, in_shape);
|
||||||
|
}
|
||||||
|
int dims = ignore_axes.size() + shape.size();
|
||||||
|
for (auto ax : ignore_axes) {
|
||||||
|
shape.insert(shape.begin() + dims + ax, inputs[0].shape(ax));
|
||||||
|
}
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Shape> BroadcastAxes::output_shapes(
|
||||||
|
const std::vector<array>& inputs) {
|
||||||
|
return {output_shape(inputs, ignore_axes_)};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> Ceil::vjp(
|
std::vector<array> Ceil::vjp(
|
||||||
@ -3066,14 +3131,9 @@ std::vector<array> Reduce::vjp(
|
|||||||
const std::vector<array>& outputs) {
|
const std::vector<array>& outputs) {
|
||||||
auto in = primals[0];
|
auto in = primals[0];
|
||||||
|
|
||||||
auto shape = in.shape();
|
|
||||||
for (auto ax : axes_) {
|
|
||||||
shape[ax] = 1;
|
|
||||||
}
|
|
||||||
auto& cotan = cotangents[0];
|
auto& cotan = cotangents[0];
|
||||||
if (reduce_type_ == Reduce::Sum) {
|
if (reduce_type_ == Reduce::Sum) {
|
||||||
return {
|
return {broadcast_arrays({cotan, in}, stream())[0]};
|
||||||
broadcast_to(reshape(cotan, shape, stream()), in.shape(), stream())};
|
|
||||||
} else if (reduce_type_ == Reduce::Prod) {
|
} else if (reduce_type_ == Reduce::Prod) {
|
||||||
auto s = stream();
|
auto s = stream();
|
||||||
auto prod_grad_single_axis =
|
auto prod_grad_single_axis =
|
||||||
@ -3129,7 +3189,7 @@ std::vector<array> Reduce::vjp(
|
|||||||
|
|
||||||
return {grad};
|
return {grad};
|
||||||
} else {
|
} else {
|
||||||
return {prod_grad_single_axis(in, reshape(cotan, shape, s), axes_[0])};
|
return {prod_grad_single_axis(in, cotan, axes_[0])};
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (reduce_type_ == Reduce::Min || reduce_type_ == Reduce::Max) {
|
} else if (reduce_type_ == Reduce::Min || reduce_type_ == Reduce::Max) {
|
||||||
@ -3139,9 +3199,7 @@ std::vector<array> Reduce::vjp(
|
|||||||
}
|
}
|
||||||
auto mask = equal(in, out, stream());
|
auto mask = equal(in, out, stream());
|
||||||
auto normalizer = sum(mask, axes_, true, stream());
|
auto normalizer = sum(mask, axes_, true, stream());
|
||||||
auto cotan_reshape = reshape(cotan, shape, stream());
|
return {multiply(divide(cotan, normalizer, stream()), mask, stream())};
|
||||||
cotan_reshape = divide(cotan_reshape, normalizer, stream());
|
|
||||||
return {multiply(cotan_reshape, mask, stream())};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
else {
|
else {
|
||||||
|
@ -547,6 +547,31 @@ class GatherMM : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class BroadcastAxes : public UnaryPrimitive {
|
||||||
|
public:
|
||||||
|
explicit BroadcastAxes(Stream stream, std::vector<int> ignore_axes = {})
|
||||||
|
: UnaryPrimitive(stream), ignore_axes_(std::move(ignore_axes)) {}
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
|
DEFINE_GRADS()
|
||||||
|
DEFINE_PRINT(BroadcastAxes)
|
||||||
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
static Shape output_shape(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& ignore_axes);
|
||||||
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
|
auto state() const {
|
||||||
|
return ignore_axes_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
std::vector<int> ignore_axes_;
|
||||||
|
};
|
||||||
|
|
||||||
class Broadcast : public UnaryPrimitive {
|
class Broadcast : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit Broadcast(Stream stream, const Shape& shape)
|
explicit Broadcast(Stream stream, const Shape& shape)
|
||||||
@ -558,13 +583,13 @@ class Broadcast : public UnaryPrimitive {
|
|||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_PRINT(Broadcast)
|
DEFINE_PRINT(Broadcast)
|
||||||
|
static Shape output_shape(const std::vector<array>& inputs);
|
||||||
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
std::vector<int> state() const {
|
std::vector<int> state() const {
|
||||||
return shape_;
|
return shape_;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Shape shape_;
|
Shape shape_;
|
||||||
|
|
||||||
|
@ -31,13 +31,13 @@ class Synchronizer : public Primitive {
|
|||||||
DEFINE_PRINT(Synchronize);
|
DEFINE_PRINT(Synchronize);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Initialize the static tracing counter from transforms_impl.h .
|
// Initialize the static tracing members from transforms_impl.h
|
||||||
//
|
//
|
||||||
// This is used to implement the in_tracing() function the returns true if we
|
// These are used to implement the in_tracing() function the returns true if we
|
||||||
// are currently under a function transformation and the retain_graph()
|
// are currently under a function transformation and the retain_graph()
|
||||||
// function which returns true if we are forced to retain the graph during
|
// function which returns true if we are forced to retain the graph during
|
||||||
// evaluation.
|
// evaluation.
|
||||||
int detail::InTracing::tracing_counter{0};
|
std::vector<bool> detail::InTracing::trace_stack{};
|
||||||
int detail::RetainGraph::tracing_counter{0};
|
int detail::RetainGraph::tracing_counter{0};
|
||||||
|
|
||||||
array eval_impl(std::vector<array> outputs, bool async) {
|
array eval_impl(std::vector<array> outputs, bool async) {
|
||||||
@ -434,7 +434,6 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> vjps;
|
std::vector<array> vjps;
|
||||||
for (auto& primal : primals_) {
|
for (auto& primal : primals_) {
|
||||||
if (auto cotan_it = cotan_map.find(primal.id());
|
if (auto cotan_it = cotan_map.find(primal.id());
|
||||||
@ -629,7 +628,7 @@ ValueAndGradFn value_and_grad(
|
|||||||
for (auto arg : args) {
|
for (auto arg : args) {
|
||||||
ginputs.push_back(inputs[arg]);
|
ginputs.push_back(inputs[arg]);
|
||||||
}
|
}
|
||||||
// Set the incoming gradient to int32, vjp will cast it to the output type
|
// Set the incoming gradient to float32, vjp will cast it to the output type
|
||||||
auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)});
|
auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)});
|
||||||
return std::make_pair(outputs, grads);
|
return std::make_pair(outputs, grads);
|
||||||
};
|
};
|
||||||
|
@ -20,19 +20,23 @@ std::vector<array> vmap_replace(
|
|||||||
// of the codebase that we are during tracing so evals should not throw away
|
// of the codebase that we are during tracing so evals should not throw away
|
||||||
// the graph.
|
// the graph.
|
||||||
struct InTracing {
|
struct InTracing {
|
||||||
InTracing() {
|
explicit InTracing(bool dynamic = false) {
|
||||||
tracing_counter++;
|
trace_stack.push_back(dynamic);
|
||||||
}
|
}
|
||||||
~InTracing() {
|
~InTracing() {
|
||||||
tracing_counter--;
|
trace_stack.pop_back();
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool in_tracing() {
|
static bool in_tracing() {
|
||||||
return tracing_counter > 0;
|
return !trace_stack.empty();
|
||||||
|
}
|
||||||
|
static bool in_dynamic_tracing() {
|
||||||
|
// compile is always and only the outer-most transform
|
||||||
|
return in_tracing() && trace_stack.front();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static int tracing_counter;
|
static std::vector<bool> trace_stack;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct RetainGraph {
|
struct RetainGraph {
|
||||||
@ -51,4 +55,20 @@ struct RetainGraph {
|
|||||||
static int tracing_counter;
|
static int tracing_counter;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** Return true if we are currently performing a function transformation in
|
||||||
|
* order to keep the graph when evaluating tracer arrays. */
|
||||||
|
inline bool in_tracing() {
|
||||||
|
return detail::InTracing::in_tracing();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Return true if we are in a dynamic (shapeless) trace used for compiling or
|
||||||
|
* exporting graphs with dynamic shapes. */
|
||||||
|
inline bool in_dynamic_tracing() {
|
||||||
|
return detail::InTracing::in_dynamic_tracing();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool retain_graph() {
|
||||||
|
return detail::RetainGraph::retain_graph();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::detail
|
} // namespace mlx::core::detail
|
||||||
|
@ -91,8 +91,8 @@ Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
|
|||||||
const auto& small = ndim1 > ndim2 ? s2 : s1;
|
const auto& small = ndim1 > ndim2 ? s2 : s1;
|
||||||
Shape out_shape(ndim);
|
Shape out_shape(ndim);
|
||||||
for (int i = ndim - 1; i >= diff; --i) {
|
for (int i = ndim - 1; i >= diff; --i) {
|
||||||
int a = big[i];
|
auto a = big[i];
|
||||||
int b = small[i - diff];
|
auto b = small[i - diff];
|
||||||
if (b == a) {
|
if (b == a) {
|
||||||
out_shape[i] = a;
|
out_shape[i] = a;
|
||||||
} else if (a == 1 || b == 1) {
|
} else if (a == 1 || b == 1) {
|
||||||
@ -100,7 +100,8 @@ Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
|
|||||||
out_shape[i] = a * b;
|
out_shape[i] = a * b;
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "Shapes " << s1 << " and " << s2 << " cannot be broadcast.";
|
msg << "[broadcast_shapes] Shapes " << s1 << " and " << s2
|
||||||
|
<< " cannot be broadcast.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2799,6 +2799,27 @@ void init_ops(nb::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The output array with the new shape.
|
array: The output array with the new shape.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"broadcast_arrays",
|
||||||
|
[](const nb::args& args, mx::StreamOrDevice s) {
|
||||||
|
return broadcast_arrays(nb::cast<std::vector<mx::array>>(args), s);
|
||||||
|
},
|
||||||
|
nb::arg(),
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def broadcast_arrays(*arrays: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, ...]"),
|
||||||
|
R"pbdoc(
|
||||||
|
Broadcast arrays against one another.
|
||||||
|
|
||||||
|
The broadcasting semantics are the same as Numpy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*arrays (array): The input arrays.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple(array): The output arrays with the broadcasted shape.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"softmax",
|
"softmax",
|
||||||
[](const mx::array& a,
|
[](const mx::array& a,
|
||||||
@ -3853,8 +3874,8 @@ void init_ops(nb::module_& m) {
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
file (file, str): Path to file to which the arrays are saved.
|
file (file, str): Path to file to which the arrays are saved.
|
||||||
args (arrays): Arrays to be saved.
|
*args (arrays): Arrays to be saved.
|
||||||
kwargs (arrays): Arrays to be saved. Each array will be saved
|
**kwargs (arrays): Arrays to be saved. Each array will be saved
|
||||||
with the associated keyword as the output file name.
|
with the associated keyword as the output file name.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
|
@ -849,6 +849,79 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
compiled_fun(x)
|
compiled_fun(x)
|
||||||
|
|
||||||
|
def test_compile_shapeless_with_broadcast(self):
|
||||||
|
a = mx.array(0.0)
|
||||||
|
b = mx.ones((2, 2))
|
||||||
|
|
||||||
|
def fun(a):
|
||||||
|
return mx.broadcast_to(a, b.shape)
|
||||||
|
|
||||||
|
cfun = mx.compile(fun, shapeless=True)
|
||||||
|
# Works on the first shape
|
||||||
|
cfun(a)
|
||||||
|
|
||||||
|
# Fails on a different shape
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
cfun(mx.array(0.0).reshape(1, 1, 1))
|
||||||
|
|
||||||
|
def fun(a, b):
|
||||||
|
return mx.broadcast_arrays(a, b)
|
||||||
|
|
||||||
|
cfun = mx.compile(fun, shapeless=True)
|
||||||
|
a, b = cfun(a, b)
|
||||||
|
self.assertEqual(a.shape, (2, 2))
|
||||||
|
self.assertEqual(b.shape, (2, 2))
|
||||||
|
|
||||||
|
# Batched matmul
|
||||||
|
a = mx.zeros((2, 1, 4, 2))
|
||||||
|
b = mx.zeros((3, 2, 5))
|
||||||
|
|
||||||
|
def fun(a, b):
|
||||||
|
return a @ b
|
||||||
|
|
||||||
|
cfun = mx.compile(fun, shapeless=True)
|
||||||
|
out = cfun(a, b)
|
||||||
|
self.assertEqual(out.shape, (2, 3, 4, 5))
|
||||||
|
|
||||||
|
# Shapeless compile should be preserved over vjp, jvp, vmap
|
||||||
|
def fun(args):
|
||||||
|
return sum(args).sum()
|
||||||
|
|
||||||
|
a = mx.array(0.0)
|
||||||
|
b = mx.ones((2, 2))
|
||||||
|
|
||||||
|
cfun = mx.compile(mx.grad(fun), shapeless=True)
|
||||||
|
out = cfun((a, b))
|
||||||
|
|
||||||
|
self.assertEqual(out[0].shape, ())
|
||||||
|
self.assertEqual(out[1].shape, (2, 2))
|
||||||
|
|
||||||
|
out = cfun((b, a))
|
||||||
|
|
||||||
|
self.assertEqual(out[0].shape, (2, 2))
|
||||||
|
self.assertEqual(out[1].shape, ())
|
||||||
|
|
||||||
|
# Shapeless compile should be preserved over vjp, jvp, vmap
|
||||||
|
def fun(args):
|
||||||
|
return (args[0] @ args[1]).sum()
|
||||||
|
|
||||||
|
a = mx.zeros((2, 1, 4, 2))
|
||||||
|
b = mx.zeros((3, 2, 5))
|
||||||
|
|
||||||
|
cfun = mx.compile(mx.grad(fun), shapeless=True)
|
||||||
|
out = cfun((a, b))
|
||||||
|
|
||||||
|
self.assertEqual(out[0].shape, (2, 1, 4, 2))
|
||||||
|
self.assertEqual(out[1].shape, (3, 2, 5))
|
||||||
|
|
||||||
|
a = mx.zeros((3, 1, 4, 2))
|
||||||
|
b = mx.zeros((2, 2, 5))
|
||||||
|
|
||||||
|
out = cfun((a, b))
|
||||||
|
|
||||||
|
self.assertEqual(out[0].shape, (3, 1, 4, 2))
|
||||||
|
self.assertEqual(out[1].shape, (2, 2, 5))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -2782,6 +2782,19 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
expected[1:, 2:, 3:] = update
|
expected[1:, 2:, 3:] = update
|
||||||
self.assertTrue(mx.array_equal(expected, out))
|
self.assertTrue(mx.array_equal(expected, out))
|
||||||
|
|
||||||
|
def test_broadcast_arrays(self):
|
||||||
|
a = mx.array(1)
|
||||||
|
b = mx.array(1.0)
|
||||||
|
a, b = mx.broadcast_arrays(a, b)
|
||||||
|
self.assertEqual(a.shape, ())
|
||||||
|
self.assertEqual(a.dtype, mx.int32)
|
||||||
|
self.assertEqual(b.shape, ())
|
||||||
|
self.assertEqual(b.dtype, mx.float32)
|
||||||
|
|
||||||
|
a, b = mx.broadcast_arrays(mx.zeros((3, 1, 2)), mx.zeros((4, 1)))
|
||||||
|
self.assertEqual(a.shape, (3, 4, 2))
|
||||||
|
self.assertEqual(b.shape, (3, 4, 2))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user