Reduce a little overhead (#871)

* some small overhead improvements

* use result_type in rms_norm

* remove release force

* fix + use non-vector version

* revert compile change

* fix ops

* a little more overhead

* a little more cleanup and overhead
This commit is contained in:
Awni Hannun 2024-03-22 17:29:36 -07:00 committed by GitHub
parent 6ee1112f30
commit be98f4ab6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 239 additions and 240 deletions

View File

@ -12,16 +12,6 @@ namespace mlx::core {
namespace { namespace {
std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
std::vector<size_t> strides(shape.size());
size_t cum_prod = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
strides[i] = cum_prod;
cum_prod *= shape[i];
}
return {cum_prod, strides};
}
/** Return true if we are currently performing a function transformation in /** Return true if we are currently performing a function transformation in
* order to keep the graph when evaluating tracer arrays. */ * order to keep the graph when evaluating tracer arrays. */
bool in_tracing() { bool in_tracing() {
@ -171,9 +161,21 @@ void array::move_shared_buffer(array other) {
move_shared_buffer(other, other.strides(), other.flags(), other.data_size()); move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
} }
void array::ArrayDesc::init() {
strides.resize(shape.size());
size = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
strides[i] = size;
size *= shape[i];
}
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
}
}
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype) array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
: shape(std::move(shape)), dtype(dtype) { : shape(std::move(shape)), dtype(dtype) {
std::tie(size, strides) = cum_prod(this->shape); init();
} }
array::ArrayDesc::ArrayDesc( array::ArrayDesc::ArrayDesc(
@ -185,10 +187,7 @@ array::ArrayDesc::ArrayDesc(
dtype(dtype), dtype(dtype),
primitive(std::move(primitive)), primitive(std::move(primitive)),
inputs(std::move(inputs)) { inputs(std::move(inputs)) {
std::tie(size, strides) = cum_prod(this->shape); init();
for (auto& in : this->inputs) {
is_tracer |= in.is_tracer();
}
} }
array::ArrayIterator::ArrayIterator(const array& arr, int idx) array::ArrayIterator::ArrayIterator(const array& arr, int idx)

View File

@ -392,6 +392,10 @@ class array {
Dtype dtype, Dtype dtype,
std::shared_ptr<Primitive> primitive, std::shared_ptr<Primitive> primitive,
std::vector<array> inputs); std::vector<array> inputs);
private:
// Initialize size, strides, and other metadata
void init();
}; };
// The ArrayDesc contains the details of the materialized array including the // The ArrayDesc contains the details of the materialized array including the

View File

@ -63,7 +63,7 @@ array rms_norm(
<< " dimensions."; << " dimensions.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
auto out_type = result_type({x, weight}); auto out_type = result_type(x, weight);
if (!is_floating_point(out_type) || is_complex(out_type)) { if (!is_floating_point(out_type) || is_complex(out_type)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[rms_norm] Received unsupported type " << out_type << "."; msg << "[rms_norm] Received unsupported type " << out_type << ".";
@ -88,7 +88,7 @@ array rms_norm(
return array( return array(
x.shape(), x.shape(),
out_type, out_type,
std::make_unique<RMSNorm>(s, fallback, eps), std::make_shared<RMSNorm>(s, fallback, eps),
{astype(x, out_type, s), astype(weight, out_type, s)}); {astype(x, out_type, s), astype(weight, out_type, s)});
} }
return fallback({x, weight})[0]; return fallback({x, weight})[0];
@ -125,8 +125,8 @@ array layer_norm(
} }
auto out_type = (weight.has_value()) auto out_type = (weight.has_value())
? ((bias.has_value()) ? result_type({x, *weight, *bias}) ? ((bias.has_value()) ? result_type(x, *weight, *bias)
: result_type({x, *weight})) : result_type(x, *weight))
: x.dtype(); : x.dtype();
if (!is_floating_point(out_type) || is_complex(out_type)) { if (!is_floating_point(out_type) || is_complex(out_type)) {
std::ostringstream msg; std::ostringstream msg;
@ -170,7 +170,7 @@ array layer_norm(
return array( return array(
x.shape(), x.shape(),
out_type, out_type,
std::make_unique<LayerNorm>(s, fallback, eps), std::make_shared<LayerNorm>(s, fallback, eps),
{astype(x, out_type, s), passed_weight, passed_bias}); {astype(x, out_type, s), passed_weight, passed_bias});
} }
return fallback({x, passed_weight, passed_bias})[0]; return fallback({x, passed_weight, passed_bias})[0];
@ -248,7 +248,7 @@ array rope(
return array( return array(
x.shape(), x.shape(),
x.dtype(), x.dtype(),
std::make_unique<RoPE>( std::make_shared<RoPE>(
stream, fallback, dims, traditional, base, scale, offset), stream, fallback, dims, traditional, base, scale, offset),
{x}); {x});
} }
@ -318,7 +318,7 @@ array scaled_dot_product_attention(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
auto final_type = result_type({queries, keys, values}); auto final_type = result_type(queries, keys, values);
if (!is_floating_point(final_type) || is_complex(final_type)) { if (!is_floating_point(final_type) || is_complex(final_type)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[scaled_dot_product_attention] Received unsupported type " msg << "[scaled_dot_product_attention] Received unsupported type "
@ -330,9 +330,6 @@ array scaled_dot_product_attention(
auto k = astype(keys, final_type, s); auto k = astype(keys, final_type, s);
auto v = astype(values, final_type, s); auto v = astype(values, final_type, s);
auto out_shape =
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
/* generic implementation for use cases that Metal implementation does not /* generic implementation for use cases that Metal implementation does not
* support. For non-supported cases listed below, use MLX primitives: * support. For non-supported cases listed below, use MLX primitives:
* * CPU implementation * * CPU implementation
@ -381,10 +378,12 @@ array scaled_dot_product_attention(
// TODO, update routing conditions post further tuning // TODO, update routing conditions post further tuning
implementation_supports_use_case &= false; implementation_supports_use_case &= false;
if (implementation_supports_use_case) { if (implementation_supports_use_case) {
auto out_shape =
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
auto out = array( auto out = array(
out_shape, std::move(out_shape),
final_type, final_type,
std::make_unique<ScaledDotProductAttention>( std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, false), stream, fallback, scale, false),
{q, k, v}); {q, k, v});
return out; return out;

View File

@ -195,7 +195,7 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
auto out = array::make_arrays( auto out = array::make_arrays(
{a.shape(), a.shape()}, {a.shape(), a.shape()},
{a.dtype(), a.dtype()}, {a.dtype(), a.dtype()},
std::make_unique<QRF>(to_stream(s)), std::make_shared<QRF>(to_stream(s)),
{astype(a, a.dtype(), s)}); {astype(a, a.dtype(), s)});
return std::make_pair(out[0], out[1]); return std::make_pair(out[0], out[1]);
} }
@ -234,7 +234,7 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
return array::make_arrays( return array::make_arrays(
{u_shape, s_shape, vt_shape}, {u_shape, s_shape, vt_shape},
{a.dtype(), a.dtype(), a.dtype()}, {a.dtype(), a.dtype(), a.dtype()},
std::make_unique<SVD>(to_stream(s)), std::make_shared<SVD>(to_stream(s)),
{a}); {a});
} }
@ -258,7 +258,7 @@ array inv(const array& a, StreamOrDevice s /* = {} */) {
} }
return array( return array(
a.shape(), a.dtype(), std::make_unique<Inverse>(to_stream(s)), {a}); a.shape(), a.dtype(), std::make_shared<Inverse>(to_stream(s)), {a});
} }
} // namespace mlx::core::linalg } // namespace mlx::core::linalg

View File

@ -88,7 +88,7 @@ array arange(
return array( return array(
{size}, {size},
dtype, dtype,
std::make_unique<Arange>(to_stream(s), start, stop, step), std::make_shared<Arange>(to_stream(s), start, stop, step),
{}); {});
} }
array arange( array arange(
@ -163,7 +163,7 @@ array astype(const array& a, Dtype dtype, StreamOrDevice s /* = {} */) {
return a; return a;
} }
return array( return array(
a.shape(), dtype, std::make_unique<AsType>(to_stream(s), dtype), {a}); a.shape(), dtype, std::make_shared<AsType>(to_stream(s), dtype), {a});
} }
array as_strided( array as_strided(
@ -177,12 +177,12 @@ array as_strided(
return array( return array(
shape, shape,
a.dtype(), a.dtype(),
std::make_unique<AsStrided>(to_stream(s), shape, strides, offset), std::make_shared<AsStrided>(to_stream(s), shape, strides, offset),
{x}); {x});
} }
array copy(const array& a, StreamOrDevice s /* = {} */) { array copy(const array& a, StreamOrDevice s /* = {} */) {
return array(a.shape(), a.dtype(), std::make_unique<Copy>(to_stream(s)), {a}); return array(a.shape(), a.dtype(), std::make_shared<Copy>(to_stream(s)), {a});
} }
array full( array full(
@ -194,7 +194,7 @@ array full(
throw std::invalid_argument("[full] Negative dimensions not allowed."); throw std::invalid_argument("[full] Negative dimensions not allowed.");
} }
auto in = broadcast_to(astype(vals, dtype, s), shape, s); auto in = broadcast_to(astype(vals, dtype, s), shape, s);
return array(shape, dtype, std::make_unique<Full>(to_stream(s)), {in}); return array(shape, dtype, std::make_shared<Full>(to_stream(s)), {in});
} }
array full( array full(
@ -313,9 +313,8 @@ array reshape(
<< " into shape " << shape << "."; << " into shape " << shape << ".";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
auto p = std::make_shared<Reshape>(to_stream(s), shape);
return array( return array(std::move(shape), a.dtype(), std::move(p), {a});
shape, a.dtype(), std::make_unique<Reshape>(to_stream(s), shape), {a});
} }
array flatten( array flatten(
@ -408,6 +407,20 @@ array squeeze(const array& a, StreamOrDevice s /* = {} */) {
return squeeze(a, axes, s); return squeeze(a, axes, s);
} }
array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) {
int out_dim = a.ndim() + 1;
int ax = axis < 0 ? axis + out_dim : axis;
if (ax < 0 || ax >= out_dim) {
std::ostringstream msg;
msg << "[expand_dims] Invalid axes " << axis << " for output array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
auto shape = a.shape();
shape.insert(shape.begin() + ax, 1);
return reshape(a, std::move(shape), s);
}
array expand_dims( array expand_dims(
const array& a, const array& a,
const std::vector<int>& axes, const std::vector<int>& axes,
@ -425,7 +438,7 @@ array expand_dims(
ax = ax < 0 ? ax + out_ndim : ax; ax = ax < 0 ? ax + out_ndim : ax;
if (ax < 0 || ax >= out_ndim) { if (ax < 0 || ax >= out_ndim) {
std::ostringstream msg; std::ostringstream msg;
msg << "[squeeze] Invalid axes " << ax << " for output array with " msg << "[expand_dims] Invalid axes " << ax << " for output array with "
<< a.ndim() << " dimensions."; << a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
@ -442,7 +455,7 @@ array expand_dims(
for (int i = 0; i < sorted_axes.size(); ++i) { for (int i = 0; i < sorted_axes.size(); ++i) {
out_shape.insert(out_shape.begin() + sorted_axes[i], 1); out_shape.insert(out_shape.begin() + sorted_axes[i], 1);
} }
return reshape(a, out_shape, s); return reshape(a, std::move(out_shape), s);
} }
// Slice helper // Slice helper
@ -523,7 +536,7 @@ array slice(
return array( return array(
out_shape, out_shape,
a.dtype(), a.dtype(),
std::make_unique<Slice>( std::make_shared<Slice>(
to_stream(s), std::move(start), std::move(stop), std::move(strides)), to_stream(s), std::move(start), std::move(stop), std::move(strides)),
{a}); {a});
} }
@ -568,7 +581,7 @@ array slice_update(
return array( return array(
src.shape(), src.shape(),
src.dtype(), src.dtype(),
std::make_unique<SliceUpdate>( std::make_shared<SliceUpdate>(
to_stream(s), std::move(start), std::move(stop), std::move(strides)), to_stream(s), std::move(start), std::move(stop), std::move(strides)),
{src, update_broadcasted}); {src, update_broadcasted});
} }
@ -743,7 +756,10 @@ array concatenate(
auto dtype = result_type(arrays); auto dtype = result_type(arrays);
return array( return array(
shape, dtype, std::make_unique<Concatenate>(to_stream(s), ax), arrays); std::move(shape),
dtype,
std::make_shared<Concatenate>(to_stream(s), ax),
std::move(arrays));
} }
array concatenate( array concatenate(
@ -886,7 +902,7 @@ array pad(
return array( return array(
out_shape, out_shape,
a.dtype(), a.dtype(),
std::make_unique<Pad>(to_stream(s), axes, low_pad_size, high_pad_size), std::make_shared<Pad>(to_stream(s), axes, low_pad_size, high_pad_size),
{a, astype(pad_value, a.dtype(), s)}); {a, astype(pad_value, a.dtype(), s)});
} }
@ -975,7 +991,7 @@ array swapaxes(
std::vector<int> reorder(a.ndim()); std::vector<int> reorder(a.ndim());
std::iota(reorder.begin(), reorder.end(), 0); std::iota(reorder.begin(), reorder.end(), 0);
std::swap(reorder[axis1], reorder[axis2]); std::swap(reorder[axis1], reorder[axis2]);
return transpose(a, reorder, s); return transpose(a, std::move(reorder), s);
} }
array transpose( array transpose(
@ -1000,9 +1016,9 @@ array transpose(
shape.push_back(a.shape()[ax]); shape.push_back(a.shape()[ax]);
} }
return array( return array(
shape, std::move(shape),
a.dtype(), a.dtype(),
std::make_unique<Transpose>(to_stream(s), std::move(axes)), std::make_shared<Transpose>(to_stream(s), std::move(axes)),
{a}); {a});
} }
@ -1029,7 +1045,16 @@ array broadcast_to(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
return array( return array(
shape, a.dtype(), std::make_unique<Broadcast>(to_stream(s), shape), {a}); std::move(bxshape),
a.dtype(),
std::make_shared<Broadcast>(to_stream(s), shape),
{a});
}
std::vector<array>
broadcast_arrays(const array& a, const array& b, StreamOrDevice s /* = {} */) {
std::vector<int> shape = broadcast_shapes(a.shape(), b.shape());
return {broadcast_to(a, shape, s), broadcast_to(b, shape, s)};
} }
std::vector<array> broadcast_arrays( std::vector<array> broadcast_arrays(
@ -1048,38 +1073,29 @@ std::vector<array> broadcast_arrays(
array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype()); auto dtype = promote_types(a.dtype(), b.dtype());
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)}; auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
if (a.shape() != b.shape()) { auto& shape = inputs[0].shape();
inputs = broadcast_arrays(inputs, s);
}
return array( return array(
inputs[0].shape(), bool_, std::make_unique<Equal>(to_stream(s)), inputs); shape, bool_, std::make_shared<Equal>(to_stream(s)), std::move(inputs));
} }
array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype()); auto dtype = promote_types(a.dtype(), b.dtype());
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)}; auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
if (a.shape() != b.shape()) { auto& shape = inputs[0].shape();
inputs = broadcast_arrays(inputs, s);
}
return array( return array(
inputs[0].shape(), shape,
bool_, bool_,
std::make_unique<NotEqual>(to_stream(s)), std::make_shared<NotEqual>(to_stream(s)),
inputs); std::move(inputs));
} }
array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) { array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype()); auto dtype = promote_types(a.dtype(), b.dtype());
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)}; auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
if (a.shape() != b.shape()) { auto& shape = inputs[0].shape();
inputs = broadcast_arrays(inputs, s);
}
return array( return array(
inputs[0].shape(), shape, bool_, std::make_shared<Greater>(to_stream(s)), std::move(inputs));
bool_,
std::make_unique<Greater>(to_stream(s)),
inputs);
} }
array greater_equal( array greater_equal(
@ -1087,38 +1103,32 @@ array greater_equal(
const array& b, const array& b,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype()); auto dtype = promote_types(a.dtype(), b.dtype());
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)}; auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
if (a.shape() != b.shape()) { auto& shape = inputs[0].shape();
inputs = broadcast_arrays(inputs, s);
}
return array( return array(
inputs[0].shape(), shape,
bool_, bool_,
std::make_unique<GreaterEqual>(to_stream(s)), std::make_shared<GreaterEqual>(to_stream(s)),
inputs); std::move(inputs));
} }
array less(const array& a, const array& b, StreamOrDevice s /* = {} */) { array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype()); auto dtype = promote_types(a.dtype(), b.dtype());
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)}; auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
if (a.shape() != b.shape()) { auto& shape = inputs[0].shape();
inputs = broadcast_arrays(inputs, s);
}
return array( return array(
inputs[0].shape(), bool_, std::make_unique<Less>(to_stream(s)), inputs); shape, bool_, std::make_shared<Less>(to_stream(s)), std::move(inputs));
} }
array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype()); auto dtype = promote_types(a.dtype(), b.dtype());
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)}; auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
if (a.shape() != b.shape()) { auto& shape = inputs[0].shape();
inputs = broadcast_arrays(inputs, s);
}
return array( return array(
inputs[0].shape(), shape,
bool_, bool_,
std::make_unique<LessEqual>(to_stream(s)), std::make_shared<LessEqual>(to_stream(s)),
inputs); std::move(inputs));
} }
array array_equal( array array_equal(
@ -1135,7 +1145,7 @@ array array_equal(
array( array(
a.shape(), a.shape(),
bool_, bool_,
std::make_unique<Equal>(to_stream(s), equal_nan), std::make_shared<Equal>(to_stream(s), equal_nan),
{astype(a, dtype, s), astype(b, dtype, s)}), {astype(a, dtype, s), astype(b, dtype, s)}),
false, false,
s); s);
@ -1180,7 +1190,7 @@ array where(
return array( return array(
inputs[0].shape(), inputs[0].shape(),
out_dtype, out_dtype,
std::make_unique<Select>(to_stream(s)), std::make_shared<Select>(to_stream(s)),
inputs); inputs);
} }
@ -1246,7 +1256,7 @@ array all(
auto out = array( auto out = array(
out_shape, out_shape,
bool_, bool_,
std::make_unique<Reduce>(to_stream(s), Reduce::And, sorted_axes), std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
{a}); {a});
if (!keepdims) { if (!keepdims) {
out = squeeze(out, sorted_axes, s); out = squeeze(out, sorted_axes, s);
@ -1280,7 +1290,7 @@ array any(
auto out = array( auto out = array(
out_shape, out_shape,
bool_, bool_,
std::make_unique<Reduce>(to_stream(s), Reduce::Or, sorted_axes), std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
{a}); {a});
if (!keepdims) { if (!keepdims) {
out = squeeze(out, sorted_axes, s); out = squeeze(out, sorted_axes, s);
@ -1315,7 +1325,7 @@ array sum(
auto out = array( auto out = array(
out_shape, out_shape,
out_type, out_type,
std::make_unique<Reduce>(to_stream(s), Reduce::Sum, sorted_axes), std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
{a}); {a});
if (!keepdims) { if (!keepdims) {
out = squeeze(out, sorted_axes, s); out = squeeze(out, sorted_axes, s);
@ -1424,7 +1434,7 @@ array prod(
auto out = array( auto out = array(
out_shape, out_shape,
a.dtype(), a.dtype(),
std::make_unique<Reduce>(to_stream(s), Reduce::Prod, sorted_axes), std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
{a}); {a});
if (!keepdims) { if (!keepdims) {
out = squeeze(out, sorted_axes, s); out = squeeze(out, sorted_axes, s);
@ -1461,7 +1471,7 @@ array max(
auto out = array( auto out = array(
out_shape, out_shape,
a.dtype(), a.dtype(),
std::make_unique<Reduce>(to_stream(s), Reduce::Max, sorted_axes), std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
{a}); {a});
if (!keepdims) { if (!keepdims) {
out = squeeze(out, sorted_axes, s); out = squeeze(out, sorted_axes, s);
@ -1498,7 +1508,7 @@ array min(
auto out = array( auto out = array(
out_shape, out_shape,
a.dtype(), a.dtype(),
std::make_unique<Reduce>(to_stream(s), Reduce::Min, sorted_axes), std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
{a}); {a});
if (!keepdims) { if (!keepdims) {
out = squeeze(out, sorted_axes, s); out = squeeze(out, sorted_axes, s);
@ -1538,7 +1548,7 @@ array argmin(
auto out = array( auto out = array(
out_shape, out_shape,
uint32, uint32,
std::make_unique<ArgReduce>( std::make_shared<ArgReduce>(
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]), to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
{a}); {a});
if (!keepdims) { if (!keepdims) {
@ -1571,7 +1581,7 @@ array argmax(
auto out = array( auto out = array(
out_shape, out_shape,
uint32, uint32,
std::make_unique<ArgReduce>( std::make_shared<ArgReduce>(
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]), to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
{a}); {a});
if (!keepdims) { if (!keepdims) {
@ -1607,7 +1617,7 @@ array sort(const array& a, int axis, StreamOrDevice s /* = {} */) {
} }
return array( return array(
a.shape(), a.dtype(), std::make_unique<Sort>(to_stream(s), axis), {a}); a.shape(), a.dtype(), std::make_shared<Sort>(to_stream(s), axis), {a});
} }
/** Returns indices that sort the flattened array. */ /** Returns indices that sort the flattened array. */
@ -1637,7 +1647,7 @@ array argsort(const array& a, int axis, StreamOrDevice s /* = {} */) {
} }
return array( return array(
a.shape(), uint32, std::make_unique<ArgSort>(to_stream(s), axis), {a}); a.shape(), uint32, std::make_shared<ArgSort>(to_stream(s), axis), {a});
} }
/** /**
@ -1677,7 +1687,7 @@ array partition(
return array( return array(
a.shape(), a.shape(),
a.dtype(), a.dtype(),
std::make_unique<Partition>(to_stream(s), kth_, axis_), std::make_shared<Partition>(to_stream(s), kth_, axis_),
{a}); {a});
} }
@ -1718,7 +1728,7 @@ array argpartition(
return array( return array(
a.shape(), a.shape(),
uint32, uint32,
std::make_unique<ArgPartition>(to_stream(s), kth_, axis_), std::make_shared<ArgPartition>(to_stream(s), kth_, axis_),
{a}); {a});
} }
@ -1787,7 +1797,7 @@ array logsumexp(
array abs(const array& a, StreamOrDevice s /* = {} */) { array abs(const array& a, StreamOrDevice s /* = {} */) {
auto out = auto out =
array(a.shape(), a.dtype(), std::make_unique<Abs>(to_stream(s)), {a}); array(a.shape(), a.dtype(), std::make_shared<Abs>(to_stream(s)), {a});
if (a.dtype() == complex64) { if (a.dtype() == complex64) {
out = astype(out, float32, s); out = astype(out, float32, s);
} }
@ -1800,33 +1810,33 @@ array negative(const array& a, StreamOrDevice s /* = {} */) {
throw std::invalid_argument(msg); throw std::invalid_argument(msg);
} }
return array( return array(
a.shape(), a.dtype(), std::make_unique<Negative>(to_stream(s)), {a}); a.shape(), a.dtype(), std::make_shared<Negative>(to_stream(s)), {a});
} }
array operator-(const array& a) { array operator-(const array& a) {
return negative(a); return negative(a);
} }
array sign(const array& a, StreamOrDevice s /* = {} */) { array sign(const array& a, StreamOrDevice s /* = {} */) {
return array(a.shape(), a.dtype(), std::make_unique<Sign>(to_stream(s)), {a}); return array(a.shape(), a.dtype(), std::make_shared<Sign>(to_stream(s)), {a});
} }
array logical_not(const array& a, StreamOrDevice s /* = {} */) { array logical_not(const array& a, StreamOrDevice s /* = {} */) {
return array( return array(
a.shape(), a.shape(),
bool_, bool_,
std::make_unique<LogicalNot>(to_stream(s)), std::make_shared<LogicalNot>(to_stream(s)),
{astype(a, bool_, s)}); {astype(a, bool_, s)});
} }
array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) { array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Broadcast arrays to a common shape // Broadcast arrays to a common shape
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s); auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);
auto& shape = inputs[0].shape();
return array( return array(
inputs[0].shape(), shape,
bool_, bool_,
std::make_unique<LogicalAnd>(to_stream(s)), std::make_shared<LogicalAnd>(to_stream(s)),
inputs); std::move(inputs));
} }
array operator&&(const array& a, const array& b) { array operator&&(const array& a, const array& b) {
return logical_and(a, b); return logical_and(a, b);
@ -1834,13 +1844,13 @@ array operator&&(const array& a, const array& b) {
array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) { array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Broadcast arrays to a common shape // Broadcast arrays to a common shape
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s); auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);
auto& shape = inputs[0].shape();
return array( return array(
inputs[0].shape(), shape,
bool_, bool_,
std::make_unique<LogicalOr>(to_stream(s)), std::make_shared<LogicalOr>(to_stream(s)),
inputs); std::move(inputs));
} }
array operator||(const array& a, const array& b) { array operator||(const array& a, const array& b) {
return logical_or(a, b); return logical_or(a, b);
@ -1854,9 +1864,10 @@ array reciprocal(const array& a, StreamOrDevice s /* = {} */) {
array add(const array& a, const array& b, StreamOrDevice s /* = {} */) { array add(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype()); auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs = auto inputs =
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
auto& shape = inputs[0].shape();
return array( return array(
inputs[0].shape(), out_type, std::make_unique<Add>(to_stream(s)), inputs); shape, out_type, std::make_shared<Add>(to_stream(s)), std::move(inputs));
} }
array operator+(const array& a, const array& b) { array operator+(const array& a, const array& b) {
@ -1866,12 +1877,13 @@ array operator+(const array& a, const array& b) {
array subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) { array subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype()); auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs = auto inputs =
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
auto& shape = inputs[0].shape();
return array( return array(
inputs[0].shape(), shape,
out_type, out_type,
std::make_unique<Subtract>(to_stream(s)), std::make_shared<Subtract>(to_stream(s)),
inputs); std::move(inputs));
} }
array operator-(const array& a, const array& b) { array operator-(const array& a, const array& b) {
@ -1881,12 +1893,13 @@ array operator-(const array& a, const array& b) {
array multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) { array multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype()); auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs = auto inputs =
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
auto& shape = inputs[0].shape();
return array( return array(
inputs[0].shape(), shape,
out_type, out_type,
std::make_unique<Multiply>(to_stream(s)), std::make_shared<Multiply>(to_stream(s)),
inputs); std::move(inputs));
} }
array operator*(const array& a, const array& b) { array operator*(const array& a, const array& b) {
@ -1895,10 +1908,11 @@ array operator*(const array& a, const array& b) {
array divide(const array& a, const array& b, StreamOrDevice s /* = {} */) { array divide(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype())); auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
auto inputs = broadcast_arrays( auto inputs =
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s); broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
auto& shape = inputs[0].shape();
return array( return array(
inputs[0].shape(), dtype, std::make_unique<Divide>(to_stream(s)), inputs); shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
} }
array operator/(const array& a, const array& b) { array operator/(const array& a, const array& b) {
return divide(a, b); return divide(a, b);
@ -1919,20 +1933,22 @@ array floor_divide(
return floor(divide(a, b, s), s); return floor(divide(a, b, s), s);
} }
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto& shape = inputs[0].shape();
return array( return array(
inputs[0].shape(), dtype, std::make_unique<Divide>(to_stream(s)), inputs); shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
} }
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) { array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype()); auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays( auto inputs =
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s); broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
auto& shape = inputs[0].shape();
return array( return array(
inputs[0].shape(), shape,
dtype, dtype,
std::make_unique<Remainder>(to_stream(s)), std::make_shared<Remainder>(to_stream(s)),
inputs); std::move(inputs));
} }
array operator%(const array& a, const array& b) { array operator%(const array& a, const array& b) {
return remainder(a, b); return remainder(a, b);
@ -1944,8 +1960,8 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
if (is_complex(dtype)) { if (is_complex(dtype)) {
throw std::invalid_argument("[divmod] Complex type not supported."); throw std::invalid_argument("[divmod] Complex type not supported.");
} }
auto inputs = broadcast_arrays( auto inputs =
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s); broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
return array::make_arrays( return array::make_arrays(
{inputs[0].shape(), inputs[0].shape()}, {inputs[0].shape(), inputs[0].shape()},
{inputs[0].dtype(), inputs[0].dtype()}, {inputs[0].dtype(), inputs[0].dtype()},
@ -1956,23 +1972,25 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) { array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype()); auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs = auto inputs =
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
auto& shape = inputs[0].shape();
return array( return array(
inputs[0].shape(), shape,
out_type, out_type,
std::make_unique<Maximum>(to_stream(s)), std::make_shared<Maximum>(to_stream(s)),
inputs); std::move(inputs));
} }
array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) { array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype()); auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs = auto inputs =
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
auto& shape = inputs[0].shape();
return array( return array(
inputs[0].shape(), shape,
out_type, out_type,
std::make_unique<Minimum>(to_stream(s)), std::make_shared<Minimum>(to_stream(s)),
inputs); std::move(inputs));
} }
array floor(const array& a, StreamOrDevice s /* = {} */) { array floor(const array& a, StreamOrDevice s /* = {} */) {
@ -1980,103 +1998,103 @@ array floor(const array& a, StreamOrDevice s /* = {} */) {
throw std::invalid_argument("[floor] Not supported for complex64."); throw std::invalid_argument("[floor] Not supported for complex64.");
} }
return array( return array(
a.shape(), a.dtype(), std::make_unique<Floor>(to_stream(s)), {a}); a.shape(), a.dtype(), std::make_shared<Floor>(to_stream(s)), {a});
} }
array ceil(const array& a, StreamOrDevice s /* = {} */) { array ceil(const array& a, StreamOrDevice s /* = {} */) {
if (a.dtype() == complex64) { if (a.dtype() == complex64) {
throw std::invalid_argument("[floor] Not supported for complex64."); throw std::invalid_argument("[floor] Not supported for complex64.");
} }
return array(a.shape(), a.dtype(), std::make_unique<Ceil>(to_stream(s)), {a}); return array(a.shape(), a.dtype(), std::make_shared<Ceil>(to_stream(s)), {a});
} }
array square(const array& a, StreamOrDevice s /* = {} */) { array square(const array& a, StreamOrDevice s /* = {} */) {
return array( return array(
a.shape(), a.dtype(), std::make_unique<Square>(to_stream(s)), {a}); a.shape(), a.dtype(), std::make_shared<Square>(to_stream(s)), {a});
} }
array exp(const array& a, StreamOrDevice s /* = {} */) { array exp(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s); auto input = astype(a, dtype, s);
return array(a.shape(), dtype, std::make_unique<Exp>(to_stream(s)), {input}); return array(a.shape(), dtype, std::make_shared<Exp>(to_stream(s)), {input});
} }
array sin(const array& a, StreamOrDevice s /* = {} */) { array sin(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s); auto input = astype(a, dtype, s);
return array(a.shape(), dtype, std::make_unique<Sin>(to_stream(s)), {input}); return array(a.shape(), dtype, std::make_shared<Sin>(to_stream(s)), {input});
} }
array cos(const array& a, StreamOrDevice s /* = {} */) { array cos(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s); auto input = astype(a, dtype, s);
return array(a.shape(), dtype, std::make_unique<Cos>(to_stream(s)), {input}); return array(a.shape(), dtype, std::make_shared<Cos>(to_stream(s)), {input});
} }
array tan(const array& a, StreamOrDevice s /* = {} */) { array tan(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s); auto input = astype(a, dtype, s);
return array(a.shape(), dtype, std::make_unique<Tan>(to_stream(s)), {input}); return array(a.shape(), dtype, std::make_shared<Tan>(to_stream(s)), {input});
} }
array arcsin(const array& a, StreamOrDevice s /* = {} */) { array arcsin(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s); auto input = astype(a, dtype, s);
return array( return array(
a.shape(), dtype, std::make_unique<ArcSin>(to_stream(s)), {input}); a.shape(), dtype, std::make_shared<ArcSin>(to_stream(s)), {input});
} }
array arccos(const array& a, StreamOrDevice s /* = {} */) { array arccos(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s); auto input = astype(a, dtype, s);
return array( return array(
a.shape(), dtype, std::make_unique<ArcCos>(to_stream(s)), {input}); a.shape(), dtype, std::make_shared<ArcCos>(to_stream(s)), {input});
} }
array arctan(const array& a, StreamOrDevice s /* = {} */) { array arctan(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s); auto input = astype(a, dtype, s);
return array( return array(
a.shape(), dtype, std::make_unique<ArcTan>(to_stream(s)), {input}); a.shape(), dtype, std::make_shared<ArcTan>(to_stream(s)), {input});
} }
array sinh(const array& a, StreamOrDevice s /* = {} */) { array sinh(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s); auto input = astype(a, dtype, s);
return array(a.shape(), dtype, std::make_unique<Sinh>(to_stream(s)), {input}); return array(a.shape(), dtype, std::make_shared<Sinh>(to_stream(s)), {input});
} }
array cosh(const array& a, StreamOrDevice s /* = {} */) { array cosh(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s); auto input = astype(a, dtype, s);
return array(a.shape(), dtype, std::make_unique<Cosh>(to_stream(s)), {input}); return array(a.shape(), dtype, std::make_shared<Cosh>(to_stream(s)), {input});
} }
array tanh(const array& a, StreamOrDevice s /* = {} */) { array tanh(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s); auto input = astype(a, dtype, s);
return array(a.shape(), dtype, std::make_unique<Tanh>(to_stream(s)), {input}); return array(a.shape(), dtype, std::make_shared<Tanh>(to_stream(s)), {input});
} }
array arcsinh(const array& a, StreamOrDevice s /* = {} */) { array arcsinh(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s); auto input = astype(a, dtype, s);
return array( return array(
a.shape(), dtype, std::make_unique<ArcSinh>(to_stream(s)), {input}); a.shape(), dtype, std::make_shared<ArcSinh>(to_stream(s)), {input});
} }
array arccosh(const array& a, StreamOrDevice s /* = {} */) { array arccosh(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s); auto input = astype(a, dtype, s);
return array( return array(
a.shape(), dtype, std::make_unique<ArcCosh>(to_stream(s)), {input}); a.shape(), dtype, std::make_shared<ArcCosh>(to_stream(s)), {input});
} }
array arctanh(const array& a, StreamOrDevice s /* = {} */) { array arctanh(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s); auto input = astype(a, dtype, s);
return array( return array(
a.shape(), dtype, std::make_unique<ArcTanh>(to_stream(s)), {input}); a.shape(), dtype, std::make_shared<ArcTanh>(to_stream(s)), {input});
} }
array log(const array& a, StreamOrDevice s /* = {} */) { array log(const array& a, StreamOrDevice s /* = {} */) {
@ -2085,7 +2103,7 @@ array log(const array& a, StreamOrDevice s /* = {} */) {
return array( return array(
a.shape(), a.shape(),
dtype, dtype,
std::make_unique<Log>(to_stream(s), Log::Base::e), std::make_shared<Log>(to_stream(s), Log::Base::e),
{input}); {input});
} }
@ -2095,7 +2113,7 @@ array log2(const array& a, StreamOrDevice s /* = {} */) {
return array( return array(
a.shape(), a.shape(),
dtype, dtype,
std::make_unique<Log>(to_stream(s), Log::Base::two), std::make_shared<Log>(to_stream(s), Log::Base::two),
{input}); {input});
} }
@ -2105,7 +2123,7 @@ array log10(const array& a, StreamOrDevice s /* = {} */) {
return array( return array(
a.shape(), a.shape(),
dtype, dtype,
std::make_unique<Log>(to_stream(s), Log::Base::ten), std::make_shared<Log>(to_stream(s), Log::Base::ten),
{input}); {input});
} }
@ -2113,26 +2131,27 @@ array log1p(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s); auto input = astype(a, dtype, s);
return array( return array(
a.shape(), dtype, std::make_unique<Log1p>(to_stream(s)), {input}); a.shape(), dtype, std::make_shared<Log1p>(to_stream(s)), {input});
} }
array logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) { array logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Make sure out type is floating point // Make sure out type is floating point
auto out_type = at_least_float(promote_types(a.dtype(), b.dtype())); auto out_type = at_least_float(promote_types(a.dtype(), b.dtype()));
auto inputs = auto inputs =
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
auto& shape = inputs[0].shape();
return array( return array(
inputs[0].shape(), shape,
out_type, out_type,
std::make_unique<LogAddExp>(to_stream(s)), std::make_shared<LogAddExp>(to_stream(s)),
inputs); std::move(inputs));
} }
array sigmoid(const array& a, StreamOrDevice s /* = {} */) { array sigmoid(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s); auto input = astype(a, dtype, s);
return array( return array(
a.shape(), dtype, std::make_unique<Sigmoid>(to_stream(s)), {input}); a.shape(), dtype, std::make_shared<Sigmoid>(to_stream(s)), {input});
} }
array erf(const array& a, StreamOrDevice s /* = {} */) { array erf(const array& a, StreamOrDevice s /* = {} */) {
@ -2140,7 +2159,7 @@ array erf(const array& a, StreamOrDevice s /* = {} */) {
return array( return array(
a.shape(), a.shape(),
dtype, dtype,
std::make_unique<Erf>(to_stream(s)), std::make_shared<Erf>(to_stream(s)),
{astype(a, dtype, s)}); {astype(a, dtype, s)});
} }
@ -2149,19 +2168,19 @@ array erfinv(const array& a, StreamOrDevice s /* = {} */) {
return array( return array(
a.shape(), a.shape(),
dtype, dtype,
std::make_unique<ErfInv>(to_stream(s)), std::make_shared<ErfInv>(to_stream(s)),
{astype(a, dtype, s)}); {astype(a, dtype, s)});
} }
array stop_gradient(const array& a, StreamOrDevice s /* = {} */) { array stop_gradient(const array& a, StreamOrDevice s /* = {} */) {
return array( return array(
a.shape(), a.dtype(), std::make_unique<StopGradient>(to_stream(s)), {a}); a.shape(), a.dtype(), std::make_shared<StopGradient>(to_stream(s)), {a});
} }
array round(const array& a, int decimals, StreamOrDevice s /* = {} */) { array round(const array& a, int decimals, StreamOrDevice s /* = {} */) {
if (decimals == 0) { if (decimals == 0) {
return array( return array(
a.shape(), a.dtype(), std::make_unique<Round>(to_stream(s)), {a}); a.shape(), a.dtype(), std::make_shared<Round>(to_stream(s)), {a});
} }
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
@ -2226,7 +2245,7 @@ array matmul(
auto out = array( auto out = array(
{a.shape(0), b.shape(1)}, {a.shape(0), b.shape(1)},
out_type, out_type,
std::make_unique<Matmul>(to_stream(s)), std::make_shared<Matmul>(to_stream(s)),
{a, b}); {a, b});
return reshape(out, out_shape, s); return reshape(out, out_shape, s);
} }
@ -2250,17 +2269,17 @@ array matmul(
auto out_shape = a.shape(); auto out_shape = a.shape();
out_shape.back() = b.shape(-1); out_shape.back() = b.shape(-1);
auto out = array( auto p = std::make_shared<Matmul>(to_stream(s));
out_shape, out_type, std::make_unique<Matmul>(to_stream(s)), {a, b});
// Remove the possibly inserted singleton dimensions // Remove the possibly inserted singleton dimensions
if (in_a.ndim() == 1 || in_b.ndim() == 1) { if (in_a.ndim() == 1 || in_b.ndim() == 1) {
auto out = array(out_shape, out_type, std::move(p), {a, b});
out_shape.erase( out_shape.erase(
out_shape.end() - ((in_a.ndim() == 1) ? 2 : 1), out_shape.end() - ((in_a.ndim() == 1) ? 2 : 1),
out_shape.end() - ((in_b.ndim() == 1) ? 0 : 1)); out_shape.end() - ((in_b.ndim() == 1) ? 0 : 1));
out = reshape(out, out_shape, s); return reshape(out, std::move(out_shape), s);
} }
return out; return array(std::move(out_shape), out_type, std::move(p), {a, b});
} }
array gather( array gather(
@ -2332,7 +2351,7 @@ array gather(
return array( return array(
out_shape, out_shape,
a.dtype(), a.dtype(),
std::make_unique<Gather>(to_stream(s), axes, slice_sizes), std::make_shared<Gather>(to_stream(s), axes, slice_sizes),
inputs); inputs);
} }
@ -2516,7 +2535,7 @@ array scatter(
return array( return array(
a.shape(), a.shape(),
a.dtype(), a.dtype(),
std::make_unique<Scatter>(to_stream(s), mode, axes), std::make_shared<Scatter>(to_stream(s), mode, axes),
inputs); inputs);
} }
@ -2570,7 +2589,7 @@ array sqrt(const array& a, StreamOrDevice s /* = {} */) {
return array( return array(
a.shape(), a.shape(),
dtype, dtype,
std::make_unique<Sqrt>(to_stream(s)), std::make_shared<Sqrt>(to_stream(s)),
{astype(a, dtype, s)}); {astype(a, dtype, s)});
} }
@ -2579,7 +2598,7 @@ array rsqrt(const array& a, StreamOrDevice s /* = {} */) {
return array( return array(
a.shape(), a.shape(),
dtype, dtype,
std::make_unique<Sqrt>(to_stream(s), true), std::make_shared<Sqrt>(to_stream(s), true),
{astype(a, dtype, s)}); {astype(a, dtype, s)});
} }
@ -2592,7 +2611,7 @@ array softmax(
return array( return array(
a.shape(), a.shape(),
dtype, dtype,
std::make_unique<Softmax>(to_stream(s)), std::make_shared<Softmax>(to_stream(s)),
{astype(a, dtype, s)}); {astype(a, dtype, s)});
} else { } else {
auto a_max = stop_gradient(max(a, axes, /*keepdims = */ true, s), s); auto a_max = stop_gradient(max(a, axes, /*keepdims = */ true, s), s);
@ -2614,7 +2633,7 @@ array power(const array& a, const array& b, StreamOrDevice s /* = {} */) {
inputs = broadcast_arrays(inputs, s); inputs = broadcast_arrays(inputs, s);
} }
return array( return array(
inputs[0].shape(), dtype, std::make_unique<Power>(to_stream(s)), inputs); inputs[0].shape(), dtype, std::make_shared<Power>(to_stream(s)), inputs);
} }
array cumsum( array cumsum(
@ -2635,7 +2654,7 @@ array cumsum(
return array( return array(
a.shape(), a.shape(),
out_type, out_type,
std::make_unique<Scan>( std::make_shared<Scan>(
to_stream(s), Scan::ReduceType::Sum, axis, reverse, inclusive), to_stream(s), Scan::ReduceType::Sum, axis, reverse, inclusive),
{a}); {a});
} }
@ -2657,7 +2676,7 @@ array cumprod(
return array( return array(
a.shape(), a.shape(),
a.dtype(), a.dtype(),
std::make_unique<Scan>( std::make_shared<Scan>(
to_stream(s), Scan::ReduceType::Prod, axis, reverse, inclusive), to_stream(s), Scan::ReduceType::Prod, axis, reverse, inclusive),
{a}); {a});
} }
@ -2679,7 +2698,7 @@ array cummax(
return array( return array(
a.shape(), a.shape(),
a.dtype(), a.dtype(),
std::make_unique<Scan>( std::make_shared<Scan>(
to_stream(s), Scan::ReduceType::Max, axis, reverse, inclusive), to_stream(s), Scan::ReduceType::Max, axis, reverse, inclusive),
{a}); {a});
} }
@ -2701,7 +2720,7 @@ array cummin(
return array( return array(
a.shape(), a.shape(),
a.dtype(), a.dtype(),
std::make_unique<Scan>( std::make_shared<Scan>(
to_stream(s), Scan::ReduceType::Min, axis, reverse, inclusive), to_stream(s), Scan::ReduceType::Min, axis, reverse, inclusive),
{a}); {a});
} }
@ -2965,7 +2984,7 @@ array conv_general(
return array( return array(
out_shape, out_shape,
in.dtype(), in.dtype(),
std::make_unique<Convolution>( std::make_shared<Convolution>(
to_stream(s), to_stream(s),
stride, stride,
padding_lo, padding_lo,
@ -3042,7 +3061,7 @@ array quantized_matmul(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
auto dtype = result_type({x, scales, biases}); auto dtype = result_type(x, scales, biases);
if (!is_floating_point(dtype) || is_complex(dtype)) { if (!is_floating_point(dtype) || is_complex(dtype)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[quantized_matmul] Only real floating types are supported but " msg << "[quantized_matmul] Only real floating types are supported but "
@ -3055,7 +3074,7 @@ array quantized_matmul(
auto out = array( auto out = array(
{x.shape(0), w_outer_dims}, {x.shape(0), w_outer_dims},
dtype, dtype,
std::make_unique<QuantizedMatmul>( std::make_shared<QuantizedMatmul>(
to_stream(s), group_size, bits, transpose), to_stream(s), group_size, bits, transpose),
{astype(x, dtype, s), {astype(x, dtype, s),
w, w,
@ -3065,7 +3084,7 @@ array quantized_matmul(
// If needed reshape x to the original batch shape // If needed reshape x to the original batch shape
if (original_shape.size() != 1) { if (original_shape.size() != 1) {
original_shape.push_back(w_outer_dims); original_shape.push_back(w_outer_dims);
out = reshape(out, original_shape, s); out = reshape(out, std::move(original_shape), s);
} }
return out; return out;
@ -3344,7 +3363,7 @@ array addmm(
} }
// Type promotion // Type promotion
auto out_type = result_type({a, b, c}); auto out_type = result_type(a, b, c);
if (!is_floating_point(out_type) || is_complex(out_type)) { if (!is_floating_point(out_type) || is_complex(out_type)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[addmm] Only real floating point types are supported but " msg << "[addmm] Only real floating point types are supported but "
@ -3373,7 +3392,7 @@ array addmm(
auto out = array( auto out = array(
{a.shape(0), b.shape(1)}, {a.shape(0), b.shape(1)},
out_type, out_type,
std::make_unique<AddMM>(to_stream(s), alpha, beta), std::make_shared<AddMM>(to_stream(s), alpha, beta),
{a, b, c}); {a, b, c});
return reshape(out, out_shape, s); return reshape(out, out_shape, s);
} }
@ -3425,7 +3444,7 @@ array addmm(
auto out = array( auto out = array(
out_shape, out_shape,
out_type, out_type,
std::make_unique<AddMM>(to_stream(s), alpha, beta), std::make_shared<AddMM>(to_stream(s), alpha, beta),
{a, b, c}); {a, b, c});
// Remove the possibly inserted singleton dimensions // Remove the possibly inserted singleton dimensions
@ -3621,7 +3640,7 @@ array number_of_elements(
return stop_gradient(array( return stop_gradient(array(
std::vector<int>{}, std::vector<int>{},
dtype, dtype,
std::make_unique<NumberOfElements>( std::make_shared<NumberOfElements>(
to_stream(s), std::move(axes), inverted, dtype), to_stream(s), std::move(axes), inverted, dtype),
{a})); {a}));
} }

View File

@ -158,9 +158,7 @@ array expand_dims(
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Add a singleton dimension at the given axis. */ /** Add a singleton dimension at the given axis. */
inline array expand_dims(const array& a, int axis, StreamOrDevice s = {}) { array expand_dims(const array& a, int axis, StreamOrDevice s = {});
return expand_dims(a, std::vector<int>{axis}, s);
}
/** Slice an array. */ /** Slice an array. */
array slice( array slice(

View File

@ -66,7 +66,7 @@ array bits(
return array( return array(
shape, shape,
get_dtype(), get_dtype(),
std::make_unique<RandomBits>(to_stream(s), shape, width), std::make_shared<RandomBits>(to_stream(s), shape, width),
{key}); {key});
} }

View File

@ -54,6 +54,12 @@ struct PrintFormatter {
extern PrintFormatter global_formatter; extern PrintFormatter global_formatter;
/** The type from promoting the arrays' types with one another. */ /** The type from promoting the arrays' types with one another. */
inline Dtype result_type(const array& a, const array& b) {
return promote_types(a.dtype(), b.dtype());
}
inline Dtype result_type(const array& a, const array& b, const array& c) {
return promote_types(result_type(a, b), c.dtype());
}
Dtype result_type(const std::vector<array>& arrays); Dtype result_type(const std::vector<array>& arrays);
std::vector<int> broadcast_shapes( std::vector<int> broadcast_shapes(

View File

@ -131,8 +131,8 @@ class Module(dict):
return value return value
def __getattr__(self, key: str): def __getattr__(self, key: str):
if key in self: if (value := self.get(key, None)) is not None:
return self[key] return value
else: else:
super(Module, self).__getattribute__(key) super(Module, self).__getattribute__(key)

View File

@ -64,9 +64,9 @@ class Linear(Module):
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
if "bias" in self: if "bias" in self:
x = mx.addmm(self.bias, x, self.weight.T) x = mx.addmm(self["bias"], x, self["weight"].T)
else: else:
x = x @ self.weight.T x = x @ self["weight"].T
return x return x

View File

@ -140,7 +140,7 @@ class RMSNorm(Module):
return f"{self.weight.shape[0]}, eps={self.eps}" return f"{self.weight.shape[0]}, eps={self.eps}"
def __call__(self, x): def __call__(self, x):
return mx.fast.rms_norm(x, self.weight, self.eps) return mx.fast.rms_norm(x, self["weight"], self.eps)
class GroupNorm(Module): class GroupNorm(Module):

View File

@ -81,15 +81,15 @@ class QuantizedLinear(Module):
def __call__(self, x): def __call__(self, x):
x = mx.quantized_matmul( x = mx.quantized_matmul(
x, x,
self.weight, self["weight"],
scales=self.scales, scales=self["scales"],
biases=self.biases, biases=self["biases"],
transpose=True, transpose=True,
group_size=self.group_size, group_size=self.group_size,
bits=self.bits, bits=self.bits,
) )
if "bias" in self: if "bias" in self:
x = x + self.bias x = x + self["bias"]
return x return x
@classmethod @classmethod

View File

@ -17,12 +17,7 @@ void init_fast(nb::module_& parent_module) {
m.def( m.def(
"rms_norm", "rms_norm",
[](const array& x, &fast::rms_norm,
const array& weight,
float eps,
const StreamOrDevice& s /* = {} */) {
return fast::rms_norm(x, weight, eps, s);
},
"x"_a, "x"_a,
"weight"_a, "weight"_a,
"eps"_a, "eps"_a,
@ -48,13 +43,7 @@ void init_fast(nb::module_& parent_module) {
m.def( m.def(
"layer_norm", "layer_norm",
[](const array& x, &fast::layer_norm,
const std::optional<array>& weight,
const std::optional<array>& bias,
float eps,
const StreamOrDevice& s /* = {} */) {
return fast::layer_norm(x, weight, bias, eps, s);
},
"x"_a, "x"_a,
"weight"_a.none(), "weight"_a.none(),
"bias"_a.none(), "bias"_a.none(),
@ -84,15 +73,7 @@ void init_fast(nb::module_& parent_module) {
m.def( m.def(
"rope", "rope",
[](const array& a, &fast::rope,
int dims,
bool traditional,
float base,
float scale,
int offset,
const StreamOrDevice& s /* = {} */) {
return fast::rope(a, dims, traditional, base, scale, offset, s);
},
"a"_a, "a"_a,
"dims"_a, "dims"_a,
nb::kw_only(), nb::kw_only(),
@ -123,14 +104,7 @@ void init_fast(nb::module_& parent_module) {
m.def( m.def(
"scaled_dot_product_attention", "scaled_dot_product_attention",
[](const array& q, &fast::scaled_dot_product_attention,
const array& k,
const array& v,
const float scale,
const std::optional<array>& mask,
const StreamOrDevice& s) {
return fast::scaled_dot_product_attention(q, k, v, scale, mask, s);
},
"q"_a, "q"_a,
"k"_a, "k"_a,
"v"_a, "v"_a,