Reduce a little overhead (#871)

* some small overhead improvements

* use result_type in rms_norm

* remove release force

* fix + use non-vector version

* revert compile change

* fix ops

* a little more overhead

* a little more cleanup and overhead
This commit is contained in:
Awni Hannun 2024-03-22 17:29:36 -07:00 committed by GitHub
parent 6ee1112f30
commit be98f4ab6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 239 additions and 240 deletions

View File

@ -12,16 +12,6 @@ namespace mlx::core {
namespace {
std::pair<size_t, std::vector<size_t>> cum_prod(const std::vector<int>& shape) {
std::vector<size_t> strides(shape.size());
size_t cum_prod = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
strides[i] = cum_prod;
cum_prod *= shape[i];
}
return {cum_prod, strides};
}
/** Return true if we are currently performing a function transformation in
* order to keep the graph when evaluating tracer arrays. */
bool in_tracing() {
@ -171,9 +161,21 @@ void array::move_shared_buffer(array other) {
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}
void array::ArrayDesc::init() {
strides.resize(shape.size());
size = 1;
for (int i = shape.size() - 1; i >= 0; --i) {
strides[i] = size;
size *= shape[i];
}
for (auto& in : inputs) {
is_tracer |= in.is_tracer();
}
}
array::ArrayDesc::ArrayDesc(std::vector<int> shape, Dtype dtype)
: shape(std::move(shape)), dtype(dtype) {
std::tie(size, strides) = cum_prod(this->shape);
init();
}
array::ArrayDesc::ArrayDesc(
@ -185,10 +187,7 @@ array::ArrayDesc::ArrayDesc(
dtype(dtype),
primitive(std::move(primitive)),
inputs(std::move(inputs)) {
std::tie(size, strides) = cum_prod(this->shape);
for (auto& in : this->inputs) {
is_tracer |= in.is_tracer();
}
init();
}
array::ArrayIterator::ArrayIterator(const array& arr, int idx)

View File

@ -392,6 +392,10 @@ class array {
Dtype dtype,
std::shared_ptr<Primitive> primitive,
std::vector<array> inputs);
private:
// Initialize size, strides, and other metadata
void init();
};
// The ArrayDesc contains the details of the materialized array including the

View File

@ -63,7 +63,7 @@ array rms_norm(
<< " dimensions.";
throw std::invalid_argument(msg.str());
}
auto out_type = result_type({x, weight});
auto out_type = result_type(x, weight);
if (!is_floating_point(out_type) || is_complex(out_type)) {
std::ostringstream msg;
msg << "[rms_norm] Received unsupported type " << out_type << ".";
@ -88,7 +88,7 @@ array rms_norm(
return array(
x.shape(),
out_type,
std::make_unique<RMSNorm>(s, fallback, eps),
std::make_shared<RMSNorm>(s, fallback, eps),
{astype(x, out_type, s), astype(weight, out_type, s)});
}
return fallback({x, weight})[0];
@ -125,8 +125,8 @@ array layer_norm(
}
auto out_type = (weight.has_value())
? ((bias.has_value()) ? result_type({x, *weight, *bias})
: result_type({x, *weight}))
? ((bias.has_value()) ? result_type(x, *weight, *bias)
: result_type(x, *weight))
: x.dtype();
if (!is_floating_point(out_type) || is_complex(out_type)) {
std::ostringstream msg;
@ -170,7 +170,7 @@ array layer_norm(
return array(
x.shape(),
out_type,
std::make_unique<LayerNorm>(s, fallback, eps),
std::make_shared<LayerNorm>(s, fallback, eps),
{astype(x, out_type, s), passed_weight, passed_bias});
}
return fallback({x, passed_weight, passed_bias})[0];
@ -248,7 +248,7 @@ array rope(
return array(
x.shape(),
x.dtype(),
std::make_unique<RoPE>(
std::make_shared<RoPE>(
stream, fallback, dims, traditional, base, scale, offset),
{x});
}
@ -318,7 +318,7 @@ array scaled_dot_product_attention(
throw std::invalid_argument(msg.str());
}
auto final_type = result_type({queries, keys, values});
auto final_type = result_type(queries, keys, values);
if (!is_floating_point(final_type) || is_complex(final_type)) {
std::ostringstream msg;
msg << "[scaled_dot_product_attention] Received unsupported type "
@ -330,9 +330,6 @@ array scaled_dot_product_attention(
auto k = astype(keys, final_type, s);
auto v = astype(values, final_type, s);
auto out_shape =
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
/* generic implementation for use cases that Metal implementation does not
* support. For non-supported cases listed below, use MLX primitives:
* * CPU implementation
@ -381,10 +378,12 @@ array scaled_dot_product_attention(
// TODO, update routing conditions post further tuning
implementation_supports_use_case &= false;
if (implementation_supports_use_case) {
auto out_shape =
std::vector<int>({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)});
auto out = array(
out_shape,
std::move(out_shape),
final_type,
std::make_unique<ScaledDotProductAttention>(
std::make_shared<ScaledDotProductAttention>(
stream, fallback, scale, false),
{q, k, v});
return out;

View File

@ -195,7 +195,7 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
auto out = array::make_arrays(
{a.shape(), a.shape()},
{a.dtype(), a.dtype()},
std::make_unique<QRF>(to_stream(s)),
std::make_shared<QRF>(to_stream(s)),
{astype(a, a.dtype(), s)});
return std::make_pair(out[0], out[1]);
}
@ -234,7 +234,7 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
return array::make_arrays(
{u_shape, s_shape, vt_shape},
{a.dtype(), a.dtype(), a.dtype()},
std::make_unique<SVD>(to_stream(s)),
std::make_shared<SVD>(to_stream(s)),
{a});
}
@ -258,7 +258,7 @@ array inv(const array& a, StreamOrDevice s /* = {} */) {
}
return array(
a.shape(), a.dtype(), std::make_unique<Inverse>(to_stream(s)), {a});
a.shape(), a.dtype(), std::make_shared<Inverse>(to_stream(s)), {a});
}
} // namespace mlx::core::linalg

View File

@ -88,7 +88,7 @@ array arange(
return array(
{size},
dtype,
std::make_unique<Arange>(to_stream(s), start, stop, step),
std::make_shared<Arange>(to_stream(s), start, stop, step),
{});
}
array arange(
@ -163,7 +163,7 @@ array astype(const array& a, Dtype dtype, StreamOrDevice s /* = {} */) {
return a;
}
return array(
a.shape(), dtype, std::make_unique<AsType>(to_stream(s), dtype), {a});
a.shape(), dtype, std::make_shared<AsType>(to_stream(s), dtype), {a});
}
array as_strided(
@ -177,12 +177,12 @@ array as_strided(
return array(
shape,
a.dtype(),
std::make_unique<AsStrided>(to_stream(s), shape, strides, offset),
std::make_shared<AsStrided>(to_stream(s), shape, strides, offset),
{x});
}
array copy(const array& a, StreamOrDevice s /* = {} */) {
return array(a.shape(), a.dtype(), std::make_unique<Copy>(to_stream(s)), {a});
return array(a.shape(), a.dtype(), std::make_shared<Copy>(to_stream(s)), {a});
}
array full(
@ -194,7 +194,7 @@ array full(
throw std::invalid_argument("[full] Negative dimensions not allowed.");
}
auto in = broadcast_to(astype(vals, dtype, s), shape, s);
return array(shape, dtype, std::make_unique<Full>(to_stream(s)), {in});
return array(shape, dtype, std::make_shared<Full>(to_stream(s)), {in});
}
array full(
@ -313,9 +313,8 @@ array reshape(
<< " into shape " << shape << ".";
throw std::invalid_argument(msg.str());
}
return array(
shape, a.dtype(), std::make_unique<Reshape>(to_stream(s), shape), {a});
auto p = std::make_shared<Reshape>(to_stream(s), shape);
return array(std::move(shape), a.dtype(), std::move(p), {a});
}
array flatten(
@ -408,6 +407,20 @@ array squeeze(const array& a, StreamOrDevice s /* = {} */) {
return squeeze(a, axes, s);
}
array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) {
int out_dim = a.ndim() + 1;
int ax = axis < 0 ? axis + out_dim : axis;
if (ax < 0 || ax >= out_dim) {
std::ostringstream msg;
msg << "[expand_dims] Invalid axes " << axis << " for output array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
auto shape = a.shape();
shape.insert(shape.begin() + ax, 1);
return reshape(a, std::move(shape), s);
}
array expand_dims(
const array& a,
const std::vector<int>& axes,
@ -425,7 +438,7 @@ array expand_dims(
ax = ax < 0 ? ax + out_ndim : ax;
if (ax < 0 || ax >= out_ndim) {
std::ostringstream msg;
msg << "[squeeze] Invalid axes " << ax << " for output array with "
msg << "[expand_dims] Invalid axes " << ax << " for output array with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}
@ -442,7 +455,7 @@ array expand_dims(
for (int i = 0; i < sorted_axes.size(); ++i) {
out_shape.insert(out_shape.begin() + sorted_axes[i], 1);
}
return reshape(a, out_shape, s);
return reshape(a, std::move(out_shape), s);
}
// Slice helper
@ -523,7 +536,7 @@ array slice(
return array(
out_shape,
a.dtype(),
std::make_unique<Slice>(
std::make_shared<Slice>(
to_stream(s), std::move(start), std::move(stop), std::move(strides)),
{a});
}
@ -568,7 +581,7 @@ array slice_update(
return array(
src.shape(),
src.dtype(),
std::make_unique<SliceUpdate>(
std::make_shared<SliceUpdate>(
to_stream(s), std::move(start), std::move(stop), std::move(strides)),
{src, update_broadcasted});
}
@ -743,7 +756,10 @@ array concatenate(
auto dtype = result_type(arrays);
return array(
shape, dtype, std::make_unique<Concatenate>(to_stream(s), ax), arrays);
std::move(shape),
dtype,
std::make_shared<Concatenate>(to_stream(s), ax),
std::move(arrays));
}
array concatenate(
@ -886,7 +902,7 @@ array pad(
return array(
out_shape,
a.dtype(),
std::make_unique<Pad>(to_stream(s), axes, low_pad_size, high_pad_size),
std::make_shared<Pad>(to_stream(s), axes, low_pad_size, high_pad_size),
{a, astype(pad_value, a.dtype(), s)});
}
@ -975,7 +991,7 @@ array swapaxes(
std::vector<int> reorder(a.ndim());
std::iota(reorder.begin(), reorder.end(), 0);
std::swap(reorder[axis1], reorder[axis2]);
return transpose(a, reorder, s);
return transpose(a, std::move(reorder), s);
}
array transpose(
@ -1000,9 +1016,9 @@ array transpose(
shape.push_back(a.shape()[ax]);
}
return array(
shape,
std::move(shape),
a.dtype(),
std::make_unique<Transpose>(to_stream(s), std::move(axes)),
std::make_shared<Transpose>(to_stream(s), std::move(axes)),
{a});
}
@ -1029,7 +1045,16 @@ array broadcast_to(
throw std::invalid_argument(msg.str());
}
return array(
shape, a.dtype(), std::make_unique<Broadcast>(to_stream(s), shape), {a});
std::move(bxshape),
a.dtype(),
std::make_shared<Broadcast>(to_stream(s), shape),
{a});
}
std::vector<array>
broadcast_arrays(const array& a, const array& b, StreamOrDevice s /* = {} */) {
std::vector<int> shape = broadcast_shapes(a.shape(), b.shape());
return {broadcast_to(a, shape, s), broadcast_to(b, shape, s)};
}
std::vector<array> broadcast_arrays(
@ -1048,38 +1073,29 @@ std::vector<array> broadcast_arrays(
array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};
if (a.shape() != b.shape()) {
inputs = broadcast_arrays(inputs, s);
}
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto& shape = inputs[0].shape();
return array(
inputs[0].shape(), bool_, std::make_unique<Equal>(to_stream(s)), inputs);
shape, bool_, std::make_shared<Equal>(to_stream(s)), std::move(inputs));
}
array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};
if (a.shape() != b.shape()) {
inputs = broadcast_arrays(inputs, s);
}
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto& shape = inputs[0].shape();
return array(
inputs[0].shape(),
shape,
bool_,
std::make_unique<NotEqual>(to_stream(s)),
inputs);
std::make_shared<NotEqual>(to_stream(s)),
std::move(inputs));
}
array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};
if (a.shape() != b.shape()) {
inputs = broadcast_arrays(inputs, s);
}
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto& shape = inputs[0].shape();
return array(
inputs[0].shape(),
bool_,
std::make_unique<Greater>(to_stream(s)),
inputs);
shape, bool_, std::make_shared<Greater>(to_stream(s)), std::move(inputs));
}
array greater_equal(
@ -1087,38 +1103,32 @@ array greater_equal(
const array& b,
StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};
if (a.shape() != b.shape()) {
inputs = broadcast_arrays(inputs, s);
}
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto& shape = inputs[0].shape();
return array(
inputs[0].shape(),
shape,
bool_,
std::make_unique<GreaterEqual>(to_stream(s)),
inputs);
std::make_shared<GreaterEqual>(to_stream(s)),
std::move(inputs));
}
array less(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};
if (a.shape() != b.shape()) {
inputs = broadcast_arrays(inputs, s);
}
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto& shape = inputs[0].shape();
return array(
inputs[0].shape(), bool_, std::make_unique<Less>(to_stream(s)), inputs);
shape, bool_, std::make_shared<Less>(to_stream(s)), std::move(inputs));
}
array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
std::vector<array> inputs = {astype(a, dtype, s), astype(b, dtype, s)};
if (a.shape() != b.shape()) {
inputs = broadcast_arrays(inputs, s);
}
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto& shape = inputs[0].shape();
return array(
inputs[0].shape(),
shape,
bool_,
std::make_unique<LessEqual>(to_stream(s)),
inputs);
std::make_shared<LessEqual>(to_stream(s)),
std::move(inputs));
}
array array_equal(
@ -1135,7 +1145,7 @@ array array_equal(
array(
a.shape(),
bool_,
std::make_unique<Equal>(to_stream(s), equal_nan),
std::make_shared<Equal>(to_stream(s), equal_nan),
{astype(a, dtype, s), astype(b, dtype, s)}),
false,
s);
@ -1180,7 +1190,7 @@ array where(
return array(
inputs[0].shape(),
out_dtype,
std::make_unique<Select>(to_stream(s)),
std::make_shared<Select>(to_stream(s)),
inputs);
}
@ -1246,7 +1256,7 @@ array all(
auto out = array(
out_shape,
bool_,
std::make_unique<Reduce>(to_stream(s), Reduce::And, sorted_axes),
std::make_shared<Reduce>(to_stream(s), Reduce::And, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
@ -1280,7 +1290,7 @@ array any(
auto out = array(
out_shape,
bool_,
std::make_unique<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
std::make_shared<Reduce>(to_stream(s), Reduce::Or, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
@ -1315,7 +1325,7 @@ array sum(
auto out = array(
out_shape,
out_type,
std::make_unique<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
std::make_shared<Reduce>(to_stream(s), Reduce::Sum, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
@ -1424,7 +1434,7 @@ array prod(
auto out = array(
out_shape,
a.dtype(),
std::make_unique<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
std::make_shared<Reduce>(to_stream(s), Reduce::Prod, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
@ -1461,7 +1471,7 @@ array max(
auto out = array(
out_shape,
a.dtype(),
std::make_unique<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
std::make_shared<Reduce>(to_stream(s), Reduce::Max, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
@ -1498,7 +1508,7 @@ array min(
auto out = array(
out_shape,
a.dtype(),
std::make_unique<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
std::make_shared<Reduce>(to_stream(s), Reduce::Min, sorted_axes),
{a});
if (!keepdims) {
out = squeeze(out, sorted_axes, s);
@ -1538,7 +1548,7 @@ array argmin(
auto out = array(
out_shape,
uint32,
std::make_unique<ArgReduce>(
std::make_shared<ArgReduce>(
to_stream(s), ArgReduce::ArgMin, sorted_axes[0]),
{a});
if (!keepdims) {
@ -1571,7 +1581,7 @@ array argmax(
auto out = array(
out_shape,
uint32,
std::make_unique<ArgReduce>(
std::make_shared<ArgReduce>(
to_stream(s), ArgReduce::ArgMax, sorted_axes[0]),
{a});
if (!keepdims) {
@ -1607,7 +1617,7 @@ array sort(const array& a, int axis, StreamOrDevice s /* = {} */) {
}
return array(
a.shape(), a.dtype(), std::make_unique<Sort>(to_stream(s), axis), {a});
a.shape(), a.dtype(), std::make_shared<Sort>(to_stream(s), axis), {a});
}
/** Returns indices that sort the flattened array. */
@ -1637,7 +1647,7 @@ array argsort(const array& a, int axis, StreamOrDevice s /* = {} */) {
}
return array(
a.shape(), uint32, std::make_unique<ArgSort>(to_stream(s), axis), {a});
a.shape(), uint32, std::make_shared<ArgSort>(to_stream(s), axis), {a});
}
/**
@ -1677,7 +1687,7 @@ array partition(
return array(
a.shape(),
a.dtype(),
std::make_unique<Partition>(to_stream(s), kth_, axis_),
std::make_shared<Partition>(to_stream(s), kth_, axis_),
{a});
}
@ -1718,7 +1728,7 @@ array argpartition(
return array(
a.shape(),
uint32,
std::make_unique<ArgPartition>(to_stream(s), kth_, axis_),
std::make_shared<ArgPartition>(to_stream(s), kth_, axis_),
{a});
}
@ -1787,7 +1797,7 @@ array logsumexp(
array abs(const array& a, StreamOrDevice s /* = {} */) {
auto out =
array(a.shape(), a.dtype(), std::make_unique<Abs>(to_stream(s)), {a});
array(a.shape(), a.dtype(), std::make_shared<Abs>(to_stream(s)), {a});
if (a.dtype() == complex64) {
out = astype(out, float32, s);
}
@ -1800,33 +1810,33 @@ array negative(const array& a, StreamOrDevice s /* = {} */) {
throw std::invalid_argument(msg);
}
return array(
a.shape(), a.dtype(), std::make_unique<Negative>(to_stream(s)), {a});
a.shape(), a.dtype(), std::make_shared<Negative>(to_stream(s)), {a});
}
array operator-(const array& a) {
return negative(a);
}
array sign(const array& a, StreamOrDevice s /* = {} */) {
return array(a.shape(), a.dtype(), std::make_unique<Sign>(to_stream(s)), {a});
return array(a.shape(), a.dtype(), std::make_shared<Sign>(to_stream(s)), {a});
}
array logical_not(const array& a, StreamOrDevice s /* = {} */) {
return array(
a.shape(),
bool_,
std::make_unique<LogicalNot>(to_stream(s)),
std::make_shared<LogicalNot>(to_stream(s)),
{astype(a, bool_, s)});
}
array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Broadcast arrays to a common shape
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);
auto& shape = inputs[0].shape();
return array(
inputs[0].shape(),
shape,
bool_,
std::make_unique<LogicalAnd>(to_stream(s)),
inputs);
std::make_shared<LogicalAnd>(to_stream(s)),
std::move(inputs));
}
array operator&&(const array& a, const array& b) {
return logical_and(a, b);
@ -1834,13 +1844,13 @@ array operator&&(const array& a, const array& b) {
array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Broadcast arrays to a common shape
auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s);
auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s);
auto& shape = inputs[0].shape();
return array(
inputs[0].shape(),
shape,
bool_,
std::make_unique<LogicalOr>(to_stream(s)),
inputs);
std::make_shared<LogicalOr>(to_stream(s)),
std::move(inputs));
}
array operator||(const array& a, const array& b) {
return logical_or(a, b);
@ -1854,9 +1864,10 @@ array reciprocal(const array& a, StreamOrDevice s /* = {} */) {
array add(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
auto& shape = inputs[0].shape();
return array(
inputs[0].shape(), out_type, std::make_unique<Add>(to_stream(s)), inputs);
shape, out_type, std::make_shared<Add>(to_stream(s)), std::move(inputs));
}
array operator+(const array& a, const array& b) {
@ -1866,12 +1877,13 @@ array operator+(const array& a, const array& b) {
array subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
auto& shape = inputs[0].shape();
return array(
inputs[0].shape(),
shape,
out_type,
std::make_unique<Subtract>(to_stream(s)),
inputs);
std::make_shared<Subtract>(to_stream(s)),
std::move(inputs));
}
array operator-(const array& a, const array& b) {
@ -1881,12 +1893,13 @@ array operator-(const array& a, const array& b) {
array multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
auto& shape = inputs[0].shape();
return array(
inputs[0].shape(),
shape,
out_type,
std::make_unique<Multiply>(to_stream(s)),
inputs);
std::make_shared<Multiply>(to_stream(s)),
std::move(inputs));
}
array operator*(const array& a, const array& b) {
@ -1895,10 +1908,11 @@ array operator*(const array& a, const array& b) {
array divide(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(promote_types(a.dtype(), b.dtype()));
auto inputs = broadcast_arrays(
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
auto inputs =
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
auto& shape = inputs[0].shape();
return array(
inputs[0].shape(), dtype, std::make_unique<Divide>(to_stream(s)), inputs);
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
}
array operator/(const array& a, const array& b) {
return divide(a, b);
@ -1919,20 +1933,22 @@ array floor_divide(
return floor(divide(a, b, s), s);
}
auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s);
auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s);
auto& shape = inputs[0].shape();
return array(
inputs[0].shape(), dtype, std::make_unique<Divide>(to_stream(s)), inputs);
shape, dtype, std::make_shared<Divide>(to_stream(s)), std::move(inputs));
}
array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto dtype = promote_types(a.dtype(), b.dtype());
auto inputs = broadcast_arrays(
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
auto inputs =
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
auto& shape = inputs[0].shape();
return array(
inputs[0].shape(),
shape,
dtype,
std::make_unique<Remainder>(to_stream(s)),
inputs);
std::make_shared<Remainder>(to_stream(s)),
std::move(inputs));
}
array operator%(const array& a, const array& b) {
return remainder(a, b);
@ -1944,8 +1960,8 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
if (is_complex(dtype)) {
throw std::invalid_argument("[divmod] Complex type not supported.");
}
auto inputs = broadcast_arrays(
{astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s);
auto inputs =
broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s);
return array::make_arrays(
{inputs[0].shape(), inputs[0].shape()},
{inputs[0].dtype(), inputs[0].dtype()},
@ -1956,23 +1972,25 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) {
array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
auto& shape = inputs[0].shape();
return array(
inputs[0].shape(),
shape,
out_type,
std::make_unique<Maximum>(to_stream(s)),
inputs);
std::make_shared<Maximum>(to_stream(s)),
std::move(inputs));
}
array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) {
auto out_type = promote_types(a.dtype(), b.dtype());
auto inputs =
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
auto& shape = inputs[0].shape();
return array(
inputs[0].shape(),
shape,
out_type,
std::make_unique<Minimum>(to_stream(s)),
inputs);
std::make_shared<Minimum>(to_stream(s)),
std::move(inputs));
}
array floor(const array& a, StreamOrDevice s /* = {} */) {
@ -1980,103 +1998,103 @@ array floor(const array& a, StreamOrDevice s /* = {} */) {
throw std::invalid_argument("[floor] Not supported for complex64.");
}
return array(
a.shape(), a.dtype(), std::make_unique<Floor>(to_stream(s)), {a});
a.shape(), a.dtype(), std::make_shared<Floor>(to_stream(s)), {a});
}
array ceil(const array& a, StreamOrDevice s /* = {} */) {
if (a.dtype() == complex64) {
throw std::invalid_argument("[floor] Not supported for complex64.");
}
return array(a.shape(), a.dtype(), std::make_unique<Ceil>(to_stream(s)), {a});
return array(a.shape(), a.dtype(), std::make_shared<Ceil>(to_stream(s)), {a});
}
array square(const array& a, StreamOrDevice s /* = {} */) {
return array(
a.shape(), a.dtype(), std::make_unique<Square>(to_stream(s)), {a});
a.shape(), a.dtype(), std::make_shared<Square>(to_stream(s)), {a});
}
array exp(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
return array(a.shape(), dtype, std::make_unique<Exp>(to_stream(s)), {input});
return array(a.shape(), dtype, std::make_shared<Exp>(to_stream(s)), {input});
}
array sin(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
return array(a.shape(), dtype, std::make_unique<Sin>(to_stream(s)), {input});
return array(a.shape(), dtype, std::make_shared<Sin>(to_stream(s)), {input});
}
array cos(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
return array(a.shape(), dtype, std::make_unique<Cos>(to_stream(s)), {input});
return array(a.shape(), dtype, std::make_shared<Cos>(to_stream(s)), {input});
}
array tan(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
return array(a.shape(), dtype, std::make_unique<Tan>(to_stream(s)), {input});
return array(a.shape(), dtype, std::make_shared<Tan>(to_stream(s)), {input});
}
array arcsin(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
return array(
a.shape(), dtype, std::make_unique<ArcSin>(to_stream(s)), {input});
a.shape(), dtype, std::make_shared<ArcSin>(to_stream(s)), {input});
}
array arccos(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
return array(
a.shape(), dtype, std::make_unique<ArcCos>(to_stream(s)), {input});
a.shape(), dtype, std::make_shared<ArcCos>(to_stream(s)), {input});
}
array arctan(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
return array(
a.shape(), dtype, std::make_unique<ArcTan>(to_stream(s)), {input});
a.shape(), dtype, std::make_shared<ArcTan>(to_stream(s)), {input});
}
array sinh(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
return array(a.shape(), dtype, std::make_unique<Sinh>(to_stream(s)), {input});
return array(a.shape(), dtype, std::make_shared<Sinh>(to_stream(s)), {input});
}
array cosh(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
return array(a.shape(), dtype, std::make_unique<Cosh>(to_stream(s)), {input});
return array(a.shape(), dtype, std::make_shared<Cosh>(to_stream(s)), {input});
}
array tanh(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
return array(a.shape(), dtype, std::make_unique<Tanh>(to_stream(s)), {input});
return array(a.shape(), dtype, std::make_shared<Tanh>(to_stream(s)), {input});
}
array arcsinh(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
return array(
a.shape(), dtype, std::make_unique<ArcSinh>(to_stream(s)), {input});
a.shape(), dtype, std::make_shared<ArcSinh>(to_stream(s)), {input});
}
array arccosh(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
return array(
a.shape(), dtype, std::make_unique<ArcCosh>(to_stream(s)), {input});
a.shape(), dtype, std::make_shared<ArcCosh>(to_stream(s)), {input});
}
array arctanh(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
return array(
a.shape(), dtype, std::make_unique<ArcTanh>(to_stream(s)), {input});
a.shape(), dtype, std::make_shared<ArcTanh>(to_stream(s)), {input});
}
array log(const array& a, StreamOrDevice s /* = {} */) {
@ -2085,7 +2103,7 @@ array log(const array& a, StreamOrDevice s /* = {} */) {
return array(
a.shape(),
dtype,
std::make_unique<Log>(to_stream(s), Log::Base::e),
std::make_shared<Log>(to_stream(s), Log::Base::e),
{input});
}
@ -2095,7 +2113,7 @@ array log2(const array& a, StreamOrDevice s /* = {} */) {
return array(
a.shape(),
dtype,
std::make_unique<Log>(to_stream(s), Log::Base::two),
std::make_shared<Log>(to_stream(s), Log::Base::two),
{input});
}
@ -2105,7 +2123,7 @@ array log10(const array& a, StreamOrDevice s /* = {} */) {
return array(
a.shape(),
dtype,
std::make_unique<Log>(to_stream(s), Log::Base::ten),
std::make_shared<Log>(to_stream(s), Log::Base::ten),
{input});
}
@ -2113,26 +2131,27 @@ array log1p(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
return array(
a.shape(), dtype, std::make_unique<Log1p>(to_stream(s)), {input});
a.shape(), dtype, std::make_shared<Log1p>(to_stream(s)), {input});
}
array logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) {
// Make sure out type is floating point
auto out_type = at_least_float(promote_types(a.dtype(), b.dtype()));
auto inputs =
broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s);
broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s);
auto& shape = inputs[0].shape();
return array(
inputs[0].shape(),
shape,
out_type,
std::make_unique<LogAddExp>(to_stream(s)),
inputs);
std::make_shared<LogAddExp>(to_stream(s)),
std::move(inputs));
}
array sigmoid(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
auto input = astype(a, dtype, s);
return array(
a.shape(), dtype, std::make_unique<Sigmoid>(to_stream(s)), {input});
a.shape(), dtype, std::make_shared<Sigmoid>(to_stream(s)), {input});
}
array erf(const array& a, StreamOrDevice s /* = {} */) {
@ -2140,7 +2159,7 @@ array erf(const array& a, StreamOrDevice s /* = {} */) {
return array(
a.shape(),
dtype,
std::make_unique<Erf>(to_stream(s)),
std::make_shared<Erf>(to_stream(s)),
{astype(a, dtype, s)});
}
@ -2149,19 +2168,19 @@ array erfinv(const array& a, StreamOrDevice s /* = {} */) {
return array(
a.shape(),
dtype,
std::make_unique<ErfInv>(to_stream(s)),
std::make_shared<ErfInv>(to_stream(s)),
{astype(a, dtype, s)});
}
array stop_gradient(const array& a, StreamOrDevice s /* = {} */) {
return array(
a.shape(), a.dtype(), std::make_unique<StopGradient>(to_stream(s)), {a});
a.shape(), a.dtype(), std::make_shared<StopGradient>(to_stream(s)), {a});
}
array round(const array& a, int decimals, StreamOrDevice s /* = {} */) {
if (decimals == 0) {
return array(
a.shape(), a.dtype(), std::make_unique<Round>(to_stream(s)), {a});
a.shape(), a.dtype(), std::make_shared<Round>(to_stream(s)), {a});
}
auto dtype = at_least_float(a.dtype());
@ -2226,7 +2245,7 @@ array matmul(
auto out = array(
{a.shape(0), b.shape(1)},
out_type,
std::make_unique<Matmul>(to_stream(s)),
std::make_shared<Matmul>(to_stream(s)),
{a, b});
return reshape(out, out_shape, s);
}
@ -2250,17 +2269,17 @@ array matmul(
auto out_shape = a.shape();
out_shape.back() = b.shape(-1);
auto out = array(
out_shape, out_type, std::make_unique<Matmul>(to_stream(s)), {a, b});
auto p = std::make_shared<Matmul>(to_stream(s));
// Remove the possibly inserted singleton dimensions
if (in_a.ndim() == 1 || in_b.ndim() == 1) {
auto out = array(out_shape, out_type, std::move(p), {a, b});
out_shape.erase(
out_shape.end() - ((in_a.ndim() == 1) ? 2 : 1),
out_shape.end() - ((in_b.ndim() == 1) ? 0 : 1));
out = reshape(out, out_shape, s);
return reshape(out, std::move(out_shape), s);
}
return out;
return array(std::move(out_shape), out_type, std::move(p), {a, b});
}
array gather(
@ -2332,7 +2351,7 @@ array gather(
return array(
out_shape,
a.dtype(),
std::make_unique<Gather>(to_stream(s), axes, slice_sizes),
std::make_shared<Gather>(to_stream(s), axes, slice_sizes),
inputs);
}
@ -2516,7 +2535,7 @@ array scatter(
return array(
a.shape(),
a.dtype(),
std::make_unique<Scatter>(to_stream(s), mode, axes),
std::make_shared<Scatter>(to_stream(s), mode, axes),
inputs);
}
@ -2570,7 +2589,7 @@ array sqrt(const array& a, StreamOrDevice s /* = {} */) {
return array(
a.shape(),
dtype,
std::make_unique<Sqrt>(to_stream(s)),
std::make_shared<Sqrt>(to_stream(s)),
{astype(a, dtype, s)});
}
@ -2579,7 +2598,7 @@ array rsqrt(const array& a, StreamOrDevice s /* = {} */) {
return array(
a.shape(),
dtype,
std::make_unique<Sqrt>(to_stream(s), true),
std::make_shared<Sqrt>(to_stream(s), true),
{astype(a, dtype, s)});
}
@ -2592,7 +2611,7 @@ array softmax(
return array(
a.shape(),
dtype,
std::make_unique<Softmax>(to_stream(s)),
std::make_shared<Softmax>(to_stream(s)),
{astype(a, dtype, s)});
} else {
auto a_max = stop_gradient(max(a, axes, /*keepdims = */ true, s), s);
@ -2614,7 +2633,7 @@ array power(const array& a, const array& b, StreamOrDevice s /* = {} */) {
inputs = broadcast_arrays(inputs, s);
}
return array(
inputs[0].shape(), dtype, std::make_unique<Power>(to_stream(s)), inputs);
inputs[0].shape(), dtype, std::make_shared<Power>(to_stream(s)), inputs);
}
array cumsum(
@ -2635,7 +2654,7 @@ array cumsum(
return array(
a.shape(),
out_type,
std::make_unique<Scan>(
std::make_shared<Scan>(
to_stream(s), Scan::ReduceType::Sum, axis, reverse, inclusive),
{a});
}
@ -2657,7 +2676,7 @@ array cumprod(
return array(
a.shape(),
a.dtype(),
std::make_unique<Scan>(
std::make_shared<Scan>(
to_stream(s), Scan::ReduceType::Prod, axis, reverse, inclusive),
{a});
}
@ -2679,7 +2698,7 @@ array cummax(
return array(
a.shape(),
a.dtype(),
std::make_unique<Scan>(
std::make_shared<Scan>(
to_stream(s), Scan::ReduceType::Max, axis, reverse, inclusive),
{a});
}
@ -2701,7 +2720,7 @@ array cummin(
return array(
a.shape(),
a.dtype(),
std::make_unique<Scan>(
std::make_shared<Scan>(
to_stream(s), Scan::ReduceType::Min, axis, reverse, inclusive),
{a});
}
@ -2965,7 +2984,7 @@ array conv_general(
return array(
out_shape,
in.dtype(),
std::make_unique<Convolution>(
std::make_shared<Convolution>(
to_stream(s),
stride,
padding_lo,
@ -3042,7 +3061,7 @@ array quantized_matmul(
throw std::invalid_argument(msg.str());
}
auto dtype = result_type({x, scales, biases});
auto dtype = result_type(x, scales, biases);
if (!is_floating_point(dtype) || is_complex(dtype)) {
std::ostringstream msg;
msg << "[quantized_matmul] Only real floating types are supported but "
@ -3055,7 +3074,7 @@ array quantized_matmul(
auto out = array(
{x.shape(0), w_outer_dims},
dtype,
std::make_unique<QuantizedMatmul>(
std::make_shared<QuantizedMatmul>(
to_stream(s), group_size, bits, transpose),
{astype(x, dtype, s),
w,
@ -3065,7 +3084,7 @@ array quantized_matmul(
// If needed reshape x to the original batch shape
if (original_shape.size() != 1) {
original_shape.push_back(w_outer_dims);
out = reshape(out, original_shape, s);
out = reshape(out, std::move(original_shape), s);
}
return out;
@ -3344,7 +3363,7 @@ array addmm(
}
// Type promotion
auto out_type = result_type({a, b, c});
auto out_type = result_type(a, b, c);
if (!is_floating_point(out_type) || is_complex(out_type)) {
std::ostringstream msg;
msg << "[addmm] Only real floating point types are supported but "
@ -3373,7 +3392,7 @@ array addmm(
auto out = array(
{a.shape(0), b.shape(1)},
out_type,
std::make_unique<AddMM>(to_stream(s), alpha, beta),
std::make_shared<AddMM>(to_stream(s), alpha, beta),
{a, b, c});
return reshape(out, out_shape, s);
}
@ -3425,7 +3444,7 @@ array addmm(
auto out = array(
out_shape,
out_type,
std::make_unique<AddMM>(to_stream(s), alpha, beta),
std::make_shared<AddMM>(to_stream(s), alpha, beta),
{a, b, c});
// Remove the possibly inserted singleton dimensions
@ -3621,7 +3640,7 @@ array number_of_elements(
return stop_gradient(array(
std::vector<int>{},
dtype,
std::make_unique<NumberOfElements>(
std::make_shared<NumberOfElements>(
to_stream(s), std::move(axes), inverted, dtype),
{a}));
}

View File

@ -158,9 +158,7 @@ array expand_dims(
StreamOrDevice s = {});
/** Add a singleton dimension at the given axis. */
inline array expand_dims(const array& a, int axis, StreamOrDevice s = {}) {
return expand_dims(a, std::vector<int>{axis}, s);
}
array expand_dims(const array& a, int axis, StreamOrDevice s = {});
/** Slice an array. */
array slice(

View File

@ -66,7 +66,7 @@ array bits(
return array(
shape,
get_dtype(),
std::make_unique<RandomBits>(to_stream(s), shape, width),
std::make_shared<RandomBits>(to_stream(s), shape, width),
{key});
}

View File

@ -54,6 +54,12 @@ struct PrintFormatter {
extern PrintFormatter global_formatter;
/** The type from promoting the arrays' types with one another. */
inline Dtype result_type(const array& a, const array& b) {
return promote_types(a.dtype(), b.dtype());
}
inline Dtype result_type(const array& a, const array& b, const array& c) {
return promote_types(result_type(a, b), c.dtype());
}
Dtype result_type(const std::vector<array>& arrays);
std::vector<int> broadcast_shapes(

View File

@ -131,8 +131,8 @@ class Module(dict):
return value
def __getattr__(self, key: str):
if key in self:
return self[key]
if (value := self.get(key, None)) is not None:
return value
else:
super(Module, self).__getattribute__(key)

View File

@ -64,9 +64,9 @@ class Linear(Module):
def __call__(self, x: mx.array) -> mx.array:
if "bias" in self:
x = mx.addmm(self.bias, x, self.weight.T)
x = mx.addmm(self["bias"], x, self["weight"].T)
else:
x = x @ self.weight.T
x = x @ self["weight"].T
return x

View File

@ -140,7 +140,7 @@ class RMSNorm(Module):
return f"{self.weight.shape[0]}, eps={self.eps}"
def __call__(self, x):
return mx.fast.rms_norm(x, self.weight, self.eps)
return mx.fast.rms_norm(x, self["weight"], self.eps)
class GroupNorm(Module):

View File

@ -81,15 +81,15 @@ class QuantizedLinear(Module):
def __call__(self, x):
x = mx.quantized_matmul(
x,
self.weight,
scales=self.scales,
biases=self.biases,
self["weight"],
scales=self["scales"],
biases=self["biases"],
transpose=True,
group_size=self.group_size,
bits=self.bits,
)
if "bias" in self:
x = x + self.bias
x = x + self["bias"]
return x
@classmethod

View File

@ -17,12 +17,7 @@ void init_fast(nb::module_& parent_module) {
m.def(
"rms_norm",
[](const array& x,
const array& weight,
float eps,
const StreamOrDevice& s /* = {} */) {
return fast::rms_norm(x, weight, eps, s);
},
&fast::rms_norm,
"x"_a,
"weight"_a,
"eps"_a,
@ -48,13 +43,7 @@ void init_fast(nb::module_& parent_module) {
m.def(
"layer_norm",
[](const array& x,
const std::optional<array>& weight,
const std::optional<array>& bias,
float eps,
const StreamOrDevice& s /* = {} */) {
return fast::layer_norm(x, weight, bias, eps, s);
},
&fast::layer_norm,
"x"_a,
"weight"_a.none(),
"bias"_a.none(),
@ -84,15 +73,7 @@ void init_fast(nb::module_& parent_module) {
m.def(
"rope",
[](const array& a,
int dims,
bool traditional,
float base,
float scale,
int offset,
const StreamOrDevice& s /* = {} */) {
return fast::rope(a, dims, traditional, base, scale, offset, s);
},
&fast::rope,
"a"_a,
"dims"_a,
nb::kw_only(),
@ -123,14 +104,7 @@ void init_fast(nb::module_& parent_module) {
m.def(
"scaled_dot_product_attention",
[](const array& q,
const array& k,
const array& v,
const float scale,
const std::optional<array>& mask,
const StreamOrDevice& s) {
return fast::scaled_dot_product_attention(q, k, v, scale, mask, s);
},
&fast::scaled_dot_product_attention,
"q"_a,
"k"_a,
"v"_a,