From be98f4ab6be2e466e2ebf59622e5f4866b6527e2 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 22 Mar 2024 17:29:36 -0700 Subject: [PATCH] Reduce a little overhead (#871) * some small overhead improvements * use result_type in rms_norm * remove release force * fix + use non-vector version * revert compile change * fix ops * a little more overhead * a little more cleanup and overhead --- mlx/array.cpp | 29 +-- mlx/array.h | 4 + mlx/fast.cpp | 23 +- mlx/linalg.cpp | 6 +- mlx/ops.cpp | 353 ++++++++++++++------------ mlx/ops.h | 4 +- mlx/random.cpp | 2 +- mlx/utils.h | 6 + python/mlx/nn/layers/base.py | 4 +- python/mlx/nn/layers/linear.py | 4 +- python/mlx/nn/layers/normalization.py | 2 +- python/mlx/nn/layers/quantized.py | 8 +- python/src/fast.cpp | 34 +-- 13 files changed, 239 insertions(+), 240 deletions(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index e14f7bec7..771e72676 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -12,16 +12,6 @@ namespace mlx::core { namespace { -std::pair> cum_prod(const std::vector& shape) { - std::vector strides(shape.size()); - size_t cum_prod = 1; - for (int i = shape.size() - 1; i >= 0; --i) { - strides[i] = cum_prod; - cum_prod *= shape[i]; - } - return {cum_prod, strides}; -} - /** Return true if we are currently performing a function transformation in * order to keep the graph when evaluating tracer arrays. */ bool in_tracing() { @@ -171,9 +161,21 @@ void array::move_shared_buffer(array other) { move_shared_buffer(other, other.strides(), other.flags(), other.data_size()); } +void array::ArrayDesc::init() { + strides.resize(shape.size()); + size = 1; + for (int i = shape.size() - 1; i >= 0; --i) { + strides[i] = size; + size *= shape[i]; + } + for (auto& in : inputs) { + is_tracer |= in.is_tracer(); + } +} + array::ArrayDesc::ArrayDesc(std::vector shape, Dtype dtype) : shape(std::move(shape)), dtype(dtype) { - std::tie(size, strides) = cum_prod(this->shape); + init(); } array::ArrayDesc::ArrayDesc( @@ -185,10 +187,7 @@ array::ArrayDesc::ArrayDesc( dtype(dtype), primitive(std::move(primitive)), inputs(std::move(inputs)) { - std::tie(size, strides) = cum_prod(this->shape); - for (auto& in : this->inputs) { - is_tracer |= in.is_tracer(); - } + init(); } array::ArrayIterator::ArrayIterator(const array& arr, int idx) diff --git a/mlx/array.h b/mlx/array.h index 686dc9b33..9b7cf0f8f 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -392,6 +392,10 @@ class array { Dtype dtype, std::shared_ptr primitive, std::vector inputs); + + private: + // Initialize size, strides, and other metadata + void init(); }; // The ArrayDesc contains the details of the materialized array including the diff --git a/mlx/fast.cpp b/mlx/fast.cpp index 568e2604f..589dbe5aa 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -63,7 +63,7 @@ array rms_norm( << " dimensions."; throw std::invalid_argument(msg.str()); } - auto out_type = result_type({x, weight}); + auto out_type = result_type(x, weight); if (!is_floating_point(out_type) || is_complex(out_type)) { std::ostringstream msg; msg << "[rms_norm] Received unsupported type " << out_type << "."; @@ -88,7 +88,7 @@ array rms_norm( return array( x.shape(), out_type, - std::make_unique(s, fallback, eps), + std::make_shared(s, fallback, eps), {astype(x, out_type, s), astype(weight, out_type, s)}); } return fallback({x, weight})[0]; @@ -125,8 +125,8 @@ array layer_norm( } auto out_type = (weight.has_value()) - ? ((bias.has_value()) ? result_type({x, *weight, *bias}) - : result_type({x, *weight})) + ? ((bias.has_value()) ? result_type(x, *weight, *bias) + : result_type(x, *weight)) : x.dtype(); if (!is_floating_point(out_type) || is_complex(out_type)) { std::ostringstream msg; @@ -170,7 +170,7 @@ array layer_norm( return array( x.shape(), out_type, - std::make_unique(s, fallback, eps), + std::make_shared(s, fallback, eps), {astype(x, out_type, s), passed_weight, passed_bias}); } return fallback({x, passed_weight, passed_bias})[0]; @@ -248,7 +248,7 @@ array rope( return array( x.shape(), x.dtype(), - std::make_unique( + std::make_shared( stream, fallback, dims, traditional, base, scale, offset), {x}); } @@ -318,7 +318,7 @@ array scaled_dot_product_attention( throw std::invalid_argument(msg.str()); } - auto final_type = result_type({queries, keys, values}); + auto final_type = result_type(queries, keys, values); if (!is_floating_point(final_type) || is_complex(final_type)) { std::ostringstream msg; msg << "[scaled_dot_product_attention] Received unsupported type " @@ -330,9 +330,6 @@ array scaled_dot_product_attention( auto k = astype(keys, final_type, s); auto v = astype(values, final_type, s); - auto out_shape = - std::vector({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}); - /* generic implementation for use cases that Metal implementation does not * support. For non-supported cases listed below, use MLX primitives: * * CPU implementation @@ -381,10 +378,12 @@ array scaled_dot_product_attention( // TODO, update routing conditions post further tuning implementation_supports_use_case &= false; if (implementation_supports_use_case) { + auto out_shape = + std::vector({q.shape(0), q.shape(1), q.shape(2), v.shape(-1)}); auto out = array( - out_shape, + std::move(out_shape), final_type, - std::make_unique( + std::make_shared( stream, fallback, scale, false), {q, k, v}); return out; diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 5d609e7f1..c000d2591 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -195,7 +195,7 @@ std::pair qr(const array& a, StreamOrDevice s /* = {} */) { auto out = array::make_arrays( {a.shape(), a.shape()}, {a.dtype(), a.dtype()}, - std::make_unique(to_stream(s)), + std::make_shared(to_stream(s)), {astype(a, a.dtype(), s)}); return std::make_pair(out[0], out[1]); } @@ -234,7 +234,7 @@ std::vector svd(const array& a, StreamOrDevice s /* = {} */) { return array::make_arrays( {u_shape, s_shape, vt_shape}, {a.dtype(), a.dtype(), a.dtype()}, - std::make_unique(to_stream(s)), + std::make_shared(to_stream(s)), {a}); } @@ -258,7 +258,7 @@ array inv(const array& a, StreamOrDevice s /* = {} */) { } return array( - a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); + a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); } } // namespace mlx::core::linalg diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 2050fa374..c7c858550 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -88,7 +88,7 @@ array arange( return array( {size}, dtype, - std::make_unique(to_stream(s), start, stop, step), + std::make_shared(to_stream(s), start, stop, step), {}); } array arange( @@ -163,7 +163,7 @@ array astype(const array& a, Dtype dtype, StreamOrDevice s /* = {} */) { return a; } return array( - a.shape(), dtype, std::make_unique(to_stream(s), dtype), {a}); + a.shape(), dtype, std::make_shared(to_stream(s), dtype), {a}); } array as_strided( @@ -177,12 +177,12 @@ array as_strided( return array( shape, a.dtype(), - std::make_unique(to_stream(s), shape, strides, offset), + std::make_shared(to_stream(s), shape, strides, offset), {x}); } array copy(const array& a, StreamOrDevice s /* = {} */) { - return array(a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); + return array(a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); } array full( @@ -194,7 +194,7 @@ array full( throw std::invalid_argument("[full] Negative dimensions not allowed."); } auto in = broadcast_to(astype(vals, dtype, s), shape, s); - return array(shape, dtype, std::make_unique(to_stream(s)), {in}); + return array(shape, dtype, std::make_shared(to_stream(s)), {in}); } array full( @@ -313,9 +313,8 @@ array reshape( << " into shape " << shape << "."; throw std::invalid_argument(msg.str()); } - - return array( - shape, a.dtype(), std::make_unique(to_stream(s), shape), {a}); + auto p = std::make_shared(to_stream(s), shape); + return array(std::move(shape), a.dtype(), std::move(p), {a}); } array flatten( @@ -408,6 +407,20 @@ array squeeze(const array& a, StreamOrDevice s /* = {} */) { return squeeze(a, axes, s); } +array expand_dims(const array& a, int axis, StreamOrDevice s /* = {} */) { + int out_dim = a.ndim() + 1; + int ax = axis < 0 ? axis + out_dim : axis; + if (ax < 0 || ax >= out_dim) { + std::ostringstream msg; + msg << "[expand_dims] Invalid axes " << axis << " for output array with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + auto shape = a.shape(); + shape.insert(shape.begin() + ax, 1); + return reshape(a, std::move(shape), s); +} + array expand_dims( const array& a, const std::vector& axes, @@ -425,7 +438,7 @@ array expand_dims( ax = ax < 0 ? ax + out_ndim : ax; if (ax < 0 || ax >= out_ndim) { std::ostringstream msg; - msg << "[squeeze] Invalid axes " << ax << " for output array with " + msg << "[expand_dims] Invalid axes " << ax << " for output array with " << a.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } @@ -442,7 +455,7 @@ array expand_dims( for (int i = 0; i < sorted_axes.size(); ++i) { out_shape.insert(out_shape.begin() + sorted_axes[i], 1); } - return reshape(a, out_shape, s); + return reshape(a, std::move(out_shape), s); } // Slice helper @@ -523,7 +536,7 @@ array slice( return array( out_shape, a.dtype(), - std::make_unique( + std::make_shared( to_stream(s), std::move(start), std::move(stop), std::move(strides)), {a}); } @@ -568,7 +581,7 @@ array slice_update( return array( src.shape(), src.dtype(), - std::make_unique( + std::make_shared( to_stream(s), std::move(start), std::move(stop), std::move(strides)), {src, update_broadcasted}); } @@ -743,7 +756,10 @@ array concatenate( auto dtype = result_type(arrays); return array( - shape, dtype, std::make_unique(to_stream(s), ax), arrays); + std::move(shape), + dtype, + std::make_shared(to_stream(s), ax), + std::move(arrays)); } array concatenate( @@ -886,7 +902,7 @@ array pad( return array( out_shape, a.dtype(), - std::make_unique(to_stream(s), axes, low_pad_size, high_pad_size), + std::make_shared(to_stream(s), axes, low_pad_size, high_pad_size), {a, astype(pad_value, a.dtype(), s)}); } @@ -975,7 +991,7 @@ array swapaxes( std::vector reorder(a.ndim()); std::iota(reorder.begin(), reorder.end(), 0); std::swap(reorder[axis1], reorder[axis2]); - return transpose(a, reorder, s); + return transpose(a, std::move(reorder), s); } array transpose( @@ -1000,9 +1016,9 @@ array transpose( shape.push_back(a.shape()[ax]); } return array( - shape, + std::move(shape), a.dtype(), - std::make_unique(to_stream(s), std::move(axes)), + std::make_shared(to_stream(s), std::move(axes)), {a}); } @@ -1029,7 +1045,16 @@ array broadcast_to( throw std::invalid_argument(msg.str()); } return array( - shape, a.dtype(), std::make_unique(to_stream(s), shape), {a}); + std::move(bxshape), + a.dtype(), + std::make_shared(to_stream(s), shape), + {a}); +} + +std::vector +broadcast_arrays(const array& a, const array& b, StreamOrDevice s /* = {} */) { + std::vector shape = broadcast_shapes(a.shape(), b.shape()); + return {broadcast_to(a, shape, s), broadcast_to(b, shape, s)}; } std::vector broadcast_arrays( @@ -1048,38 +1073,29 @@ std::vector broadcast_arrays( array equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); - std::vector inputs = {astype(a, dtype, s), astype(b, dtype, s)}; - if (a.shape() != b.shape()) { - inputs = broadcast_arrays(inputs, s); - } + auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); + auto& shape = inputs[0].shape(); return array( - inputs[0].shape(), bool_, std::make_unique(to_stream(s)), inputs); + shape, bool_, std::make_shared(to_stream(s)), std::move(inputs)); } array not_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); - std::vector inputs = {astype(a, dtype, s), astype(b, dtype, s)}; - if (a.shape() != b.shape()) { - inputs = broadcast_arrays(inputs, s); - } + auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); + auto& shape = inputs[0].shape(); return array( - inputs[0].shape(), + shape, bool_, - std::make_unique(to_stream(s)), - inputs); + std::make_shared(to_stream(s)), + std::move(inputs)); } array greater(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); - std::vector inputs = {astype(a, dtype, s), astype(b, dtype, s)}; - if (a.shape() != b.shape()) { - inputs = broadcast_arrays(inputs, s); - } + auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); + auto& shape = inputs[0].shape(); return array( - inputs[0].shape(), - bool_, - std::make_unique(to_stream(s)), - inputs); + shape, bool_, std::make_shared(to_stream(s)), std::move(inputs)); } array greater_equal( @@ -1087,38 +1103,32 @@ array greater_equal( const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); - std::vector inputs = {astype(a, dtype, s), astype(b, dtype, s)}; - if (a.shape() != b.shape()) { - inputs = broadcast_arrays(inputs, s); - } + auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); + auto& shape = inputs[0].shape(); return array( - inputs[0].shape(), + shape, bool_, - std::make_unique(to_stream(s)), - inputs); + std::make_shared(to_stream(s)), + std::move(inputs)); } array less(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); - std::vector inputs = {astype(a, dtype, s), astype(b, dtype, s)}; - if (a.shape() != b.shape()) { - inputs = broadcast_arrays(inputs, s); - } + auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); + auto& shape = inputs[0].shape(); return array( - inputs[0].shape(), bool_, std::make_unique(to_stream(s)), inputs); + shape, bool_, std::make_shared(to_stream(s)), std::move(inputs)); } array less_equal(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); - std::vector inputs = {astype(a, dtype, s), astype(b, dtype, s)}; - if (a.shape() != b.shape()) { - inputs = broadcast_arrays(inputs, s); - } + auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); + auto& shape = inputs[0].shape(); return array( - inputs[0].shape(), + shape, bool_, - std::make_unique(to_stream(s)), - inputs); + std::make_shared(to_stream(s)), + std::move(inputs)); } array array_equal( @@ -1135,7 +1145,7 @@ array array_equal( array( a.shape(), bool_, - std::make_unique(to_stream(s), equal_nan), + std::make_shared(to_stream(s), equal_nan), {astype(a, dtype, s), astype(b, dtype, s)}), false, s); @@ -1180,7 +1190,7 @@ array where( return array( inputs[0].shape(), out_dtype, - std::make_unique(to_stream(s)), inputs); } @@ -1246,7 +1256,7 @@ array all( auto out = array( out_shape, bool_, - std::make_unique(to_stream(s), Reduce::And, sorted_axes), + std::make_shared(to_stream(s), Reduce::And, sorted_axes), {a}); if (!keepdims) { out = squeeze(out, sorted_axes, s); @@ -1280,7 +1290,7 @@ array any( auto out = array( out_shape, bool_, - std::make_unique(to_stream(s), Reduce::Or, sorted_axes), + std::make_shared(to_stream(s), Reduce::Or, sorted_axes), {a}); if (!keepdims) { out = squeeze(out, sorted_axes, s); @@ -1315,7 +1325,7 @@ array sum( auto out = array( out_shape, out_type, - std::make_unique(to_stream(s), Reduce::Sum, sorted_axes), + std::make_shared(to_stream(s), Reduce::Sum, sorted_axes), {a}); if (!keepdims) { out = squeeze(out, sorted_axes, s); @@ -1424,7 +1434,7 @@ array prod( auto out = array( out_shape, a.dtype(), - std::make_unique(to_stream(s), Reduce::Prod, sorted_axes), + std::make_shared(to_stream(s), Reduce::Prod, sorted_axes), {a}); if (!keepdims) { out = squeeze(out, sorted_axes, s); @@ -1461,7 +1471,7 @@ array max( auto out = array( out_shape, a.dtype(), - std::make_unique(to_stream(s), Reduce::Max, sorted_axes), + std::make_shared(to_stream(s), Reduce::Max, sorted_axes), {a}); if (!keepdims) { out = squeeze(out, sorted_axes, s); @@ -1498,7 +1508,7 @@ array min( auto out = array( out_shape, a.dtype(), - std::make_unique(to_stream(s), Reduce::Min, sorted_axes), + std::make_shared(to_stream(s), Reduce::Min, sorted_axes), {a}); if (!keepdims) { out = squeeze(out, sorted_axes, s); @@ -1538,7 +1548,7 @@ array argmin( auto out = array( out_shape, uint32, - std::make_unique( + std::make_shared( to_stream(s), ArgReduce::ArgMin, sorted_axes[0]), {a}); if (!keepdims) { @@ -1571,7 +1581,7 @@ array argmax( auto out = array( out_shape, uint32, - std::make_unique( + std::make_shared( to_stream(s), ArgReduce::ArgMax, sorted_axes[0]), {a}); if (!keepdims) { @@ -1607,7 +1617,7 @@ array sort(const array& a, int axis, StreamOrDevice s /* = {} */) { } return array( - a.shape(), a.dtype(), std::make_unique(to_stream(s), axis), {a}); + a.shape(), a.dtype(), std::make_shared(to_stream(s), axis), {a}); } /** Returns indices that sort the flattened array. */ @@ -1637,7 +1647,7 @@ array argsort(const array& a, int axis, StreamOrDevice s /* = {} */) { } return array( - a.shape(), uint32, std::make_unique(to_stream(s), axis), {a}); + a.shape(), uint32, std::make_shared(to_stream(s), axis), {a}); } /** @@ -1677,7 +1687,7 @@ array partition( return array( a.shape(), a.dtype(), - std::make_unique(to_stream(s), kth_, axis_), + std::make_shared(to_stream(s), kth_, axis_), {a}); } @@ -1718,7 +1728,7 @@ array argpartition( return array( a.shape(), uint32, - std::make_unique(to_stream(s), kth_, axis_), + std::make_shared(to_stream(s), kth_, axis_), {a}); } @@ -1787,7 +1797,7 @@ array logsumexp( array abs(const array& a, StreamOrDevice s /* = {} */) { auto out = - array(a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); + array(a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); if (a.dtype() == complex64) { out = astype(out, float32, s); } @@ -1800,33 +1810,33 @@ array negative(const array& a, StreamOrDevice s /* = {} */) { throw std::invalid_argument(msg); } return array( - a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); + a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); } array operator-(const array& a) { return negative(a); } array sign(const array& a, StreamOrDevice s /* = {} */) { - return array(a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); + return array(a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); } array logical_not(const array& a, StreamOrDevice s /* = {} */) { return array( a.shape(), bool_, - std::make_unique(to_stream(s)), + std::make_shared(to_stream(s)), {astype(a, bool_, s)}); } array logical_and(const array& a, const array& b, StreamOrDevice s /* = {} */) { // Broadcast arrays to a common shape - auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s); - + auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s); + auto& shape = inputs[0].shape(); return array( - inputs[0].shape(), + shape, bool_, - std::make_unique(to_stream(s)), - inputs); + std::make_shared(to_stream(s)), + std::move(inputs)); } array operator&&(const array& a, const array& b) { return logical_and(a, b); @@ -1834,13 +1844,13 @@ array operator&&(const array& a, const array& b) { array logical_or(const array& a, const array& b, StreamOrDevice s /* = {} */) { // Broadcast arrays to a common shape - auto inputs = broadcast_arrays({astype(a, bool_, s), astype(b, bool_, s)}, s); - + auto inputs = broadcast_arrays(astype(a, bool_, s), astype(b, bool_, s), s); + auto& shape = inputs[0].shape(); return array( - inputs[0].shape(), + shape, bool_, - std::make_unique(to_stream(s)), - inputs); + std::make_shared(to_stream(s)), + std::move(inputs)); } array operator||(const array& a, const array& b) { return logical_or(a, b); @@ -1854,9 +1864,10 @@ array reciprocal(const array& a, StreamOrDevice s /* = {} */) { array add(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto out_type = promote_types(a.dtype(), b.dtype()); auto inputs = - broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); + broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s); + auto& shape = inputs[0].shape(); return array( - inputs[0].shape(), out_type, std::make_unique(to_stream(s)), inputs); + shape, out_type, std::make_shared(to_stream(s)), std::move(inputs)); } array operator+(const array& a, const array& b) { @@ -1866,12 +1877,13 @@ array operator+(const array& a, const array& b) { array subtract(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto out_type = promote_types(a.dtype(), b.dtype()); auto inputs = - broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); + broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s); + auto& shape = inputs[0].shape(); return array( - inputs[0].shape(), + shape, out_type, - std::make_unique(to_stream(s)), - inputs); + std::make_shared(to_stream(s)), + std::move(inputs)); } array operator-(const array& a, const array& b) { @@ -1881,12 +1893,13 @@ array operator-(const array& a, const array& b) { array multiply(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto out_type = promote_types(a.dtype(), b.dtype()); auto inputs = - broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); + broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s); + auto& shape = inputs[0].shape(); return array( - inputs[0].shape(), + shape, out_type, - std::make_unique(to_stream(s)), - inputs); + std::make_shared(to_stream(s)), + std::move(inputs)); } array operator*(const array& a, const array& b) { @@ -1895,10 +1908,11 @@ array operator*(const array& a, const array& b) { array divide(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(promote_types(a.dtype(), b.dtype())); - auto inputs = broadcast_arrays( - {astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s); + auto inputs = + broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s); + auto& shape = inputs[0].shape(); return array( - inputs[0].shape(), dtype, std::make_unique(to_stream(s)), inputs); + shape, dtype, std::make_shared(to_stream(s)), std::move(inputs)); } array operator/(const array& a, const array& b) { return divide(a, b); @@ -1919,20 +1933,22 @@ array floor_divide( return floor(divide(a, b, s), s); } - auto inputs = broadcast_arrays({astype(a, dtype, s), astype(b, dtype, s)}, s); + auto inputs = broadcast_arrays(astype(a, dtype, s), astype(b, dtype, s), s); + auto& shape = inputs[0].shape(); return array( - inputs[0].shape(), dtype, std::make_unique(to_stream(s)), inputs); + shape, dtype, std::make_shared(to_stream(s)), std::move(inputs)); } array remainder(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto dtype = promote_types(a.dtype(), b.dtype()); - auto inputs = broadcast_arrays( - {astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s); + auto inputs = + broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s); + auto& shape = inputs[0].shape(); return array( - inputs[0].shape(), + shape, dtype, - std::make_unique(to_stream(s)), - inputs); + std::make_shared(to_stream(s)), + std::move(inputs)); } array operator%(const array& a, const array& b) { return remainder(a, b); @@ -1944,8 +1960,8 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) { if (is_complex(dtype)) { throw std::invalid_argument("[divmod] Complex type not supported."); } - auto inputs = broadcast_arrays( - {astype(a, dtype, s), astype(b, dtype, to_stream(s))}, s); + auto inputs = + broadcast_arrays(astype(a, dtype, s), astype(b, dtype, to_stream(s)), s); return array::make_arrays( {inputs[0].shape(), inputs[0].shape()}, {inputs[0].dtype(), inputs[0].dtype()}, @@ -1956,23 +1972,25 @@ divmod(const array& a, const array& b, StreamOrDevice s /* = {} */) { array maximum(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto out_type = promote_types(a.dtype(), b.dtype()); auto inputs = - broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); + broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s); + auto& shape = inputs[0].shape(); return array( - inputs[0].shape(), + shape, out_type, - std::make_unique(to_stream(s)), - inputs); + std::make_shared(to_stream(s)), + std::move(inputs)); } array minimum(const array& a, const array& b, StreamOrDevice s /* = {} */) { auto out_type = promote_types(a.dtype(), b.dtype()); auto inputs = - broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); + broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s); + auto& shape = inputs[0].shape(); return array( - inputs[0].shape(), + shape, out_type, - std::make_unique(to_stream(s)), - inputs); + std::make_shared(to_stream(s)), + std::move(inputs)); } array floor(const array& a, StreamOrDevice s /* = {} */) { @@ -1980,103 +1998,103 @@ array floor(const array& a, StreamOrDevice s /* = {} */) { throw std::invalid_argument("[floor] Not supported for complex64."); } return array( - a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); + a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); } array ceil(const array& a, StreamOrDevice s /* = {} */) { if (a.dtype() == complex64) { throw std::invalid_argument("[floor] Not supported for complex64."); } - return array(a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); + return array(a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); } array square(const array& a, StreamOrDevice s /* = {} */) { return array( - a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); + a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); } array exp(const array& a, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(a.dtype()); auto input = astype(a, dtype, s); - return array(a.shape(), dtype, std::make_unique(to_stream(s)), {input}); + return array(a.shape(), dtype, std::make_shared(to_stream(s)), {input}); } array sin(const array& a, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(a.dtype()); auto input = astype(a, dtype, s); - return array(a.shape(), dtype, std::make_unique(to_stream(s)), {input}); + return array(a.shape(), dtype, std::make_shared(to_stream(s)), {input}); } array cos(const array& a, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(a.dtype()); auto input = astype(a, dtype, s); - return array(a.shape(), dtype, std::make_unique(to_stream(s)), {input}); + return array(a.shape(), dtype, std::make_shared(to_stream(s)), {input}); } array tan(const array& a, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(a.dtype()); auto input = astype(a, dtype, s); - return array(a.shape(), dtype, std::make_unique(to_stream(s)), {input}); + return array(a.shape(), dtype, std::make_shared(to_stream(s)), {input}); } array arcsin(const array& a, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(a.dtype()); auto input = astype(a, dtype, s); return array( - a.shape(), dtype, std::make_unique(to_stream(s)), {input}); + a.shape(), dtype, std::make_shared(to_stream(s)), {input}); } array arccos(const array& a, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(a.dtype()); auto input = astype(a, dtype, s); return array( - a.shape(), dtype, std::make_unique(to_stream(s)), {input}); + a.shape(), dtype, std::make_shared(to_stream(s)), {input}); } array arctan(const array& a, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(a.dtype()); auto input = astype(a, dtype, s); return array( - a.shape(), dtype, std::make_unique(to_stream(s)), {input}); + a.shape(), dtype, std::make_shared(to_stream(s)), {input}); } array sinh(const array& a, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(a.dtype()); auto input = astype(a, dtype, s); - return array(a.shape(), dtype, std::make_unique(to_stream(s)), {input}); + return array(a.shape(), dtype, std::make_shared(to_stream(s)), {input}); } array cosh(const array& a, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(a.dtype()); auto input = astype(a, dtype, s); - return array(a.shape(), dtype, std::make_unique(to_stream(s)), {input}); + return array(a.shape(), dtype, std::make_shared(to_stream(s)), {input}); } array tanh(const array& a, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(a.dtype()); auto input = astype(a, dtype, s); - return array(a.shape(), dtype, std::make_unique(to_stream(s)), {input}); + return array(a.shape(), dtype, std::make_shared(to_stream(s)), {input}); } array arcsinh(const array& a, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(a.dtype()); auto input = astype(a, dtype, s); return array( - a.shape(), dtype, std::make_unique(to_stream(s)), {input}); + a.shape(), dtype, std::make_shared(to_stream(s)), {input}); } array arccosh(const array& a, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(a.dtype()); auto input = astype(a, dtype, s); return array( - a.shape(), dtype, std::make_unique(to_stream(s)), {input}); + a.shape(), dtype, std::make_shared(to_stream(s)), {input}); } array arctanh(const array& a, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(a.dtype()); auto input = astype(a, dtype, s); return array( - a.shape(), dtype, std::make_unique(to_stream(s)), {input}); + a.shape(), dtype, std::make_shared(to_stream(s)), {input}); } array log(const array& a, StreamOrDevice s /* = {} */) { @@ -2085,7 +2103,7 @@ array log(const array& a, StreamOrDevice s /* = {} */) { return array( a.shape(), dtype, - std::make_unique(to_stream(s), Log::Base::e), + std::make_shared(to_stream(s), Log::Base::e), {input}); } @@ -2095,7 +2113,7 @@ array log2(const array& a, StreamOrDevice s /* = {} */) { return array( a.shape(), dtype, - std::make_unique(to_stream(s), Log::Base::two), + std::make_shared(to_stream(s), Log::Base::two), {input}); } @@ -2105,7 +2123,7 @@ array log10(const array& a, StreamOrDevice s /* = {} */) { return array( a.shape(), dtype, - std::make_unique(to_stream(s), Log::Base::ten), + std::make_shared(to_stream(s), Log::Base::ten), {input}); } @@ -2113,26 +2131,27 @@ array log1p(const array& a, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(a.dtype()); auto input = astype(a, dtype, s); return array( - a.shape(), dtype, std::make_unique(to_stream(s)), {input}); + a.shape(), dtype, std::make_shared(to_stream(s)), {input}); } array logaddexp(const array& a, const array& b, StreamOrDevice s /* = {} */) { // Make sure out type is floating point auto out_type = at_least_float(promote_types(a.dtype(), b.dtype())); auto inputs = - broadcast_arrays({astype(a, out_type, s), astype(b, out_type, s)}, s); + broadcast_arrays(astype(a, out_type, s), astype(b, out_type, s), s); + auto& shape = inputs[0].shape(); return array( - inputs[0].shape(), + shape, out_type, - std::make_unique(to_stream(s)), - inputs); + std::make_shared(to_stream(s)), + std::move(inputs)); } array sigmoid(const array& a, StreamOrDevice s /* = {} */) { auto dtype = at_least_float(a.dtype()); auto input = astype(a, dtype, s); return array( - a.shape(), dtype, std::make_unique(to_stream(s)), {input}); + a.shape(), dtype, std::make_shared(to_stream(s)), {input}); } array erf(const array& a, StreamOrDevice s /* = {} */) { @@ -2140,7 +2159,7 @@ array erf(const array& a, StreamOrDevice s /* = {} */) { return array( a.shape(), dtype, - std::make_unique(to_stream(s)), + std::make_shared(to_stream(s)), {astype(a, dtype, s)}); } @@ -2149,19 +2168,19 @@ array erfinv(const array& a, StreamOrDevice s /* = {} */) { return array( a.shape(), dtype, - std::make_unique(to_stream(s)), + std::make_shared(to_stream(s)), {astype(a, dtype, s)}); } array stop_gradient(const array& a, StreamOrDevice s /* = {} */) { return array( - a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); + a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); } array round(const array& a, int decimals, StreamOrDevice s /* = {} */) { if (decimals == 0) { return array( - a.shape(), a.dtype(), std::make_unique(to_stream(s)), {a}); + a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); } auto dtype = at_least_float(a.dtype()); @@ -2226,7 +2245,7 @@ array matmul( auto out = array( {a.shape(0), b.shape(1)}, out_type, - std::make_unique(to_stream(s)), + std::make_shared(to_stream(s)), {a, b}); return reshape(out, out_shape, s); } @@ -2250,17 +2269,17 @@ array matmul( auto out_shape = a.shape(); out_shape.back() = b.shape(-1); - auto out = array( - out_shape, out_type, std::make_unique(to_stream(s)), {a, b}); + auto p = std::make_shared(to_stream(s)); // Remove the possibly inserted singleton dimensions if (in_a.ndim() == 1 || in_b.ndim() == 1) { + auto out = array(out_shape, out_type, std::move(p), {a, b}); out_shape.erase( out_shape.end() - ((in_a.ndim() == 1) ? 2 : 1), out_shape.end() - ((in_b.ndim() == 1) ? 0 : 1)); - out = reshape(out, out_shape, s); + return reshape(out, std::move(out_shape), s); } - return out; + return array(std::move(out_shape), out_type, std::move(p), {a, b}); } array gather( @@ -2332,7 +2351,7 @@ array gather( return array( out_shape, a.dtype(), - std::make_unique(to_stream(s), axes, slice_sizes), + std::make_shared(to_stream(s), axes, slice_sizes), inputs); } @@ -2516,7 +2535,7 @@ array scatter( return array( a.shape(), a.dtype(), - std::make_unique(to_stream(s), mode, axes), + std::make_shared(to_stream(s), mode, axes), inputs); } @@ -2570,7 +2589,7 @@ array sqrt(const array& a, StreamOrDevice s /* = {} */) { return array( a.shape(), dtype, - std::make_unique(to_stream(s)), + std::make_shared(to_stream(s)), {astype(a, dtype, s)}); } @@ -2579,7 +2598,7 @@ array rsqrt(const array& a, StreamOrDevice s /* = {} */) { return array( a.shape(), dtype, - std::make_unique(to_stream(s), true), + std::make_shared(to_stream(s), true), {astype(a, dtype, s)}); } @@ -2592,7 +2611,7 @@ array softmax( return array( a.shape(), dtype, - std::make_unique(to_stream(s)), + std::make_shared(to_stream(s)), {astype(a, dtype, s)}); } else { auto a_max = stop_gradient(max(a, axes, /*keepdims = */ true, s), s); @@ -2614,7 +2633,7 @@ array power(const array& a, const array& b, StreamOrDevice s /* = {} */) { inputs = broadcast_arrays(inputs, s); } return array( - inputs[0].shape(), dtype, std::make_unique(to_stream(s)), inputs); + inputs[0].shape(), dtype, std::make_shared(to_stream(s)), inputs); } array cumsum( @@ -2635,7 +2654,7 @@ array cumsum( return array( a.shape(), out_type, - std::make_unique( + std::make_shared( to_stream(s), Scan::ReduceType::Sum, axis, reverse, inclusive), {a}); } @@ -2657,7 +2676,7 @@ array cumprod( return array( a.shape(), a.dtype(), - std::make_unique( + std::make_shared( to_stream(s), Scan::ReduceType::Prod, axis, reverse, inclusive), {a}); } @@ -2679,7 +2698,7 @@ array cummax( return array( a.shape(), a.dtype(), - std::make_unique( + std::make_shared( to_stream(s), Scan::ReduceType::Max, axis, reverse, inclusive), {a}); } @@ -2701,7 +2720,7 @@ array cummin( return array( a.shape(), a.dtype(), - std::make_unique( + std::make_shared( to_stream(s), Scan::ReduceType::Min, axis, reverse, inclusive), {a}); } @@ -2965,7 +2984,7 @@ array conv_general( return array( out_shape, in.dtype(), - std::make_unique( + std::make_shared( to_stream(s), stride, padding_lo, @@ -3042,7 +3061,7 @@ array quantized_matmul( throw std::invalid_argument(msg.str()); } - auto dtype = result_type({x, scales, biases}); + auto dtype = result_type(x, scales, biases); if (!is_floating_point(dtype) || is_complex(dtype)) { std::ostringstream msg; msg << "[quantized_matmul] Only real floating types are supported but " @@ -3055,7 +3074,7 @@ array quantized_matmul( auto out = array( {x.shape(0), w_outer_dims}, dtype, - std::make_unique( + std::make_shared( to_stream(s), group_size, bits, transpose), {astype(x, dtype, s), w, @@ -3065,7 +3084,7 @@ array quantized_matmul( // If needed reshape x to the original batch shape if (original_shape.size() != 1) { original_shape.push_back(w_outer_dims); - out = reshape(out, original_shape, s); + out = reshape(out, std::move(original_shape), s); } return out; @@ -3344,7 +3363,7 @@ array addmm( } // Type promotion - auto out_type = result_type({a, b, c}); + auto out_type = result_type(a, b, c); if (!is_floating_point(out_type) || is_complex(out_type)) { std::ostringstream msg; msg << "[addmm] Only real floating point types are supported but " @@ -3373,7 +3392,7 @@ array addmm( auto out = array( {a.shape(0), b.shape(1)}, out_type, - std::make_unique(to_stream(s), alpha, beta), + std::make_shared(to_stream(s), alpha, beta), {a, b, c}); return reshape(out, out_shape, s); } @@ -3425,7 +3444,7 @@ array addmm( auto out = array( out_shape, out_type, - std::make_unique(to_stream(s), alpha, beta), + std::make_shared(to_stream(s), alpha, beta), {a, b, c}); // Remove the possibly inserted singleton dimensions @@ -3621,7 +3640,7 @@ array number_of_elements( return stop_gradient(array( std::vector{}, dtype, - std::make_unique( + std::make_shared( to_stream(s), std::move(axes), inverted, dtype), {a})); } diff --git a/mlx/ops.h b/mlx/ops.h index 0df68ecc9..ff9963ed7 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -158,9 +158,7 @@ array expand_dims( StreamOrDevice s = {}); /** Add a singleton dimension at the given axis. */ -inline array expand_dims(const array& a, int axis, StreamOrDevice s = {}) { - return expand_dims(a, std::vector{axis}, s); -} +array expand_dims(const array& a, int axis, StreamOrDevice s = {}); /** Slice an array. */ array slice( diff --git a/mlx/random.cpp b/mlx/random.cpp index 5e0682d32..6e823de39 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -66,7 +66,7 @@ array bits( return array( shape, get_dtype(), - std::make_unique(to_stream(s), shape, width), + std::make_shared(to_stream(s), shape, width), {key}); } diff --git a/mlx/utils.h b/mlx/utils.h index b2eedbdaf..8a5f2ccd9 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -54,6 +54,12 @@ struct PrintFormatter { extern PrintFormatter global_formatter; /** The type from promoting the arrays' types with one another. */ +inline Dtype result_type(const array& a, const array& b) { + return promote_types(a.dtype(), b.dtype()); +} +inline Dtype result_type(const array& a, const array& b, const array& c) { + return promote_types(result_type(a, b), c.dtype()); +} Dtype result_type(const std::vector& arrays); std::vector broadcast_shapes( diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py index 50808475d..2efdf5e33 100644 --- a/python/mlx/nn/layers/base.py +++ b/python/mlx/nn/layers/base.py @@ -131,8 +131,8 @@ class Module(dict): return value def __getattr__(self, key: str): - if key in self: - return self[key] + if (value := self.get(key, None)) is not None: + return value else: super(Module, self).__getattribute__(key) diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 42d2fce79..38eea7791 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -64,9 +64,9 @@ class Linear(Module): def __call__(self, x: mx.array) -> mx.array: if "bias" in self: - x = mx.addmm(self.bias, x, self.weight.T) + x = mx.addmm(self["bias"], x, self["weight"].T) else: - x = x @ self.weight.T + x = x @ self["weight"].T return x diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 7b53c2c5e..588c2ed2b 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -140,7 +140,7 @@ class RMSNorm(Module): return f"{self.weight.shape[0]}, eps={self.eps}" def __call__(self, x): - return mx.fast.rms_norm(x, self.weight, self.eps) + return mx.fast.rms_norm(x, self["weight"], self.eps) class GroupNorm(Module): diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index a52285633..15eccf0b1 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -81,15 +81,15 @@ class QuantizedLinear(Module): def __call__(self, x): x = mx.quantized_matmul( x, - self.weight, - scales=self.scales, - biases=self.biases, + self["weight"], + scales=self["scales"], + biases=self["biases"], transpose=True, group_size=self.group_size, bits=self.bits, ) if "bias" in self: - x = x + self.bias + x = x + self["bias"] return x @classmethod diff --git a/python/src/fast.cpp b/python/src/fast.cpp index cbdbcde47..306ef1bb4 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -17,12 +17,7 @@ void init_fast(nb::module_& parent_module) { m.def( "rms_norm", - [](const array& x, - const array& weight, - float eps, - const StreamOrDevice& s /* = {} */) { - return fast::rms_norm(x, weight, eps, s); - }, + &fast::rms_norm, "x"_a, "weight"_a, "eps"_a, @@ -48,13 +43,7 @@ void init_fast(nb::module_& parent_module) { m.def( "layer_norm", - [](const array& x, - const std::optional& weight, - const std::optional& bias, - float eps, - const StreamOrDevice& s /* = {} */) { - return fast::layer_norm(x, weight, bias, eps, s); - }, + &fast::layer_norm, "x"_a, "weight"_a.none(), "bias"_a.none(), @@ -84,15 +73,7 @@ void init_fast(nb::module_& parent_module) { m.def( "rope", - [](const array& a, - int dims, - bool traditional, - float base, - float scale, - int offset, - const StreamOrDevice& s /* = {} */) { - return fast::rope(a, dims, traditional, base, scale, offset, s); - }, + &fast::rope, "a"_a, "dims"_a, nb::kw_only(), @@ -123,14 +104,7 @@ void init_fast(nb::module_& parent_module) { m.def( "scaled_dot_product_attention", - [](const array& q, - const array& k, - const array& v, - const float scale, - const std::optional& mask, - const StreamOrDevice& s) { - return fast::scaled_dot_product_attention(q, k, v, scale, mask, s); - }, + &fast::scaled_dot_product_attention, "q"_a, "k"_a, "v"_a,