diff --git a/mlx/array.cpp b/mlx/array.cpp index d90c7d446..70ecab40d 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -31,7 +31,7 @@ array::array(const std::complex& val, Dtype dtype /* = complex64 */) } array::array( - std::vector shape, + Shape shape, Dtype dtype, std::shared_ptr primitive, std::vector inputs) @@ -42,7 +42,7 @@ array::array( std::move(inputs))) {} std::vector array::make_arrays( - std::vector> shapes, + std::vector shapes, const std::vector& dtypes, const std::shared_ptr& primitive, const std::vector& inputs) { @@ -74,11 +74,7 @@ array::array(std::initializer_list data, Dtype dtype) } /* Build an array from a shared buffer */ -array::array( - allocator::Buffer data, - std::vector shape, - Dtype dtype, - deleter_t deleter) +array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter) : array_desc_(std::make_shared(std::move(shape), dtype)) { set_data(data, deleter); } @@ -126,7 +122,7 @@ bool array::is_tracer() const { return array_desc_->is_tracer && in_tracing() || retain_graph(); } -void array::set_data(allocator::Buffer buffer, deleter_t d) { +void array::set_data(allocator::Buffer buffer, Deleter d) { array_desc_->data = std::make_shared(buffer, d); array_desc_->data_ptr = buffer.raw_ptr(); array_desc_->data_size = size(); @@ -139,9 +135,9 @@ void array::set_data(allocator::Buffer buffer, deleter_t d) { void array::set_data( allocator::Buffer buffer, size_t data_size, - std::vector strides, + Strides strides, Flags flags, - deleter_t d) { + Deleter d) { array_desc_->data = std::make_shared(buffer, d); array_desc_->data_ptr = buffer.raw_ptr(); array_desc_->data_size = data_size; @@ -151,7 +147,7 @@ void array::set_data( void array::copy_shared_buffer( const array& other, - const std::vector& strides, + const Strides& strides, Flags flags, size_t data_size, size_t offset /* = 0 */) { @@ -170,7 +166,7 @@ void array::copy_shared_buffer(const array& other) { void array::move_shared_buffer( array other, - const std::vector& strides, + const Strides& strides, Flags flags, size_t data_size, size_t offset /* = 0 */) { @@ -237,13 +233,13 @@ void array::ArrayDesc::init() { } } -array::ArrayDesc::ArrayDesc(std::vector shape, Dtype dtype) +array::ArrayDesc::ArrayDesc(Shape shape, Dtype dtype) : shape(std::move(shape)), dtype(dtype), status(Status::available) { init(); } array::ArrayDesc::ArrayDesc( - std::vector shape, + Shape shape, Dtype dtype, std::shared_ptr primitive, std::vector inputs) diff --git a/mlx/array.h b/mlx/array.h index f41baf568..8c1f7e933 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -15,7 +15,10 @@ namespace mlx::core { // Forward declaration class Primitive; -using deleter_t = std::function; + +using Deleter = std::function; +using Shape = std::vector; +using Strides = std::vector; class array { /* An array is really a node in a graph. It contains a shared ArrayDesc @@ -33,7 +36,7 @@ class array { template array( It data, - std::vector shape, + Shape shape, Dtype dtype = TypeToDtype::value_type>()); @@ -49,15 +52,15 @@ class array { template array( std::initializer_list data, - std::vector shape, + Shape shape, Dtype dtype = TypeToDtype()); /* Build an array from a buffer */ array( allocator::Buffer data, - std::vector shape, + Shape shape, Dtype dtype, - deleter_t deleter = allocator::free); + Deleter deleter = allocator::free); /** Assignment to rvalue does not compile. */ array& operator=(const array& other) && = delete; @@ -96,7 +99,7 @@ class array { } /** The shape of the array as a vector of integers. */ - const std::vector& shape() const { + const Shape& shape() const { return array_desc_->shape; } @@ -105,12 +108,12 @@ class array { * * This function supports negative indexing and provides * bounds checking. */ - int shape(int dim) const { + auto shape(int dim) const { return shape().at(dim < 0 ? dim + ndim() : dim); } /** The strides of the array. */ - const std::vector& strides() const { + const Strides& strides() const { return array_desc_->strides; } @@ -119,7 +122,7 @@ class array { * * This function supports negative indexing and provides * bounds checking. */ - size_t strides(int dim) const { + auto strides(int dim) const { return strides().at(dim < 0 ? dim + ndim() : dim); } @@ -184,13 +187,13 @@ class array { */ array( - std::vector shape, + Shape shape, Dtype dtype, std::shared_ptr primitive, std::vector inputs); static std::vector make_arrays( - std::vector> shapes, + std::vector shapes, const std::vector& dtypes, const std::shared_ptr& primitive, const std::vector& inputs); @@ -207,8 +210,8 @@ class array { struct Data { allocator::Buffer buffer; - deleter_t d; - Data(allocator::Buffer buffer, deleter_t d = allocator::free) + Deleter d; + Data(allocator::Buffer buffer, Deleter d = allocator::free) : buffer(buffer), d(d) {} // Not copyable Data(const Data& d) = delete; @@ -397,18 +400,18 @@ class array { // Check if the array is a tracer array bool is_tracer() const; - void set_data(allocator::Buffer buffer, deleter_t d = allocator::free); + void set_data(allocator::Buffer buffer, Deleter d = allocator::free); void set_data( allocator::Buffer buffer, size_t data_size, - std::vector strides, + Strides strides, Flags flags, - deleter_t d = allocator::free); + Deleter d = allocator::free); void copy_shared_buffer( const array& other, - const std::vector& strides, + const Strides& strides, Flags flags, size_t data_size, size_t offset = 0); @@ -417,7 +420,7 @@ class array { void move_shared_buffer( array other, - const std::vector& strides, + const Strides& strides, Flags flags, size_t data_size, size_t offset = 0); @@ -436,8 +439,8 @@ class array { void init(const It src); struct ArrayDesc { - std::vector shape; - std::vector strides; + Shape shape; + Strides strides; size_t size; Dtype dtype; std::shared_ptr primitive; @@ -471,10 +474,10 @@ class array { // The arrays position in the output list uint32_t position{0}; - explicit ArrayDesc(std::vector shape, Dtype dtype); + explicit ArrayDesc(Shape shape, Dtype dtype); explicit ArrayDesc( - std::vector shape, + Shape shape, Dtype dtype, std::shared_ptr primitive, std::vector inputs); @@ -502,7 +505,7 @@ array::array(T val, Dtype dtype /* = TypeToDtype() */) template array::array( It data, - std::vector shape, + Shape shape, Dtype dtype /* = TypeToDtype::value_type>() */) : array_desc_(std::make_shared(std::move(shape), dtype)) { init(data); @@ -521,7 +524,7 @@ array::array( template array::array( std::initializer_list data, - std::vector shape, + Shape shape, Dtype dtype /* = TypeToDtype() */) : array_desc_(std::make_shared(std::move(shape), dtype)) { if (data.size() != size()) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index cbe8a6861..2e346236a 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -16,10 +16,9 @@ namespace mlx::core { namespace { -std::tuple, std::vector, std::vector, bool> -compute_reduce_shape( +std::tuple, Shape, bool> compute_reduce_shape( const std::vector& axes, - const std::vector& shape) { + const Shape& shape) { bool is_noop = true; std::set axes_set; auto ndim = shape.size(); @@ -36,8 +35,8 @@ compute_reduce_shape( if (axes_set.size() != axes.size()) { throw std::invalid_argument("Duplicate axes detected in reduction."); } - std::vector out_shape; - std::vector squeezed_shape; + Shape out_shape; + Shape squeezed_shape; for (int i = 0; i < ndim; ++i) { if (axes_set.count(i) == 0) { out_shape.push_back(shape[i]); @@ -63,7 +62,7 @@ array indices_or_default( return indices.value(); } - std::vector shape(x.shape().begin(), x.shape().end() - 2); + Shape shape(x.shape().begin(), x.shape().end() - 2); int total = std::reduce(shape.begin(), shape.end(), 1, std::multiplies()); return reshape(arange(total, uint32, s), shape, s); @@ -254,8 +253,8 @@ array astype(array a, Dtype dtype, StreamOrDevice s /* = {} */) { array as_strided( array a, - std::vector shape, - std::vector strides, + Shape shape, + Strides strides, size_t offset, StreamOrDevice s /* = {} */) { auto copied_shape = shape; // |shape| will be moved @@ -279,12 +278,8 @@ array copy(array a, StreamOrDevice s /* = {} */) { {std::move(a)}); } -array full( - std::vector shape, - array vals, - Dtype dtype, - StreamOrDevice s /* = {} */) { - if (std::any_of(shape.begin(), shape.end(), [](int i) { return i < 0; })) { +array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s /* = {} */) { + if (std::any_of(shape.begin(), shape.end(), [](auto i) { return i < 0; })) { throw std::invalid_argument("[full] Negative dimensions not allowed."); } auto copied_shape = shape; // |shape| will be moved @@ -295,15 +290,12 @@ array full( {broadcast_to(astype(std::move(vals), dtype, s), std::move(shape), s)}); } -array full(std::vector shape, array vals, StreamOrDevice s /* = {} */) { +array full(Shape shape, array vals, StreamOrDevice s /* = {} */) { auto dtype = vals.dtype(); // |vals| will be moved return full(std::move(shape), std::move(vals), dtype, to_stream(s)); } -array zeros( - const std::vector& shape, - Dtype dtype, - StreamOrDevice s /* = {} */) { +array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) { return full(shape, array(0, dtype), to_stream(s)); } @@ -311,10 +303,7 @@ array zeros_like(const array& a, StreamOrDevice s /* = {} */) { return zeros(a.shape(), a.dtype(), to_stream(s)); } -array ones( - const std::vector& shape, - Dtype dtype, - StreamOrDevice s /* = {} */) { +array ones(const Shape& shape, Dtype dtype, StreamOrDevice s /* = {} */) { return full(shape, array(1, dtype), to_stream(s)); } @@ -368,10 +357,7 @@ array triu(array x, int k /* = 0 */, StreamOrDevice s /* = {} */) { return where(mask, zeros_like(x, s), x, s); } -array reshape( - const array& a, - std::vector shape, - StreamOrDevice s /* = {} */) { +array reshape(const array& a, Shape shape, StreamOrDevice s /* = {} */) { if (a.shape() == shape) { return a; } @@ -445,11 +431,11 @@ array flatten( if (start_ax == end_ax) { return a; } - std::vector new_shape(a.shape().begin(), a.shape().begin() + start_ax); + Shape new_shape(a.shape().begin(), a.shape().begin() + start_ax); new_shape.push_back(-1); new_shape.insert( new_shape.end(), a.shape().begin() + end_ax + 1, a.shape().end()); - return reshape(a, new_shape, s); + return reshape(a, std::move(new_shape), s); } array flatten(const array& a, StreamOrDevice s /* = {} */) { @@ -496,7 +482,7 @@ array squeeze( throw std::invalid_argument("[squeeze] Received duplicate axes."); } std::vector sorted_axes(unique_axes.begin(), unique_axes.end()); - std::vector shape; + Shape shape; for (int i = 0, j = 0; i < a.ndim(); ++i) { if (j < sorted_axes.size() && i == sorted_axes[j]) { j++; @@ -584,12 +570,9 @@ array expand_dims( // Slice helper namespace { -inline auto normalize_slice( - const std::vector& shape, - std::vector& start, - std::vector& stop, - std::vector& strides) { - std::vector out_shape(shape.size()); +inline auto +normalize_slice(const Shape& shape, Shape& start, Shape& stop, Shape& strides) { + Shape out_shape(shape.size()); bool has_neg_strides = false; for (int i = 0; i < shape.size(); ++i) { @@ -641,9 +624,9 @@ inline auto normalize_slice( array slice( const array& a, - std::vector start, - std::vector stop, - std::vector strides, + Shape start, + Shape stop, + Shape strides, StreamOrDevice s /* = {} */) { if (start.size() != a.ndim() || stop.size() != a.ndim() || strides.size() != a.ndim()) { @@ -670,24 +653,20 @@ array slice( array slice( const array& a, - std::vector start, - std::vector stop, + Shape start, + Shape stop, StreamOrDevice s /* = {} */) { return slice( - a, - std::move(start), - std::move(stop), - std::vector(a.ndim(), 1), - to_stream(s)); + a, std::move(start), std::move(stop), Shape(a.ndim(), 1), to_stream(s)); } /** Update a slice from the source array */ array slice_update( const array& src, const array& update, - std::vector start, - std::vector stop, - std::vector strides, + Shape start, + Shape stop, + Shape strides, StreamOrDevice s /* = {} */) { // Check dimensions if (start.size() != src.ndim() || stop.size() != src.ndim() || @@ -721,12 +700,11 @@ array slice_update( array slice_update( const array& src, const array& update, - std::vector start, - std::vector stop, + Shape start, + Shape stop, StreamOrDevice s /* = {} */) { - auto strides = std::vector(src.ndim(), 1); return slice_update( - src, update, std::move(start), std::move(stop), std::move(strides), s); + src, update, std::move(start), std::move(stop), Shape(src.ndim(), 1), s); } std::vector split( @@ -750,7 +728,7 @@ std::vector split( std::is_sorted(indices.begin(), indices.end(), std::less<>{}) && indices[0] > 0 && indices.back() < a.shape(ax)) { std::vector dtypes(indices.size() + 1, a.dtype()); - std::vector> shapes(indices.size() + 1, a.shape()); + std::vector shapes(indices.size() + 1, a.shape()); shapes[0][ax] = indices[0]; for (int i = 1; i < indices.size(); i++) { shapes[i][ax] = indices[i] - indices[i - 1]; @@ -765,8 +743,7 @@ std::vector split( } std::vector res; - auto out_shape = a.shape(); - auto start_indices = std::vector(a.ndim(), 0); + auto start_indices = Shape(a.ndim(), 0); auto stop_indices = a.shape(); for (int i = 0; i < indices.size() + 1; ++i) { stop_indices[ax] = i < indices.size() ? indices[i] : a.shape(ax); @@ -826,13 +803,13 @@ std::vector meshgrid( auto ndim = arrays.size(); std::vector outputs; for (int i = 0; i < ndim; ++i) { - std::vector shape(ndim, 1); + Shape shape(ndim, 1); shape[i] = -1; outputs.push_back(reshape(arrays[i], std::move(shape), s)); } if (indexing == "xy" and ndim > 1) { - std::vector shape(ndim, 1); + Shape shape(ndim, 1); shape[1] = arrays[0].size(); outputs[0] = reshape(arrays[0], shape, s); @@ -895,7 +872,7 @@ array concatenate( throw std::invalid_argument(msg.str()); }; - std::vector shape = arrays[0].shape(); + auto shape = arrays[0].shape(); shape[ax] = 0; // Make the output shape and validate that all arrays have the same shape // except for the concatenation axis. @@ -980,7 +957,7 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) { } // Broadcast to (S_1, S_2, ..., S_axis, repeats, S_axis+1, ...) - std::vector shape(arr.shape()); + auto shape = arr.shape(); shape.insert(shape.begin() + axis + 1, repeats); array out = expand_dims(arr, axis + 1, s); out = broadcast_to(out, shape, s); @@ -1009,9 +986,9 @@ array tile( shape.insert(shape.begin(), reps.size() - shape.size(), 1); } - std::vector expand_shape; - std::vector broad_shape; - std::vector final_shape; + Shape expand_shape; + Shape broad_shape; + Shape final_shape; for (int i = 0; i < shape.size(); i++) { if (reps[i] != 1) { expand_shape.push_back(1); @@ -1022,17 +999,17 @@ array tile( final_shape.push_back(reps[i] * shape[i]); } - auto x = reshape(arr, expand_shape, s); - x = broadcast_to(x, broad_shape, s); - return reshape(x, final_shape, s); + auto x = reshape(arr, std::move(expand_shape), s); + x = broadcast_to(x, std::move(broad_shape), s); + return reshape(x, std::move(final_shape), s); } array edge_pad( const array& a, const std::vector& axes, - const std::vector& low_pad_size, - const std::vector& high_pad_size, - const std::vector& out_shape, + const Shape& low_pad_size, + const Shape& high_pad_size, + const Shape& out_shape, StreamOrDevice s /* = {}*/) { array out = zeros(out_shape, a.dtype(), s); auto stops = a.shape(); @@ -1044,7 +1021,7 @@ array edge_pad( for (int axis = 0; axis < a.ndim(); axis++) { if (low_pad_size[axis] > 0) { - std::vector starts(a.ndim(), 0); + Shape starts(a.ndim(), 0); starts[axis] = low_pad_size[axis]; auto stops = out.shape(); stops[axis] = low_pad_size[axis] + 1; @@ -1058,7 +1035,7 @@ array edge_pad( } if (high_pad_size[axis] > 0) { - std::vector starts(a.ndim(), 0); + Shape starts(a.ndim(), 0); starts[axis] = -high_pad_size[axis] - 1; auto stops = out.shape(); stops[axis] = -high_pad_size[axis]; @@ -1075,9 +1052,9 @@ array edge_pad( /** Pad an array with a constant value */ array pad( const array& a, - const std::vector& axes, - const std::vector& low_pad_size, - const std::vector& high_pad_size, + const Shape& axes, + const Shape& low_pad_size, + const Shape& high_pad_size, const array& pad_value /*= array(0)*/, const std::string mode /*= "constant"*/, StreamOrDevice s /* = {}*/) { @@ -1089,7 +1066,7 @@ array pad( throw std::invalid_argument(msg.str()); } - std::vector out_shape = a.shape(); + auto out_shape = a.shape(); for (int i = 0; i < axes.size(); i++) { if (low_pad_size[i] < 0) { @@ -1113,7 +1090,7 @@ array pad( if (mode == "constant") { return array( - out_shape, + std::move(out_shape), a.dtype(), std::make_shared(to_stream(s), axes, low_pad_size, high_pad_size), {a, astype(pad_value, a.dtype(), s)}); @@ -1136,8 +1113,8 @@ array pad( std::vector axes(a.ndim(), 0); std::iota(axes.begin(), axes.end(), 0); - std::vector lows; - std::vector highs; + Shape lows; + Shape highs; for (auto& pads : pad_width) { lows.push_back(pads.first); @@ -1240,7 +1217,7 @@ array transpose( } // Check in bounds and for duplicates - std::vector shape(axes.size(), 0); + Shape shape(axes.size(), 0); for (auto& ax : axes) { if (ax < 0 || ax >= a.ndim()) { std::ostringstream msg; @@ -1272,7 +1249,7 @@ array transpose(const array& a, StreamOrDevice s /* = {} */) { array broadcast_to( const array& a, - const std::vector& shape, + const Shape& shape, StreamOrDevice s /* = {} */) { if (a.shape() == shape) { return a; @@ -1295,14 +1272,14 @@ array broadcast_to( std::vector broadcast_arrays(const array& a, const array& b, StreamOrDevice s /* = {} */) { - std::vector shape = broadcast_shapes(a.shape(), b.shape()); + auto shape = broadcast_shapes(a.shape(), b.shape()); return {broadcast_to(a, shape, s), broadcast_to(b, shape, s)}; } std::vector broadcast_arrays( const std::vector& inputs, StreamOrDevice s /* = {} */) { - std::vector shape{}; + Shape shape{}; for (const auto& in : inputs) { shape = broadcast_shapes(shape, in.shape()); } @@ -1913,7 +1890,7 @@ array argmax(const array& a, bool keepdims, StreamOrDevice s /* = {} */) { int size = a.size(); auto result = argmax(reshape(a, {size}, s), 0, true, s); if (keepdims) { - result = reshape(result, std::vector(a.shape().size(), 1), s); + result = reshape(result, Shape(a.shape().size(), 1), s); } else { result = squeeze(result, s); } @@ -2098,8 +2075,8 @@ array topk(const array& a, int k, int axis, StreamOrDevice s /* = {}*/) { } array a_partitioned = partition(a, -k, axis_, s); - std::vector slice_starts(a.ndim(), 0); - std::vector slice_ends = a.shape(); + Shape slice_starts(a.ndim(), 0); + auto slice_ends = a.shape(); slice_starts[axis_] = a.shape(axis_) - k; return slice(a_partitioned, slice_starts, slice_ends, s); } @@ -2613,8 +2590,8 @@ array matmul( } if (a.ndim() > 2 || b.ndim() > 2) { - std::vector bsx_a(a.shape().begin(), a.shape().end() - 2); - std::vector bsx_b(b.shape().begin(), b.shape().end() - 2); + Shape bsx_a(a.shape().begin(), a.shape().end() - 2); + Shape bsx_b(b.shape().begin(), b.shape().end() - 2); auto inner_shape = broadcast_shapes(bsx_a, bsx_b); // Broadcast a @@ -2648,7 +2625,7 @@ array gather( const array& a, const std::vector& indices, const std::vector& axes, - const std::vector& slice_sizes, + const Shape& slice_sizes, StreamOrDevice s /* = {} */) { // Checks that indices, dimensions, and slice_sizes are all valid if (indices.size() > a.ndim()) { @@ -2703,7 +2680,7 @@ array gather( idx = astype(idx, dtype, s); } - std::vector out_shape; + Shape out_shape; if (!inputs.empty()) { out_shape = inputs[0].shape(); } @@ -2741,7 +2718,7 @@ array take( axis = axis < 0 ? a.ndim() + axis : axis; // Make slice sizes to pass to gather - std::vector slice_sizes = a.shape(); + Shape slice_sizes = a.shape(); slice_sizes[axis] = indices.size() > 0 ? 1 : 0; auto out = gather(a, indices, axis, slice_sizes, s); @@ -2759,7 +2736,7 @@ array take( } // Squeeze the axis we take over - std::vector out_shape = out.shape(); + auto out_shape = out.shape(); out_shape.erase(out_shape.begin() + indices.ndim() + axis); return reshape(out, std::move(out_shape), s); } @@ -2787,8 +2764,8 @@ array take(const array& a, int index, int axis, StreamOrDevice s /* = {} */) { // Handle negative axis axis = axis < 0 ? a.ndim() + axis : axis; - std::vector starts(a.ndim(), 0); - std::vector stops = a.shape(); + Shape starts(a.ndim(), 0); + Shape stops = a.shape(); starts[axis] = index; stops[axis] = index + 1; return squeeze(slice(a, std::move(starts), std::move(stops), s), axis, s); @@ -2821,7 +2798,7 @@ array take_along_axis( axis = axis < 0 ? a.ndim() + axis : axis; std::vector nd_indices; - std::vector index_shape(a.ndim(), 1); + Shape index_shape(a.ndim(), 1); for (int i = 0; i < a.ndim(); ++i) { if (i == axis) { nd_indices.push_back(indices); @@ -2834,12 +2811,11 @@ array take_along_axis( } std::vector dims(a.ndim()); std::iota(dims.begin(), dims.end(), 0); - std::vector slice_sizes(a.ndim(), a.size() > 0); + Shape slice_sizes(a.ndim(), a.size() > 0); auto out = gather(a, nd_indices, dims, slice_sizes, s); // Squeeze out the slice shape - std::vector out_shape( - out.shape().begin(), out.shape().begin() + a.ndim()); + Shape out_shape(out.shape().begin(), out.shape().begin() + a.ndim()); return reshape(out, std::move(out_shape), s); } @@ -2867,7 +2843,7 @@ array put_along_axis( axis = axis < 0 ? a.ndim() + axis : axis; std::vector nd_indices; - std::vector index_shape(a.ndim(), 1); + Shape index_shape(a.ndim(), 1); for (int i = 0; i < a.ndim(); ++i) { if (i == axis) { nd_indices.push_back(indices); @@ -2927,7 +2903,7 @@ array scatter( // Broadcast and cast indices if necessary auto inputs = broadcast_arrays(indices); - std::vector idx_shape; + Shape idx_shape; if (!inputs.empty()) { idx_shape = inputs[0].shape(); } @@ -3198,7 +3174,7 @@ inline int dilate_size(int dim, int dil) { return 1 + dil * (dim - 1); } -inline std::vector conv_out_shape( +Shape conv_out_shape( const std::vector& in_shape, const std::vector& wt_shape, const std::vector& strides, @@ -3208,7 +3184,7 @@ inline std::vector conv_out_shape( const std::vector& input_dilation) { int N = in_shape[0]; int O = wt_shape[0]; - std::vector out_shape(in_shape.size()); + Shape out_shape(in_shape.size()); int i = 0; out_shape[i++] = N; @@ -3577,8 +3553,8 @@ array conv_general( // Handle negative padding if (has_neg_padding) { - std::vector starts(in.ndim(), 0); - std::vector stops = in.shape(); + Shape starts(in.ndim(), 0); + auto stops = in.shape(); for (int i = 0; i < spatial_dims; i++) { if (padding_lo[i] < 0) { @@ -3596,7 +3572,7 @@ array conv_general( } // Get output shapes - std::vector out_shape = conv_out_shape( + auto out_shape = conv_out_shape( in.shape(), wt.shape(), stride, @@ -3606,7 +3582,7 @@ array conv_general( input_dilation); return array( - out_shape, + std::move(out_shape), in.dtype(), std::make_shared( to_stream(s), @@ -3634,8 +3610,8 @@ array quantized_matmul( // QuantizedMatmul handles w.ndim == 2 case. if (x.ndim() > 2 && w.ndim() > 2) { - std::vector bsx_x(x.shape().begin(), x.shape().end() - 2); - std::vector bsx_w(w.shape().begin(), w.shape().end() - 2); + Shape bsx_x(x.shape().begin(), x.shape().end() - 2); + Shape bsx_w(w.shape().begin(), w.shape().end() - 2); auto inner_shape = broadcast_shapes(bsx_x, bsx_w); // Broadcast x @@ -3731,7 +3707,7 @@ array gather_qmm( // and output type auto out_type = result_type(x, scales, biases); - auto out = array( + return array( std::move(out_shape), out_type, std::make_shared(to_stream(s), group_size, bits, transpose), @@ -3741,8 +3717,6 @@ array gather_qmm( astype(biases, out_type, s), lhs_indices, rhs_indices}); - - return out; } array tensordot( @@ -3802,7 +3776,7 @@ array tensordot( std::vector t1; std::vector t2; - std::vector rshape; + Shape rshape; int size1 = 1; int size2 = 1; for (int i = 0; i < a.ndim(); i++) { @@ -3898,7 +3872,7 @@ array addmm( // We can batch the multiplication by reshaping a if (a.ndim() > 2 && b.ndim() == 2 && c.ndim() <= 1) { - std::vector out_shape = a.shape(); + auto out_shape = a.shape(); a = reshape(a, {-1, out_shape.back()}, s); out_shape.back() = b.shape(-1); @@ -3917,8 +3891,8 @@ array addmm( } if (a.ndim() > 2 || b.ndim() > 2) { - std::vector bsx_a(a.shape().begin(), a.shape().end() - 2); - std::vector bsx_b(b.shape().begin(), b.shape().end() - 2); + Shape bsx_a(a.shape().begin(), a.shape().end() - 2); + Shape bsx_b(b.shape().begin(), b.shape().end() - 2); auto inner_shape = broadcast_shapes(bsx_a, bsx_b); // Broadcast a @@ -4042,8 +4016,8 @@ array block_masked_mm( b = astype(b, out_type, s); // Handle broadcasting - std::vector bsx_a(a.shape().begin(), a.shape().end() - 2); - std::vector bsx_b(b.shape().begin(), b.shape().end() - 2); + Shape bsx_a(a.shape().begin(), a.shape().end() - 2); + Shape bsx_b(b.shape().begin(), b.shape().end() - 2); auto bsx_shape = broadcast_shapes(bsx_a, bsx_b); @@ -4079,7 +4053,7 @@ array block_masked_mm( // Broadcast and astype mask auto broadcast_mask = [](array mask, - std::vector& bs_shape, + Shape& bs_shape, int y, int x, Dtype mask_dtype, @@ -4397,7 +4371,7 @@ std::vector depends( Stream s = (inputs[0].has_primitive()) ? inputs[0].primitive().stream() : to_stream({}); // Make the output info - std::vector> shapes; + std::vector shapes; std::vector dtypes; for (const auto& in : inputs) { shapes.emplace_back(in.shape()); @@ -4434,7 +4408,7 @@ array atleast_2d(const array& a, StreamOrDevice s /* = {} */) { case 0: return reshape(a, {1, 1}, s); case 1: - return reshape(a, {1, static_cast(a.size())}, s); + return reshape(a, {1, a.shape(0)}, s); default: return a; } @@ -4456,7 +4430,7 @@ array atleast_3d(const array& a, StreamOrDevice s /* = {} */) { case 0: return reshape(a, {1, 1, 1}, s); case 1: - return reshape(a, {1, static_cast(a.size()), 1}, s); + return reshape(a, {1, a.shape(0), 1}, s); case 2: return reshape(a, {a.shape(0), a.shape(1), 1}, s); default: @@ -4493,7 +4467,7 @@ array number_of_elements( } return stop_gradient(array( - std::vector{}, + Shape{}, dtype, std::make_shared( to_stream(s), std::move(axes), inverted, dtype), @@ -4613,7 +4587,7 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {} */) { array roll( const array& a, - const std::vector& shift, + const Shape& shift, const std::vector& axes, StreamOrDevice s /* = {} */) { if (axes.empty()) { @@ -4627,7 +4601,6 @@ array roll( throw std::invalid_argument(msg.str()); } - std::vector parts; array result = a; for (int i = 0; i < axes.size(); i++) { int ax = axes[i]; @@ -4641,11 +4614,11 @@ array roll( throw std::invalid_argument(msg.str()); } - int sh = shift[i]; - int split_index = + auto sh = shift[i]; + auto split_index = (sh < 0) ? (-sh) % a.shape(ax) : a.shape(ax) - sh % a.shape(ax); - parts = split(result, std::vector{split_index}, ax, s); + auto parts = split(result, Shape{split_index}, ax, s); std::swap(parts[0], parts[1]); result = concatenate(parts, ax, s); } @@ -4656,19 +4629,12 @@ array roll( array roll(const array& a, int shift, StreamOrDevice s /* = {} */) { auto shape = a.shape(); return reshape( - roll( - reshape(a, std::vector{-1}, s), - std::vector{shift}, - std::vector{0}, - s), + roll(reshape(a, Shape{-1}, s), Shape{shift}, std::vector{0}, s), std::move(shape), s); } -array roll( - const array& a, - const std::vector& shift, - StreamOrDevice s /* = {} */) { +array roll(const array& a, const Shape& shift, StreamOrDevice s /* = {} */) { int total_shift = 0; for (auto& s : shift) { total_shift += s; @@ -4677,7 +4643,7 @@ array roll( } array roll(const array& a, int shift, int axis, StreamOrDevice s /* = {} */) { - return roll(a, std::vector{shift}, std::vector{axis}, s); + return roll(a, Shape{shift}, std::vector{axis}, s); } array roll( @@ -4685,20 +4651,20 @@ array roll( int shift, const std::vector& axes, StreamOrDevice s /* = {} */) { - std::vector shifts(axes.size(), shift); + Shape shifts(axes.size(), shift); return roll(a, shifts, axes, s); } array roll( const array& a, - const std::vector& shift, + const Shape& shift, int axis, StreamOrDevice s /* = {} */) { int total_shift = 0; for (auto& s : shift) { total_shift += s; } - return roll(a, std::vector{total_shift}, std::vector{axis}, s); + return roll(a, Shape{total_shift}, std::vector{axis}, s); } array real(const array& a, StreamOrDevice s /* = {} */) { diff --git a/mlx/ops.h b/mlx/ops.h index fdceeed0d..7e24b5820 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -49,8 +49,8 @@ array astype(array a, Dtype dtype, StreamOrDevice s = {}); /** Create a view of an array with the given shape and strides. */ array as_strided( array a, - std::vector shape, - std::vector strides, + Shape shape, + Strides strides, size_t offset, StreamOrDevice s = {}); @@ -58,31 +58,27 @@ array as_strided( array copy(array a, StreamOrDevice s = {}); /** Fill an array of the given shape with the given value(s). */ -array full( - std::vector shape, - array vals, - Dtype dtype, - StreamOrDevice s = {}); -array full(std::vector shape, array vals, StreamOrDevice s = {}); +array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s = {}); +array full(Shape shape, array vals, StreamOrDevice s = {}); template -array full(std::vector shape, T val, Dtype dtype, StreamOrDevice s = {}) { +array full(Shape shape, T val, Dtype dtype, StreamOrDevice s = {}) { return full(std::move(shape), array(val, dtype), to_stream(s)); } template -array full(std::vector shape, T val, StreamOrDevice s = {}) { +array full(Shape shape, T val, StreamOrDevice s = {}) { return full(std::move(shape), array(val), to_stream(s)); } /** Fill an array of the given shape with zeros. */ -array zeros(const std::vector& shape, Dtype dtype, StreamOrDevice s = {}); -inline array zeros(const std::vector& shape, StreamOrDevice s = {}) { +array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {}); +inline array zeros(const Shape& shape, StreamOrDevice s = {}) { return zeros(shape, float32, s); } array zeros_like(const array& a, StreamOrDevice s = {}); /** Fill an array of the given shape with ones. */ -array ones(const std::vector& shape, Dtype dtype, StreamOrDevice s = {}); -inline array ones(const std::vector& shape, StreamOrDevice s = {}) { +array ones(const Shape& shape, Dtype dtype, StreamOrDevice s = {}); +inline array ones(const Shape& shape, StreamOrDevice s = {}) { return ones(shape, float32, s); } array ones_like(const array& a, StreamOrDevice s = {}); @@ -119,7 +115,7 @@ array tril(array x, int k = 0, StreamOrDevice s = {}); array triu(array x, int k = 0, StreamOrDevice s = {}); /** Reshape an array to the given shape. */ -array reshape(const array& a, std::vector shape, StreamOrDevice s = {}); +array reshape(const array& a, Shape shape, StreamOrDevice s = {}); /** Flatten the dimensions in the range `[start_axis, end_axis]` . */ array flatten( @@ -161,33 +157,29 @@ array expand_dims(const array& a, int axis, StreamOrDevice s = {}); /** Slice an array. */ array slice( const array& a, - std::vector start, - std::vector stop, - std::vector strides, + Shape start, + Shape stop, + Shape strides, StreamOrDevice s = {}); /** Slice an array with a stride of 1 in each dimension. */ -array slice( - const array& a, - std::vector start, - std::vector stop, - StreamOrDevice s = {}); +array slice(const array& a, Shape start, Shape stop, StreamOrDevice s = {}); /** Update a slice from the source array */ array slice_update( const array& src, const array& update, - std::vector start, - std::vector stop, - std::vector strides, + Shape start, + Shape stop, + Shape strides, StreamOrDevice s = {}); /** Update a slice from the source array with stride 1 in each dimension */ array slice_update( const array& src, const array& update, - std::vector start, - std::vector stop, + Shape start, + Shape stop, StreamOrDevice s = {}); /** Split an array into sub-arrays along a given axis. */ @@ -288,10 +280,7 @@ array pad( array transpose(const array& a, StreamOrDevice s = {}); /** Broadcast an array to a given shape. */ -array broadcast_to( - const array& a, - const std::vector& shape, - StreamOrDevice s = {}); +array broadcast_to(const array& a, const Shape& shape, StreamOrDevice s = {}); /** Broadcast a vector of arrays against one another. */ std::vector broadcast_arrays( @@ -917,13 +906,13 @@ array gather( const array& a, const std::vector& indices, const std::vector& axes, - const std::vector& slice_sizes, + const Shape& slice_sizes, StreamOrDevice s = {}); inline array gather( const array& a, const array& indices, int axis, - const std::vector& slice_sizes, + const Shape& slice_sizes, StreamOrDevice s = {}) { return gather(a, {indices}, std::vector{axis}, slice_sizes, s); } @@ -1459,24 +1448,13 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s = {}); /** Roll elements along an axis and introduce them on the other side */ array roll(const array& a, int shift, StreamOrDevice s = {}); -array roll( - const array& a, - const std::vector& shift, - StreamOrDevice s = {}); +array roll(const array& a, const Shape& shift, StreamOrDevice s = {}); array roll(const array& a, int shift, int axis, StreamOrDevice s = {}); +array roll(const array& a, int shift, const Shape& axes, StreamOrDevice s = {}); +array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {}); array roll( const array& a, - int shift, - const std::vector& axes, - StreamOrDevice s = {}); -array roll( - const array& a, - const std::vector& shift, - int axis, - StreamOrDevice s = {}); -array roll( - const array& a, - const std::vector& shift, + const Shape& shift, const std::vector& axes, StreamOrDevice s = {}); diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 73ab4bf2e..3956e2a24 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -66,9 +66,7 @@ Dtype result_type(const std::vector& arrays) { return t; } -std::vector broadcast_shapes( - const std::vector& s1, - const std::vector& s2) { +Shape broadcast_shapes(const Shape& s1, const Shape& s2) { // Use the same broadcasting rules as numpy // https://numpy.org/doc/1.20/user/theory.broadcasting.html // "The size of the trailing axes for both arrays in an operation must @@ -79,7 +77,7 @@ std::vector broadcast_shapes( int diff = std::abs(ndim1 - ndim2); const auto& big = ndim1 > ndim2 ? s1 : s2; const auto& small = ndim1 > ndim2 ? s2 : s1; - std::vector out_shape(ndim); + Shape out_shape(ndim); for (int i = ndim - 1; i >= diff; --i) { int a = big[i]; int b = small[i - diff]; @@ -158,10 +156,8 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) { namespace { -inline size_t elem_to_loc( - int elem, - const std::vector& shape, - const std::vector& strides) { +inline size_t +elem_to_loc(int elem, const Shape& shape, const Strides& strides) { size_t loc = 0; for (int i = shape.size() - 1; i >= 0; --i) { auto q_and_r = ldiv(elem, shape[i]); @@ -199,7 +195,6 @@ void print_subarray(std::ostream& os, const array& a, size_t index, int dim) { template void print_array(std::ostream& os, const array& a) { - std::vector indices(a.ndim(), 0); os << std::boolalpha; os << "array("; if (a.ndim() == 0) { @@ -310,7 +305,7 @@ std::ostream& operator<<(std::ostream& os, array a) { return os; } -std::ostream& operator<<(std::ostream& os, const std::vector& v) { +std::ostream& operator<<(std::ostream& os, const Shape& v) { os << "("; for (int i = 0; i < v.size(); ++i) { os << v[i] << ((i == v.size() - 1) ? "" : ","); @@ -319,7 +314,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector& v) { return os; } -std::ostream& operator<<(std::ostream& os, const std::vector& v) { +std::ostream& operator<<(std::ostream& os, const Strides& v) { os << "("; for (int i = 0; i < v.size(); ++i) { os << v[i] << ((i == v.size() - 1) ? "" : ","); diff --git a/mlx/utils.h b/mlx/utils.h index e5d1ad9ae..eb194b71e 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -62,9 +62,7 @@ inline Dtype result_type(const array& a, const array& b, const array& c) { } Dtype result_type(const std::vector& arrays); -std::vector broadcast_shapes( - const std::vector& s1, - const std::vector& s2); +Shape broadcast_shapes(const Shape& s1, const Shape& s2); bool is_same_shape(const std::vector& arrays); @@ -96,8 +94,8 @@ std::ostream& operator<<(std::ostream& os, const Stream& s); std::ostream& operator<<(std::ostream& os, const Dtype& d); std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k); std::ostream& operator<<(std::ostream& os, array a); -std::ostream& operator<<(std::ostream& os, const std::vector& v); -std::ostream& operator<<(std::ostream& os, const std::vector& v); +std::ostream& operator<<(std::ostream& os, const Shape& v); +std::ostream& operator<<(std::ostream& os, const Strides& v); std::ostream& operator<<(std::ostream& os, const std::vector& v); inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";