mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Dynamic broadcasting for shapeless compile/export (#1722)
* working towards dynamic broadcast * shapeless broadcast * fix build + nits * use broadcast arrays in quantize matmul * some cleanup / consistency * mend * some comments * add vjp, jvp for broadcast axes
This commit is contained in:
261
mlx/ops.cpp
261
mlx/ops.cpp
@@ -14,6 +14,7 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@@ -1399,29 +1400,151 @@ array broadcast_to(
|
||||
{a});
|
||||
}
|
||||
|
||||
std::vector<array>
|
||||
broadcast_arrays(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto shape = broadcast_shapes(a.shape(), b.shape());
|
||||
return {broadcast_to(a, shape, s), broadcast_to(b, shape, s)};
|
||||
/** Broadcast the input arrays against one another while ignoring the
|
||||
* axes specified in `ignore_axes`. Note, this API is internal only.
|
||||
* The `ignore_axes` should be:
|
||||
* - negative values indicating axes from the end
|
||||
* - sorted in increasing order
|
||||
*/
|
||||
std::vector<array> broadcast_arrays(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<int> ignore_axes,
|
||||
StreamOrDevice s) {
|
||||
if (inputs.size() <= 1) {
|
||||
return inputs;
|
||||
}
|
||||
|
||||
std::vector<array> outputs;
|
||||
auto shape = BroadcastAxes::output_shape(inputs, ignore_axes);
|
||||
auto check_and_get_shape = [&shape, &ignore_axes](const array& in) {
|
||||
auto out_shape = shape;
|
||||
for (int i = 0; i < ignore_axes.size(); ++i) {
|
||||
auto ax = ignore_axes[i];
|
||||
auto pos_ax = in.ndim() + ax;
|
||||
if (pos_ax < 0 || pos_ax > in.ndim() ||
|
||||
(i > 0 && ax <= ignore_axes[i - 1])) {
|
||||
throw std::invalid_argument(
|
||||
"[broadcast_arrays] Received invalid axes to ignore.");
|
||||
}
|
||||
out_shape[out_shape.size() + ax] = in.shape(ax);
|
||||
}
|
||||
return out_shape;
|
||||
};
|
||||
|
||||
if (!detail::in_dynamic_tracing()) {
|
||||
for (auto& in : inputs) {
|
||||
auto out_shape = check_and_get_shape(in);
|
||||
if (in.shape() == out_shape) {
|
||||
outputs.push_back(in);
|
||||
} else {
|
||||
outputs.push_back(array(
|
||||
std::move(out_shape),
|
||||
in.dtype(),
|
||||
std::make_shared<Broadcast>(to_stream(s), out_shape),
|
||||
{in}));
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::vector<array> stop_grad_inputs;
|
||||
for (auto& in : inputs) {
|
||||
stop_grad_inputs.push_back(stop_gradient(in, s));
|
||||
}
|
||||
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
auto out_shape = check_and_get_shape(in);
|
||||
if (in.shape() == out_shape) {
|
||||
outputs.push_back(in);
|
||||
} else {
|
||||
// broadcasted array goes first followed by other stopgrad inputs
|
||||
std::vector<array> p_inputs = {in};
|
||||
for (int j = 0; j < inputs.size(); ++j) {
|
||||
if (j == i) {
|
||||
continue;
|
||||
}
|
||||
p_inputs.push_back(stop_grad_inputs[j]);
|
||||
}
|
||||
outputs.push_back(array(
|
||||
std::move(out_shape),
|
||||
in.dtype(),
|
||||
std::make_shared<BroadcastAxes>(to_stream(s), ignore_axes),
|
||||
std::move(p_inputs)));
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::vector<array> broadcast_arrays(
|
||||
const std::vector<array>& inputs,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
Shape shape{};
|
||||
for (const auto& in : inputs) {
|
||||
shape = broadcast_shapes(shape, in.shape());
|
||||
if (inputs.size() <= 1) {
|
||||
return inputs;
|
||||
}
|
||||
auto shape = Broadcast::output_shape(inputs);
|
||||
std::vector<array> outputs;
|
||||
for (const auto& in : inputs) {
|
||||
outputs.push_back(broadcast_to(in, shape, s));
|
||||
|
||||
if (!detail::in_dynamic_tracing()) {
|
||||
for (auto& in : inputs) {
|
||||
if (in.shape() == shape) {
|
||||
outputs.push_back(in);
|
||||
} else {
|
||||
outputs.push_back(array(
|
||||
shape,
|
||||
in.dtype(),
|
||||
std::make_shared<Broadcast>(to_stream(s), shape),
|
||||
{in}));
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::vector<array> stop_grad_inputs;
|
||||
for (auto& in : inputs) {
|
||||
stop_grad_inputs.push_back(stop_gradient(in, s));
|
||||
}
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto& in = inputs[i];
|
||||
if (in.shape() == shape) {
|
||||
outputs.push_back(in);
|
||||
} else {
|
||||
// broadcasted array goes first followed by other stopgrad inputs
|
||||
std::vector<array> p_inputs = {in};
|
||||
for (int j = 0; j < inputs.size(); ++j) {
|
||||
if (j == i) {
|
||||
continue;
|
||||
}
|
||||
p_inputs.push_back(stop_grad_inputs[j]);
|
||||
}
|
||||
outputs.push_back(array(
|
||||
shape,
|
||||
in.dtype(),
|
||||
std::make_shared<Broadcast>(to_stream(s), shape),
|
||||
std::move(p_inputs)));
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::pair<array, array>
|
||||
broadcast_arrays(const array& a, const array& b, StreamOrDevice s) {
|
||||
auto out = broadcast_arrays({a, b}, s);
|
||||
return {out[0], out[1]};
|
||||
}
|
||||
|
||||
std::pair<array, array> broadcast_arrays(
|
||||
const array& a,
|
||||
const array& b,
|
||||
std::vector<int> ignore_axes,
|
||||
StreamOrDevice s) {
|
||||
auto out = broadcast_arrays({a, b}, std::move(ignore_axes), s);
|
||||
return {out[0], out[1]};
|
||||
}
|
||||
|
||||
array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape, bool_, std::make_shared<Equal>(to_stream(s)), std::move(inputs));
|
||||
@@ -1429,7 +1552,7 @@ array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@@ -1440,7 +1563,7 @@ array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape, bool_, std::make_shared<Greater>(to_stream(s)), std::move(inputs));
|
||||
@@ -1451,7 +1574,7 @@ array greater_equal(
|
||||
const array& b,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@@ -1462,7 +1585,7 @@ array greater_equal(
|
||||
|
||||
array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape, bool_, std::make_shared<Less>(to_stream(s)), std::move(inputs));
|
||||
@@ -1470,7 +1593,7 @@ array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@@ -2277,7 +2400,7 @@ array logical_not(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
// Broadcast arrays to a common shape
|
||||
auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@@ -2291,7 +2414,7 @@ array operator&&(const array& a, const array& b) {
|
||||
|
||||
array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
// Broadcast arrays to a common shape
|
||||
auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@@ -2311,7 +2434,7 @@ array reciprocal(const array& a, StreamOrDevice s /* = {} */) {
|
||||
array add(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape, out_type, std::make_shared<Add>(to_stream(s)), std::move(inputs));
|
||||
@@ -2324,7 +2447,7 @@ array operator+(const array& a, const array& b) {
|
||||
array subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@@ -2340,7 +2463,7 @@ array operator-(const array& a, const array& b) {
|
||||
array multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@@ -2355,8 +2478,8 @@ array operator*(const array& a, const array& b) {
|
||||
|
||||
array divide(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
|
||||
auto inputs = broadcast_arrays(
|
||||
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
|
||||
@@ -2380,7 +2503,7 @@ array floor_divide(
|
||||
return floor(divide(a, b, s), s);
|
||||
}
|
||||
|
||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
|
||||
@@ -2388,8 +2511,8 @@ array floor_divide(
|
||||
|
||||
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
|
||||
auto inputs = broadcast_arrays(
|
||||
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@@ -2407,8 +2530,8 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
if (issubdtype(dtype, complexfloating)) {
|
||||
throw std::invalid_argument("[divmod] Complex type not supported.");
|
||||
}
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
|
||||
auto inputs = broadcast_arrays(
|
||||
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
|
||||
return array::make_arrays(
|
||||
{inputs[0].shape(), inputs[0].shape()},
|
||||
{inputs[0].dtype(), inputs[0].dtype()},
|
||||
@@ -2419,7 +2542,7 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@@ -2431,7 +2554,7 @@ array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@@ -2514,7 +2637,7 @@ array arctan(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array arctan2(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
|
||||
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
|
||||
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape, dtype, std::make_shared<ArcTan2>(to_stream(s)), std::move(inputs));
|
||||
@@ -2610,7 +2733,7 @@ array logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
// Make sure out type is floating point
|
||||
auto out_type = at_least_float(promote_types(a.dtype(), b.dtype()));
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||
auto& shape = inputs[0].shape();
|
||||
return array(
|
||||
shape,
|
||||
@@ -2710,19 +2833,7 @@ array matmul(
|
||||
if (in_a.ndim() > 2 && in_b.ndim() <= 2) {
|
||||
a = flatten(a, 0, -2, s);
|
||||
} else if (in_b.ndim() > 2) {
|
||||
Shape bsx_a(a.shape().begin(), a.shape().end() - 2);
|
||||
Shape bsx_b(b.shape().begin(), b.shape().end() - 2);
|
||||
auto inner_shape = broadcast_shapes(bsx_a, bsx_b);
|
||||
|
||||
// Broadcast a
|
||||
inner_shape.push_back(a.shape(-2));
|
||||
inner_shape.push_back(a.shape(-1));
|
||||
a = broadcast_to(a, inner_shape, s);
|
||||
|
||||
// Broadcast b
|
||||
*(inner_shape.end() - 2) = b.shape(-2);
|
||||
*(inner_shape.end() - 1) = b.shape(-1);
|
||||
b = broadcast_to(b, inner_shape, s);
|
||||
std::tie(a, b) = broadcast_arrays(a, b, {-2, -1}, s);
|
||||
}
|
||||
|
||||
auto out_shape = a.shape();
|
||||
@@ -3780,29 +3891,6 @@ array quantized_matmul(
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
||||
|
||||
// QuantizedMatmul handles w.ndim == 2 case.
|
||||
if (x.ndim() > 2 && w.ndim() > 2) {
|
||||
Shape bsx_x(x.shape().begin(), x.shape().end() - 2);
|
||||
Shape bsx_w(w.shape().begin(), w.shape().end() - 2);
|
||||
auto inner_shape = broadcast_shapes(bsx_x, bsx_w);
|
||||
|
||||
// Broadcast x
|
||||
inner_shape.push_back(x.shape(-2));
|
||||
inner_shape.push_back(x.shape(-1));
|
||||
x = broadcast_to(x, inner_shape, s);
|
||||
|
||||
// Broadcast w
|
||||
*(inner_shape.end() - 2) = w.shape(-2);
|
||||
*(inner_shape.end() - 1) = w.shape(-1);
|
||||
w = broadcast_to(w, inner_shape, s);
|
||||
|
||||
*(inner_shape.end() - 1) = scales.shape(-1);
|
||||
scales = broadcast_to(scales, inner_shape, s);
|
||||
|
||||
*(inner_shape.end() - 1) = biases.shape(-1);
|
||||
biases = broadcast_to(biases, inner_shape, s);
|
||||
}
|
||||
|
||||
auto dtype = result_type(x, scales, biases);
|
||||
if (!issubdtype(dtype, floating)) {
|
||||
std::ostringstream msg;
|
||||
@@ -3812,18 +3900,21 @@ array quantized_matmul(
|
||||
<< " and biases.dtype() == " << biases.dtype();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
std::vector<array> inputs = {
|
||||
astype(x, dtype), w, astype(scales, dtype), astype(biases, dtype)};
|
||||
|
||||
auto out_shape = x.shape();
|
||||
if (x.ndim() > 2 && w.ndim() > 2) {
|
||||
inputs = broadcast_arrays(inputs, {-2, -1}, s);
|
||||
}
|
||||
|
||||
auto out_shape = inputs[0].shape();
|
||||
out_shape.back() = w_outer_dims;
|
||||
return array(
|
||||
std::move(out_shape),
|
||||
dtype,
|
||||
std::make_shared<QuantizedMatmul>(
|
||||
to_stream(s), group_size, bits, transpose),
|
||||
{astype(x, dtype, s),
|
||||
w,
|
||||
astype(scales, dtype, s),
|
||||
astype(biases, dtype, s)});
|
||||
std::move(inputs));
|
||||
}
|
||||
|
||||
std::tuple<array, array, array> quantize(
|
||||
@@ -3866,13 +3957,11 @@ array gather_qmm(
|
||||
// Extract indices and broadcast them
|
||||
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
||||
array rhs_indices = indices_or_default(rhs_indices_, w, s);
|
||||
auto out_bsx_shape =
|
||||
broadcast_shapes(lhs_indices.shape(), rhs_indices.shape());
|
||||
lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s);
|
||||
rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s);
|
||||
std::tie(lhs_indices, rhs_indices) =
|
||||
broadcast_arrays(lhs_indices, rhs_indices, s);
|
||||
|
||||
// Compute the full output shape
|
||||
auto out_shape = out_bsx_shape;
|
||||
auto out_shape = lhs_indices.shape();
|
||||
out_shape.push_back(x.shape(-2));
|
||||
out_shape.push_back(w_outer_dims);
|
||||
|
||||
@@ -4374,13 +4463,10 @@ array gather_mm(
|
||||
int N = b.shape(-1);
|
||||
int K = a.shape(-1);
|
||||
|
||||
auto out_bsx_shape =
|
||||
broadcast_shapes(lhs_indices.shape(), rhs_indices.shape());
|
||||
std::tie(lhs_indices, rhs_indices) =
|
||||
broadcast_arrays(lhs_indices, rhs_indices, s);
|
||||
|
||||
lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s);
|
||||
rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s);
|
||||
|
||||
auto out_shape = out_bsx_shape;
|
||||
auto out_shape = lhs_indices.shape();
|
||||
out_shape.push_back(M);
|
||||
out_shape.push_back(N);
|
||||
|
||||
@@ -4640,6 +4726,13 @@ array number_of_elements(
|
||||
ax = normal_axis;
|
||||
}
|
||||
|
||||
if (!detail::in_dynamic_tracing()) {
|
||||
double numel = 1;
|
||||
for (auto ax : axes) {
|
||||
numel *= a.shape(ax);
|
||||
}
|
||||
return array(inverted ? 1.0 / numel : numel, dtype);
|
||||
}
|
||||
return stop_gradient(array(
|
||||
Shape{},
|
||||
dtype,
|
||||
@@ -4673,7 +4766,7 @@ array bitwise_impl(
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
auto inputs =
|
||||
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
|
||||
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
|
||||
auto& out_shape = inputs[0].shape();
|
||||
return array(
|
||||
out_shape,
|
||||
|
||||
Reference in New Issue
Block a user