From 828c5f1137b358315e3bfc2e2014957bda08f23a Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 5 Aug 2025 09:41:03 +0900 Subject: [PATCH] Use SmallVector for shapes and strides (#2454) * Use SmallVector for shapes and strides * Convert SmallVector to tuple --- mlx/array.h | 5 +- mlx/backend/common/utils.h | 2 +- mlx/backend/cpu/compiled.cpp | 8 + mlx/backend/cpu/sort.cpp | 47 +- mlx/backend/cuda/conv.cpp | 22 +- mlx/backend/cuda/indexing.cpp | 10 +- mlx/backend/cuda/jit_module.h | 19 +- mlx/backend/cuda/kernel_utils.cuh | 2 +- mlx/backend/cuda/quantized/quantized.cpp | 19 +- mlx/backend/metal/device.h | 10 + mlx/backend/metal/quantized.cpp | 19 +- mlx/export.cpp | 10 +- mlx/ops.cpp | 6 +- mlx/primitives.cpp | 4 +- mlx/primitives.h | 10 +- mlx/small_vector.h | 528 +++++++++++++++++++++++ mlx/utils.cpp | 19 + mlx/utils.h | 2 + python/src/array.cpp | 3 +- python/src/distributed.cpp | 2 +- python/src/export.cpp | 1 + python/src/fast.cpp | 4 +- python/src/fft.cpp | 1 + python/src/linalg.cpp | 1 + python/src/load.cpp | 1 + python/src/metal.cpp | 2 + python/src/ops.cpp | 1 + python/src/random.cpp | 4 +- python/src/small_vector.h | 77 ++++ python/src/transforms.cpp | 1 + 30 files changed, 738 insertions(+), 102 deletions(-) create mode 100644 mlx/small_vector.h create mode 100644 python/src/small_vector.h diff --git a/mlx/array.h b/mlx/array.h index 98eef2e33e..4e9a5ae63a 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -10,6 +10,7 @@ #include "mlx/allocator.h" #include "mlx/dtype.h" #include "mlx/event.h" +#include "mlx/small_vector.h" namespace mlx::core { @@ -18,8 +19,8 @@ class Primitive; using Deleter = std::function; using ShapeElem = int32_t; -using Shape = std::vector; -using Strides = std::vector; +using Shape = SmallVector; +using Strides = SmallVector; class array { /* An array is really a node in a graph. It contains a shared ArrayDesc diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 04d804238d..db0da5e10e 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -200,7 +200,7 @@ void shared_buffer_reshape( array swapaxes_in_eval(const array& x, int axis1, int axis2); template -inline std::vector remove_index(std::vector vec, size_t index) { +inline SmallVector remove_index(SmallVector vec, size_t index) { vec.erase(std::next(vec.begin(), index)); return vec; } diff --git a/mlx/backend/cpu/compiled.cpp b/mlx/backend/cpu/compiled.cpp index d851149871..8aa296619d 100644 --- a/mlx/backend/cpu/compiled.cpp +++ b/mlx/backend/cpu/compiled.cpp @@ -288,6 +288,14 @@ void Compiled::eval_cpu( auto [contiguous, shape, strides] = compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); + // Force allocating shape/strides on heap so we can take their data() first + // and then std::move them. + // TODO: Refactor code to avoid heap allocation. + shape.grow(); + for (auto& s : strides) { + s.grow(); + } + // Collect function input arguments. std::vector args; int strides_index = 1; diff --git a/mlx/backend/cpu/sort.cpp b/mlx/backend/cpu/sort.cpp index 089f7c425f..0b8471d32c 100644 --- a/mlx/backend/cpu/sort.cpp +++ b/mlx/backend/cpu/sort.cpp @@ -8,7 +8,7 @@ #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" - +#include "mlx/dtype_utils.h" #include "mlx/primitives.h" namespace mlx::core { @@ -333,47 +333,24 @@ void Sort::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; + int axis = axis_; + if (axis < 0) { + axis += in.ndim(); + } + // Copy input to output - CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0) + CopyType ctype = (in.flags().contiguous && in.strides()[axis] != 0) ? CopyType::Vector : CopyType::General; copy_cpu(in, out, ctype, stream()); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_output_array(out); - encoder.dispatch( - [out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable { - switch (out.dtype()) { - case bool_: - return sort(out, axis_); - case uint8: - return sort(out, axis_); - case uint16: - return sort(out, axis_); - case uint32: - return sort(out, axis_); - case uint64: - return sort(out, axis_); - case int8: - return sort(out, axis_); - case int16: - return sort(out, axis_); - case int32: - return sort(out, axis_); - case int64: - return sort(out, axis_); - case float32: - return sort(out, axis_); - case float64: - return sort(out, axis_); - case float16: - return sort(out, axis_); - case bfloat16: - return sort(out, axis_); - case complex64: - return sort(out, axis_); - } - }); + encoder.dispatch([out = array::unsafe_weak_copy(out), axis]() mutable { + dispatch_all_types(out.dtype(), [&](auto type_tag) { + sort(out, axis); + }); + }); } void ArgPartition::eval_cpu(const std::vector& inputs, array& out) { diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 9833c348b3..1484e8c46b 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -55,13 +55,13 @@ auto& conv_cache() { return cache; } -template -inline std::vector convert_vector(const std::vector& vec) { - return std::vector(vec.begin(), vec.end()); +template +inline SmallVector convert_vector(const Vec& vec) { + return SmallVector(vec.begin(), vec.end()); } -template -inline std::array fixed_vector(const std::vector& vec) { +template class Vec> +inline std::array fixed_vector(const Vec& vec) { if (vec.size() > MAX_NDIM) { throw std::runtime_error( fmt::format("ndim can not be larger than {}.", MAX_NDIM)); @@ -78,7 +78,7 @@ auto nhwc_to_nchw(const array& x) { auto strides = convert_vector(x.strides()); strides.insert(strides.begin() + 1, strides.back()); strides.erase(strides.end() - 1); - return std::make_tuple(shape, strides); + return std::make_tuple(std::move(shape), std::move(strides)); } inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) { @@ -298,10 +298,10 @@ std::optional build_op_graph( array& x, array& w, array& y, - const std::vector& stride, - const std::vector& padding_lo, - const std::vector& padding_hi, - const std::vector& dilation) { + const SmallVector& stride, + const SmallVector& padding_lo, + const SmallVector& padding_hi, + const SmallVector& dilation) { try { auto compute_dtype = (dtype == float16 || dtype == bfloat16) ? CUDNN_DATA_FLOAT @@ -468,7 +468,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { // There is no reliable way to deduce the proper cuDNN backend for the // convolution, so we make a best guess and then try. - std::vector try_backends; + SmallVector try_backends; if (flip_) { // When weight is flipped, we assume it is backward input convolution. try_backends.push_back(CONV_BACKWARD_INPUT); diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index 22cff87d7f..dd524a72dc 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -29,12 +29,12 @@ void append_indices_arg( const std::vector& inputs, int nidx, int idx_ndim) { - std::vector indices(nidx); + SmallVector indices(nidx); for (int i = 0; i < nidx; ++i) { indices[i] = inputs[i + 1].data(); } args.append(std::move(indices)); - std::vector indices_shape(nidx * idx_ndim); + SmallVector indices_shape(nidx * idx_ndim); for (int i = 0; i < nidx; ++i) { std::copy_n( inputs[i + 1].shape().begin(), @@ -42,7 +42,7 @@ void append_indices_arg( indices_shape.data() + i * idx_ndim); } args.append(std::move(indices_shape)); - std::vector indices_strides(nidx * idx_ndim); + SmallVector indices_strides(nidx * idx_ndim); for (int i = 0; i < nidx; ++i) { std::copy_n( inputs[i + 1].strides().begin(), @@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { args.append(src.ndim()); args.append_ndim(slice_sizes_); args.append(slice_size); - args.append(axes_); + args.append(SmallVector(axes_.begin(), axes_.end())); append_indices_arg(args, inputs, nidx, idx_ndim); std::string kernel_name = fmt::format( @@ -211,7 +211,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { args.append_ndim(out.shape()); args.append_ndim(out.strides()); args.append(out.ndim()); - args.append(axes_); + args.append(SmallVector(axes_.begin(), axes_.end())); append_indices_arg(args, inputs, nidx, idx_ndim); std::string kernel_name = fmt::format( diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index 7fe3fa055f..df3f583528 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -40,19 +40,14 @@ struct KernelArgs { } template - void append(std::vector vec) { - if (vec.empty()) { - // The nullptr can not be used as arg, pass something not null. - append(std::monostate{}); - } else { - append_ptr(vec.data()); - storage_.emplace_back(std::move(vec)); - } + void append(SmallVector vec) { + storage_.emplace_back(std::move(vec)); + append_ptr(std::get>(storage_.back()).data()); } // Make sure the arg is copied to an array with size of NDIM. template - void append_ndim(std::vector vec) { + void append_ndim(SmallVector vec) { if (vec.size() > NDIM) { throw std::runtime_error( fmt::format("ndim can not be larger than {}.", NDIM)); @@ -76,9 +71,9 @@ struct KernelArgs { int32_t, uint32_t, int64_t, - std::vector, - std::vector, - std::vector>; + SmallVector, + SmallVector, + SmallVector>; std::deque storage_; }; diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index fbbca0a060..7a37361ea6 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -101,7 +101,7 @@ inline constexpr bool is_inexact_v = is_floating_v || is_complex_v; // Utility to copy data from vector to array in host. template -inline cuda::std::array const_param(const std::vector& vec) { +inline cuda::std::array const_param(const SmallVector& vec) { if (vec.size() > NDIM) { throw std::runtime_error( fmt::format("ndim can not be larger than {}.", NDIM)); diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 256f2c7d50..008001c508 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -28,15 +28,20 @@ inline array ensure_row_contiguous_matrix( const array& x, cu::CommandEncoder& enc, const Stream& s) { - auto stride_0 = x.strides()[x.ndim() - 2]; - auto stride_1 = x.strides()[x.ndim() - 1]; - if (stride_0 == x.shape(-1) && stride_1 == 1) { - return x; + if (x.ndim() < 2) { + if (x.strides()[0] == 1) { + return x; + } } else { - array x_copy = contiguous_copy_gpu(x, s); - enc.add_temporary(x_copy); - return x_copy; + auto stride_0 = x.strides()[x.ndim() - 2]; + auto stride_1 = x.strides()[x.ndim() - 1]; + if (stride_0 == x.shape(-1) && stride_1 == 1) { + return x; + } } + array x_copy = contiguous_copy_gpu(x, s); + enc.add_temporary(x_copy); + return x_copy; } } // namespace diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 52595e6e65..dfa21aa0a2 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -60,6 +60,16 @@ struct CommandEncoder { enc_->updateFence(fence); } + template + void set_vector_bytes(const SmallVector& vec, size_t nelems, int idx) { + enc_->setBytes(vec.data(), nelems * sizeof(T), idx); + } + template + void set_vector_bytes(const SmallVector& vec, int idx) { + return set_vector_bytes(vec, vec.size(), idx); + } + + // TODO: Code is duplicated but they should be deleted soon. template void set_vector_bytes(const std::vector& vec, size_t nelems, int idx) { enc_->setBytes(vec.data(), nelems * sizeof(T), idx); diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 39c208c033..999825043f 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -32,15 +32,20 @@ inline array ensure_row_contiguous_matrix( const array& x, metal::Device& d, const Stream& s) { - auto stride_0 = x.strides()[x.ndim() - 2]; - auto stride_1 = x.strides()[x.ndim() - 1]; - if (stride_0 == x.shape(-1) && stride_1 == 1) { - return x; + if (x.ndim() < 2) { + if (x.strides()[0] == 1) { + return x; + } } else { - array x_copy = contiguous_copy_gpu(x, s); - d.add_temporary(x_copy, s.index); - return x_copy; + auto stride_0 = x.strides()[x.ndim() - 2]; + auto stride_1 = x.strides()[x.ndim() - 1]; + if (stride_0 == x.shape(-1) && stride_1 == 1) { + return x; + } } + array x_copy = contiguous_copy_gpu(x, s); + d.add_temporary(x_copy, s.index); + return x_copy; } inline int get_qmv_batch_limit(int D, int O, metal::Device& d) { diff --git a/mlx/export.cpp b/mlx/export.cpp index 8eb385bb1c..7099f4864d 100644 --- a/mlx/export.cpp +++ b/mlx/export.cpp @@ -179,7 +179,7 @@ void serialize(Writer& os, const array& arr) { } template <> array deserialize(Reader& is) { - auto shape = deserialize>(is); + auto shape = deserialize(is); auto type = deserialize(is); return array(std::move(shape), type, nullptr, std::vector{}); } @@ -640,7 +640,7 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) { auto outputs = arr.outputs(); serialize(os, arrays_to_ids(outputs)); - std::vector> shapes; + std::vector shapes; std::vector dtypes; for (auto& o : outputs) { shapes.push_back(o.shape()); @@ -813,14 +813,14 @@ ImportedFunction::ImportedFunction(const std::string& file) std::shared_ptr prim = factory.load(is); auto num_siblings = deserialize(is); if (num_siblings == 0) { - auto shape = deserialize>(is); + auto shape = deserialize(is); auto type = deserialize(is); tape.emplace_back( std::move(shape), type, std::move(prim), std::move(inputs)); array_map.emplace(id, tape.back()); } else { auto ids = deserialize>(is); - auto shapes = deserialize>>(is); + auto shapes = deserialize>(is); auto types = deserialize>(is); auto arrays = array::make_arrays( std::move(shapes), @@ -841,7 +841,7 @@ ImportedFunction::ImportedFunction(const std::string& file) if (auto it = constants.find(id); it != constants.end()) { tape.push_back(it->second); } else { - auto shape = deserialize>(is); + auto shape = deserialize(is); auto type = deserialize(is); size_t offset = is.tell(); tape.push_back(array( diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 7161a39b2b..6c4f764243 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3027,9 +3027,9 @@ array kron(const array& a, const array& b, StreamOrDevice s /* = {} */) { } int ndim = std::max(a.ndim(), b.ndim()); - std::vector a_shape(2 * ndim, 1); - std::vector b_shape(2 * ndim, 1); - std::vector out_shape(ndim, 1); + Shape a_shape(2 * ndim, 1); + Shape b_shape(2 * ndim, 1); + Shape out_shape(ndim, 1); for (int i = ndim - 1, j = a.ndim() - 1; j >= 0; j--, i--) { a_shape[2 * i] = a.shape(j); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index a4fa011d51..980a1f7c32 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1205,8 +1205,8 @@ array conv_weight_backward_patches( auto in_padded = pad(in, padded_axes, - Shape(padding_lo), - Shape(padding_hi), + Shape(padding_lo.begin(), padding_lo.end()), + Shape(padding_hi.begin(), padding_hi.end()), array(0, in.dtype()), "constant", s); diff --git a/mlx/primitives.h b/mlx/primitives.h index d482a1bf97..277e42a0b0 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -592,7 +592,7 @@ class Broadcast : public UnaryPrimitive { static Shape output_shape(const std::vector& inputs); std::vector output_shapes(const std::vector& inputs) override; bool is_equivalent(const Primitive& other) const override; - std::vector state() const { + Shape state() const { return shape_; }; @@ -1146,7 +1146,7 @@ class Gather : public UnaryPrimitive { DEFINE_NAME(Gather) bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; - std::pair, std::vector> state() const { + std::pair, Shape> state() const { return {axes_, slice_sizes_}; } @@ -1668,7 +1668,7 @@ class RandomBits : public UnaryPrimitive { DEFINE_VMAP() DEFINE_NAME(RandomBits) bool is_equivalent(const Primitive& other) const override; - std::pair, int> state() const { + std::pair state() const { return {shape_, width_}; }; @@ -1703,7 +1703,7 @@ class Reshape : public UnaryPrimitive { DEFINE_GRADS() DEFINE_NAME(Reshape) bool is_equivalent(const Primitive& other) const override; - std::vector state() const { + Shape state() const { return shape_; }; static Shape output_shape(const array& input, Shape shape); @@ -2121,7 +2121,7 @@ class Split : public Primitive { DEFINE_GRADS() DEFINE_NAME(Split) bool is_equivalent(const Primitive& other) const override; - std::pair, int> state() const { + std::pair state() const { return {indices_, axis_}; }; diff --git a/mlx/small_vector.h b/mlx/small_vector.h new file mode 100644 index 0000000000..fc4c1f06cb --- /dev/null +++ b/mlx/small_vector.h @@ -0,0 +1,528 @@ +// Copyright © 2025 Apple Inc. +// Copyright © 2018 the V8 project authors. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include +#include +#include + +namespace mlx::core { + +#if defined(__has_builtin) +#define MLX_HAS_BUILTIN(x) __has_builtin(x) +#else +#define MLX_HAS_BUILTIN(x) 0 +#endif + +#if defined(__has_attribute) +#define MLX_HAS_ATTRIBUTE(x) __has_attribute(x) +#else +#define MLX_HAS_ATTRIBUTE(x) 0 +#endif + +#if MLX_HAS_BUILTIN(__builtin_expect) +#define MLX_LIKELY(condition) (__builtin_expect(!!(condition), 1)) +#define MLX_UNLIKELY(condition) (__builtin_expect(!!(condition), 0)) +#else +#define MLX_LIKELY(condition) (condition) +#define MLX_UNLIKELY(condition) (condition) +#endif + +#if MLX_HAS_ATTRIBUTE(noinline) +#define MLX_NOINLINE __attribute__((noinline)) +#else +#define MLX_NOINLINE +#endif + +template +struct is_iterator : std::false_type {}; + +template +struct is_iterator< + T, + std::void_t< + typename std::iterator_traits::difference_type, + typename std::iterator_traits::iterator_category, + typename std::iterator_traits::pointer, + typename std::iterator_traits::reference, + typename std::iterator_traits::value_type>> : std::true_type {}; + +template +constexpr bool is_iterator_v = is_iterator::value; + +// Minimal SmallVector implementation. Uses inline storage first, switches to +// dynamic storage when it overflows. +// +// Notes: +// * The default inline storage size is MAX_NDIM, as it is mainly used for +// shapes and strides, users should choose a better size for other cases. +// * The data() returns real address even for empty vector. +// * The pointer returned by data() will change after moving the vector as it +// points to the inline storage. +// * For trivial elements the storage will not be default constructed, +// i.e. SmallVector(10) will not be filled with 0 by default. +template > +class SmallVector { + public: + using value_type = T; + using reference = T&; + using const_reference = const T&; + using iterator = T*; + using const_iterator = const T*; + using difference_type = std::ptrdiff_t; + using size_type = std::size_t; + + SmallVector() = default; + + explicit SmallVector(const Allocator& allocator) : allocator_(allocator) {} + + explicit SmallVector(size_t size, const Allocator& allocator = Allocator()) + : allocator_(allocator) { + resize(size); + } + + SmallVector( + size_t size, + const T& initial_value, + const Allocator& allocator = Allocator()) + : allocator_(allocator) { + resize(size, initial_value); + } + + SmallVector( + std::initializer_list init, + const Allocator& allocator = Allocator()) + : allocator_(allocator) { + if (init.size() > capacity()) { + grow(init.size()); + } + assert(capacity() >= init.size()); // sanity check + std::uninitialized_move(init.begin(), init.end(), begin_); + end_ = begin_ + init.size(); + } + + template >> + SmallVector(Iter begin, Iter end, const Allocator& allocator = Allocator()) + : allocator_(allocator) { + size_t size = std::distance(begin, end); + if (size > capacity()) { + grow(size); + } + assert(capacity() >= size); // sanity check + std::uninitialized_copy(begin, end, begin_); + end_ = begin_ + size; + } + + SmallVector(const SmallVector& other) : allocator_(other.allocator_) { + *this = other; + } + SmallVector(const SmallVector& other, const Allocator& allocator) + : allocator_(allocator) { + *this = other; + } + SmallVector(SmallVector&& other) : allocator_(std::move(other.allocator_)) { + *this = std::move(other); + } + SmallVector(SmallVector&& other, const Allocator& allocator) + : allocator_(allocator) { + *this = std::move(other); + } + + ~SmallVector() { + free_storage(); + } + + SmallVector& operator=(const SmallVector& other) { + if (this == &other) { + return *this; + } + size_t other_size = other.size(); + if (capacity() < other_size) { + // Create large-enough heap-allocated storage. + free_storage(); + begin_ = allocator_.allocate(other_size); + end_of_storage_ = begin_ + other_size; + std::uninitialized_copy(other.begin_, other.end_, begin_); + } else if constexpr (kHasTrivialElement) { + std::copy(other.begin_, other.end_, begin_); + } else { + ptrdiff_t to_copy = + std::min(static_cast(other_size), end_ - begin_); + std::copy(other.begin_, other.begin_ + to_copy, begin_); + if (other.begin_ + to_copy < other.end_) { + std::uninitialized_copy( + other.begin_ + to_copy, other.end_, begin_ + to_copy); + } else { + std::destroy_n(begin_ + to_copy, size() - to_copy); + } + } + end_ = begin_ + other_size; + return *this; + } + + SmallVector& operator=(SmallVector&& other) { + if (this == &other) { + return *this; + } + if (other.is_big()) { + free_storage(); + begin_ = other.begin_; + end_ = other.end_; + end_of_storage_ = other.end_of_storage_; + } else { + assert(capacity() >= other.size()); // sanity check + size_t other_size = other.size(); + if constexpr (kHasTrivialElement) { + std::move(other.begin_, other.end_, begin_); + } else { + ptrdiff_t to_move = + std::min(static_cast(other_size), end_ - begin_); + std::move(other.begin_, other.begin_ + to_move, begin_); + if (other.begin_ + to_move < other.end_) { + std::uninitialized_move( + other.begin_ + to_move, other.end_, begin_ + to_move); + } else { + std::destroy_n(begin_ + to_move, size() - to_move); + } + } + end_ = begin_ + other_size; + } + other.reset_to_inline_storage(); + return *this; + } + + bool operator==(const SmallVector& other) const { + if (size() != other.size()) { + return false; + } + return std::equal(begin_, end_, other.begin_); + } + + bool operator!=(const SmallVector& other) const { + return !(*this == other); + } + + T* data() { + return begin_; + } + const T* data() const { + return begin_; + } + + iterator begin() { + return begin_; + } + const_iterator begin() const { + return begin_; + } + + iterator end() { + return end_; + } + const_iterator end() const { + return end_; + } + + const_iterator cbegin() const { + return begin_; + } + + const_iterator cend() const { + return end_; + } + + auto rbegin() { + return std::make_reverse_iterator(end_); + } + auto rbegin() const { + return std::make_reverse_iterator(end_); + } + + auto rend() { + return std::make_reverse_iterator(begin_); + } + auto rend() const { + return std::make_reverse_iterator(begin_); + } + + size_t size() const { + return end_ - begin_; + } + bool empty() const { + return end_ == begin_; + } + size_t capacity() const { + return end_of_storage_ - begin_; + } + + T& front() { + assert(size() != 0); + return begin_[0]; + } + const T& front() const { + assert(size() != 0); + return begin_[0]; + } + + T& back() { + assert(size() != 0); + return end_[-1]; + } + const T& back() const { + assert(size() != 0); + return end_[-1]; + } + + T& at(size_t index) { + if (index >= size()) { + throw std::out_of_range("SmallVector out of range."); + } + return begin_[index]; + } + const T& at(size_t index) const { + return const_cast(this)->at(index); + } + + T& operator[](size_t index) { + assert(size() > index); + return begin_[index]; + } + const T& operator[](size_t index) const { + return const_cast(this)->operator[](index); + } + + template + void emplace_back(Args&&... args) { + if (MLX_UNLIKELY(end_ == end_of_storage_)) { + grow(); + } + void* storage = end_; + end_ += 1; + new (storage) T(std::forward(args)...); + } + + void push_back(T x) { + emplace_back(std::move(x)); + } + + void pop_back(size_t count = 1) { + assert(size() >= count); + end_ -= count; + std::destroy_n(end_, count); + } + + iterator insert(iterator pos, T value) { + return insert(pos, static_cast(1), std::move(value)); + } + + iterator insert(iterator pos, size_t count, T value) { + assert(pos <= end_); + size_t offset = pos - begin_; + size_t old_size = size(); + resize(old_size + count); + pos = begin_ + offset; + iterator old_end = begin_ + old_size; + assert(old_end <= end_); + std::move_backward(pos, old_end, end_); + if constexpr (kHasTrivialElement) { + std::fill_n(pos, count, value); + } else { + std::fill_n(pos + 1, count - 1, value); + *pos = std::move(value); + } + return pos; + } + + template >> + iterator insert(iterator pos, Iter begin, Iter end) { + if constexpr (std::is_same_v, iterator>) { + // The implementation can not take overlapping range. + assert(!(begin >= pos && begin < pos + std::distance(begin, end))); + assert(!(end > pos && end <= pos + std::distance(begin, end))); + } + + assert(pos <= end_); + size_t offset = pos - begin_; + size_t count = std::distance(begin, end); + size_t old_size = size(); + resize(old_size + count); + pos = begin_ + offset; + iterator old_end = begin_ + old_size; + assert(old_end <= end_); + std::move_backward(pos, old_end, end_); + std::copy(begin, end, pos); + return pos; + } + + iterator insert(iterator pos, std::initializer_list values) { + return insert(pos, values.begin(), values.end()); + } + + iterator erase(iterator erase_start, iterator erase_end) { + assert(erase_start >= begin_); + assert(erase_start <= erase_end); + assert(erase_end <= end_); + iterator new_end = std::move(erase_end, end_, erase_start); + std::destroy_n(new_end, std::distance(new_end, end_)); + end_ = new_end; + return erase_start; + } + + iterator erase(iterator pos) { + return erase(pos, pos + 1); + } + + void resize(size_t new_size) { + if (new_size > capacity()) { + grow(new_size); + } + T* new_end = begin_ + new_size; + if constexpr (!kHasTrivialElement) { + if (new_end > end_) { + std::uninitialized_default_construct(end_, new_end); + } else { + std::destroy_n(new_end, end_ - new_end); + } + } + end_ = new_end; + } + + void resize(size_t new_size, const T& initial_value) { + if (new_size > capacity()) { + grow(new_size); + } + T* new_end = begin_ + new_size; + if (new_end > end_) { + std::uninitialized_fill(end_, new_end, initial_value); + } else { + std::destroy_n(new_end, end_ - new_end); + } + end_ = new_end; + } + + void reserve(size_t new_capacity) { + if (new_capacity > capacity()) { + grow(new_capacity); + } + } + + // Clear without reverting back to inline storage. + void clear() { + std::destroy_n(begin_, end_ - begin_); + end_ = begin_; + } + + // Grows the backing store by a factor of two, and at least to {min_capacity}. + // TODO: Move to private after removing external code using this method. + MLX_NOINLINE void grow(size_t min_capacity = 0) { + size_t new_capacity = std::max(min_capacity, 2 * capacity()); + // Round up to power of 2. + new_capacity--; + new_capacity |= new_capacity >> 1; + new_capacity |= new_capacity >> 2; + new_capacity |= new_capacity >> 4; + new_capacity |= new_capacity >> 8; + new_capacity |= new_capacity >> 16; + if constexpr (sizeof(size_t) == sizeof(uint64_t)) { + new_capacity |= new_capacity >> 32; + } + new_capacity++; + + T* new_storage = allocator_.allocate(new_capacity); + if (new_storage == nullptr) { + throw std::bad_alloc(); + } + + size_t in_use = end_ - begin_; + std::uninitialized_move(begin_, end_, new_storage); + free_storage(); + begin_ = new_storage; + end_ = new_storage + in_use; + end_of_storage_ = new_storage + new_capacity; + } + + private: + MLX_NOINLINE void free_storage() { + std::destroy_n(begin_, end_ - begin_); + if (is_big()) { + allocator_.deallocate(begin_, end_of_storage_ - begin_); + } + } + + // Clear and go back to inline storage. Dynamic storage is *not* freed. For + // internal use only. + void reset_to_inline_storage() { + if constexpr (!kHasTrivialElement) { + if (!is_big()) + std::destroy_n(begin_, end_ - begin_); + } + begin_ = inline_storage_begin(); + end_ = begin_; + end_of_storage_ = begin_ + kSize; + } + + bool is_big() const { + return begin_ != inline_storage_begin(); + } + + T* inline_storage_begin() { + return reinterpret_cast(inline_storage_); + } + const T* inline_storage_begin() const { + return reinterpret_cast(inline_storage_); + } + + Allocator allocator_; + + // Invariants: + // 1. The elements in the range between `begin_` (included) and `end_` (not + // included) will be initialized at all times. + // 2. All other elements outside the range, both in the inline storage and in + // the dynamic storage (if it exists), will be uninitialized at all times. + + T* begin_ = inline_storage_begin(); + T* end_ = begin_; + T* end_of_storage_ = begin_ + kSize; + + alignas(T) char inline_storage_[sizeof(T) * kSize]; + + static constexpr bool kHasTrivialElement = + std::is_trivially_copyable::value && + std::is_trivially_destructible::value; +}; + +#undef MLX_HAS_BUILTIN +#undef MLX_HAS_ATTRIBUTE +#undef MLX_LIKELY +#undef MLX_UNLIKELY +#undef MLX_NOINLINE + +} // namespace mlx::core diff --git a/mlx/utils.cpp b/mlx/utils.cpp index e53a7a97fa..eac18239ee 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -259,6 +259,25 @@ std::ostream& operator<<(std::ostream& os, array a) { return os; } +std::ostream& operator<<(std::ostream& os, const SmallVector& v) { + os << "("; + for (int i = 0; i < v.size(); ++i) { + os << v[i] << ((i == v.size() - 1) ? "" : ","); + } + os << ")"; + return os; +} + +std::ostream& operator<<(std::ostream& os, const SmallVector& v) { + os << "("; + for (int i = 0; i < v.size(); ++i) { + os << v[i] << ((i == v.size() - 1) ? "" : ","); + } + os << ")"; + return os; +} + +// TODO: Code is duplicated but they should be deleted soon. std::ostream& operator<<(std::ostream& os, const std::vector& v) { os << "("; for (int i = 0; i < v.size(); ++i) { diff --git a/mlx/utils.h b/mlx/utils.h index f16bf0468d..4513935407 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -100,6 +100,8 @@ std::ostream& operator<<(std::ostream& os, const Stream& s); std::ostream& operator<<(std::ostream& os, const Dtype& d); std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k); std::ostream& operator<<(std::ostream& os, array a); +std::ostream& operator<<(std::ostream& os, const SmallVector& v); +std::ostream& operator<<(std::ostream& os, const SmallVector& v); std::ostream& operator<<(std::ostream& os, const std::vector& v); std::ostream& operator<<(std::ostream& os, const std::vector& v); inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { diff --git a/python/src/array.cpp b/python/src/array.cpp index 25889d775e..22ef8e2733 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -15,6 +15,7 @@ #include "python/src/buffer.h" #include "python/src/convert.h" #include "python/src/indexing.h" +#include "python/src/small_vector.h" #include "python/src/utils.h" #include "mlx/mlx.h" @@ -303,7 +304,7 @@ void init_array(nb::module_& m) { R"pbdoc(The number of bytes in the array.)pbdoc") .def_prop_ro( "shape", - [](const mx::array& a) { return nb::tuple(nb::cast(a.shape())); }, + [](const mx::array& a) { return nb::cast(a.shape()); }, R"pbdoc( The shape of the array as a Python tuple. diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp index c9acc85836..b52fa86c07 100644 --- a/python/src/distributed.cpp +++ b/python/src/distributed.cpp @@ -9,7 +9,7 @@ #include "mlx/distributed/distributed.h" #include "mlx/distributed/ops.h" - +#include "python/src/small_vector.h" #include "python/src/utils.h" namespace mx = mlx::core; diff --git a/python/src/export.cpp b/python/src/export.cpp index 30062ae371..4428e7cc81 100644 --- a/python/src/export.cpp +++ b/python/src/export.cpp @@ -10,6 +10,7 @@ #include "mlx/array.h" #include "mlx/export.h" #include "mlx/graph_utils.h" +#include "python/src/small_vector.h" #include "python/src/trees.h" namespace mx = mlx::core; diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 8adba2a258..3d0bc41478 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -8,10 +8,10 @@ #include #include -#include "python/src/utils.h" - #include "mlx/fast.h" #include "mlx/ops.h" +#include "python/src/small_vector.h" +#include "python/src/utils.h" namespace mx = mlx::core; namespace nb = nanobind; diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 026f8139d2..522e1064c7 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -8,6 +8,7 @@ #include "mlx/fft.h" #include "mlx/ops.h" +#include "python/src/small_vector.h" namespace mx = mlx::core; namespace nb = nanobind; diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 634abaef47..83d76979f7 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -9,6 +9,7 @@ #include #include "mlx/linalg.h" +#include "python/src/small_vector.h" namespace mx = mlx::core; namespace nb = nanobind; diff --git a/python/src/load.cpp b/python/src/load.cpp index 66e8ecc5a8..e992f2077e 100644 --- a/python/src/load.cpp +++ b/python/src/load.cpp @@ -12,6 +12,7 @@ #include "mlx/ops.h" #include "mlx/utils.h" #include "python/src/load.h" +#include "python/src/small_vector.h" #include "python/src/utils.h" namespace mx = mlx::core; diff --git a/python/src/metal.cpp b/python/src/metal.cpp index 3b2f4a53ae..a56674428b 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -7,8 +7,10 @@ #include #include #include + #include "mlx/backend/metal/metal.h" #include "mlx/memory.h" +#include "python/src/small_vector.h" namespace mx = mlx::core; namespace nb = nanobind; diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 5e00d3073c..af64d9dfcc 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -16,6 +16,7 @@ #include "mlx/ops.h" #include "mlx/utils.h" #include "python/src/load.h" +#include "python/src/small_vector.h" #include "python/src/utils.h" namespace mx = mlx::core; diff --git a/python/src/random.cpp b/python/src/random.cpp index 837f91616a..ebd09863a5 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -7,10 +7,10 @@ #include -#include "python/src/utils.h" - #include "mlx/ops.h" #include "mlx/random.h" +#include "python/src/small_vector.h" +#include "python/src/utils.h" namespace mx = mlx::core; namespace nb = nanobind; diff --git a/python/src/small_vector.h b/python/src/small_vector.h new file mode 100644 index 0000000000..c7b2687215 --- /dev/null +++ b/python/src/small_vector.h @@ -0,0 +1,77 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/small_vector.h" + +#include + +NAMESPACE_BEGIN(NB_NAMESPACE) +NAMESPACE_BEGIN(detail) + +template +struct type_caster> { + using List = mlx::core::SmallVector; + using Caster = make_caster; + + NB_TYPE_CASTER( + List, + const_name(NB_TYPING_TUPLE "[") + make_caster::Name + + const_name(", ...]")) + + bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept { + size_t size; + PyObject* temp; + + // Will initialize 'size' and 'temp'. All return values and + // return parameters are zero/NULL in the case of a failure. + PyObject** o = seq_get(src.ptr(), &size, &temp); + + value.clear(); + value.reserve(size); + + Caster caster; + bool success = o != nullptr; + + flags = flags_for_local_caster(flags); + + for (size_t i = 0; i < size; ++i) { + if (!caster.from_python(o[i], flags, cleanup) || + !caster.template can_cast()) { + success = false; + break; + } + + value.push_back(caster.operator cast_t()); + } + + Py_XDECREF(temp); + + return success; + } + + template + static handle from_cpp(T&& src, rv_policy policy, cleanup_list* cleanup) { + object ret = steal(PyTuple_New(src.size())); + + if (ret.is_valid()) { + Py_ssize_t index = 0; + + for (auto&& value : src) { + handle h = Caster::from_cpp(forward_like_(value), policy, cleanup); + + if (!h.is_valid()) { + ret.reset(); + break; + } + + NB_TUPLE_SET_ITEM(ret.ptr(), index++, h.ptr()); + } + } + + return ret.release(); + } +}; + +NAMESPACE_END(detail) +NAMESPACE_END(NB_NAMESPACE) diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 2506f50b06..d88bc5f190 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -20,6 +20,7 @@ #include "mlx/transforms_impl.h" #include "mlx/utils.h" #include "python/src/mlx_func.h" +#include "python/src/small_vector.h" #include "python/src/trees.h" namespace mx = mlx::core;