Use SmallVector for shapes and strides (#2454)

* Use SmallVector for shapes and strides

* Convert SmallVector to tuple
This commit is contained in:
Cheng 2025-08-05 09:41:03 +09:00 committed by GitHub
parent 7d86a5c108
commit 828c5f1137
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 738 additions and 102 deletions

View File

@ -10,6 +10,7 @@
#include "mlx/allocator.h"
#include "mlx/dtype.h"
#include "mlx/event.h"
#include "mlx/small_vector.h"
namespace mlx::core {
@ -18,8 +19,8 @@ class Primitive;
using Deleter = std::function<void(allocator::Buffer)>;
using ShapeElem = int32_t;
using Shape = std::vector<ShapeElem>;
using Strides = std::vector<int64_t>;
using Shape = SmallVector<ShapeElem>;
using Strides = SmallVector<int64_t>;
class array {
/* An array is really a node in a graph. It contains a shared ArrayDesc

View File

@ -200,7 +200,7 @@ void shared_buffer_reshape(
array swapaxes_in_eval(const array& x, int axis1, int axis2);
template <typename T>
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
vec.erase(std::next(vec.begin(), index));
return vec;
}

View File

@ -288,6 +288,14 @@ void Compiled::eval_cpu(
auto [contiguous, shape, strides] =
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
// Force allocating shape/strides on heap so we can take their data() first
// and then std::move them.
// TODO: Refactor code to avoid heap allocation.
shape.grow();
for (auto& s : strides) {
s.grow();
}
// Collect function input arguments.
std::vector<void*> args;
int strides_index = 1;

View File

@ -8,7 +8,7 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
namespace mlx::core {
@ -333,47 +333,24 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
int axis = axis_;
if (axis < 0) {
axis += in.ndim();
}
// Copy input to output
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
CopyType ctype = (in.flags().contiguous && in.strides()[axis] != 0)
? CopyType::Vector
: CopyType::General;
copy_cpu(in, out, ctype, stream());
auto& encoder = cpu::get_command_encoder(stream());
encoder.set_output_array(out);
encoder.dispatch(
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
switch (out.dtype()) {
case bool_:
return sort<bool>(out, axis_);
case uint8:
return sort<uint8_t>(out, axis_);
case uint16:
return sort<uint16_t>(out, axis_);
case uint32:
return sort<uint32_t>(out, axis_);
case uint64:
return sort<uint64_t>(out, axis_);
case int8:
return sort<int8_t>(out, axis_);
case int16:
return sort<int16_t>(out, axis_);
case int32:
return sort<int32_t>(out, axis_);
case int64:
return sort<int64_t>(out, axis_);
case float32:
return sort<float>(out, axis_);
case float64:
return sort<double>(out, axis_);
case float16:
return sort<float16_t>(out, axis_);
case bfloat16:
return sort<bfloat16_t>(out, axis_);
case complex64:
return sort<complex64_t>(out, axis_);
}
});
encoder.dispatch([out = array::unsafe_weak_copy(out), axis]() mutable {
dispatch_all_types(out.dtype(), [&](auto type_tag) {
sort<MLX_GET_TYPE(type_tag)>(out, axis);
});
});
}
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {

View File

@ -55,13 +55,13 @@ auto& conv_cache() {
return cache;
}
template <typename T, typename U>
inline std::vector<T> convert_vector(const std::vector<U>& vec) {
return std::vector<T>(vec.begin(), vec.end());
template <typename T, typename Vec>
inline SmallVector<T> convert_vector(const Vec& vec) {
return SmallVector<T>(vec.begin(), vec.end());
}
template <typename T>
inline std::array<T, MAX_NDIM> fixed_vector(const std::vector<T>& vec) {
template <typename T, template <typename U> class Vec>
inline std::array<T, MAX_NDIM> fixed_vector(const Vec<T>& vec) {
if (vec.size() > MAX_NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
@ -78,7 +78,7 @@ auto nhwc_to_nchw(const array& x) {
auto strides = convert_vector<int64_t>(x.strides());
strides.insert(strides.begin() + 1, strides.back());
strides.erase(strides.end() - 1);
return std::make_tuple(shape, strides);
return std::make_tuple(std::move(shape), std::move(strides));
}
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
@ -298,10 +298,10 @@ std::optional<cudnn_frontend::OperationGraph> build_op_graph(
array& x,
array& w,
array& y,
const std::vector<int64_t>& stride,
const std::vector<int64_t>& padding_lo,
const std::vector<int64_t>& padding_hi,
const std::vector<int64_t>& dilation) {
const SmallVector<int64_t>& stride,
const SmallVector<int64_t>& padding_lo,
const SmallVector<int64_t>& padding_hi,
const SmallVector<int64_t>& dilation) {
try {
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
? CUDNN_DATA_FLOAT
@ -468,7 +468,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
// There is no reliable way to deduce the proper cuDNN backend for the
// convolution, so we make a best guess and then try.
std::vector<cudnnBackendDescriptorType_t> try_backends;
SmallVector<cudnnBackendDescriptorType_t, 2> try_backends;
if (flip_) {
// When weight is flipped, we assume it is backward input convolution.
try_backends.push_back(CONV_BACKWARD_INPUT);

View File

@ -29,12 +29,12 @@ void append_indices_arg(
const std::vector<array>& inputs,
int nidx,
int idx_ndim) {
std::vector<const void*> indices(nidx);
SmallVector<const void*> indices(nidx);
for (int i = 0; i < nidx; ++i) {
indices[i] = inputs[i + 1].data<void>();
}
args.append(std::move(indices));
std::vector<int32_t> indices_shape(nidx * idx_ndim);
SmallVector<int32_t> indices_shape(nidx * idx_ndim);
for (int i = 0; i < nidx; ++i) {
std::copy_n(
inputs[i + 1].shape().begin(),
@ -42,7 +42,7 @@ void append_indices_arg(
indices_shape.data() + i * idx_ndim);
}
args.append(std::move(indices_shape));
std::vector<int64_t> indices_strides(nidx * idx_ndim);
SmallVector<int64_t> indices_strides(nidx * idx_ndim);
for (int i = 0; i < nidx; ++i) {
std::copy_n(
inputs[i + 1].strides().begin(),
@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
args.append<int32_t>(src.ndim());
args.append_ndim(slice_sizes_);
args.append(slice_size);
args.append(axes_);
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format(
@ -211,7 +211,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
args.append_ndim(out.shape());
args.append_ndim(out.strides());
args.append<int32_t>(out.ndim());
args.append(axes_);
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
append_indices_arg(args, inputs, nidx, idx_ndim);
std::string kernel_name = fmt::format(

View File

@ -40,19 +40,14 @@ struct KernelArgs {
}
template <typename T>
void append(std::vector<T> vec) {
if (vec.empty()) {
// The nullptr can not be used as arg, pass something not null.
append(std::monostate{});
} else {
append_ptr(vec.data());
storage_.emplace_back(std::move(vec));
}
void append(SmallVector<T> vec) {
storage_.emplace_back(std::move(vec));
append_ptr(std::get<SmallVector<T>>(storage_.back()).data());
}
// Make sure the arg is copied to an array with size of NDIM.
template <size_t NDIM = MAX_NDIM, typename T>
void append_ndim(std::vector<T> vec) {
void append_ndim(SmallVector<T> vec) {
if (vec.size() > NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", NDIM));
@ -76,9 +71,9 @@ struct KernelArgs {
int32_t,
uint32_t,
int64_t,
std::vector<const void*>,
std::vector<int32_t>,
std::vector<int64_t>>;
SmallVector<const void*>,
SmallVector<int32_t>,
SmallVector<int64_t>>;
std::deque<Arg> storage_;
};

View File

@ -101,7 +101,7 @@ inline constexpr bool is_inexact_v = is_floating_v<T> || is_complex_v<T>;
// Utility to copy data from vector to array in host.
template <int NDIM = MAX_NDIM, typename T = int32_t>
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
inline cuda::std::array<T, NDIM> const_param(const SmallVector<T>& vec) {
if (vec.size() > NDIM) {
throw std::runtime_error(
fmt::format("ndim can not be larger than {}.", NDIM));

View File

@ -28,15 +28,20 @@ inline array ensure_row_contiguous_matrix(
const array& x,
cu::CommandEncoder& enc,
const Stream& s) {
auto stride_0 = x.strides()[x.ndim() - 2];
auto stride_1 = x.strides()[x.ndim() - 1];
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
if (x.ndim() < 2) {
if (x.strides()[0] == 1) {
return x;
}
} else {
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
auto stride_0 = x.strides()[x.ndim() - 2];
auto stride_1 = x.strides()[x.ndim() - 1];
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
}
}
array x_copy = contiguous_copy_gpu(x, s);
enc.add_temporary(x_copy);
return x_copy;
}
} // namespace

View File

@ -60,6 +60,16 @@ struct CommandEncoder {
enc_->updateFence(fence);
}
template <typename T>
void set_vector_bytes(const SmallVector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
}
template <typename T>
void set_vector_bytes(const SmallVector<T>& vec, int idx) {
return set_vector_bytes(vec, vec.size(), idx);
}
// TODO: Code is duplicated but they should be deleted soon.
template <typename T>
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);

View File

@ -32,15 +32,20 @@ inline array ensure_row_contiguous_matrix(
const array& x,
metal::Device& d,
const Stream& s) {
auto stride_0 = x.strides()[x.ndim() - 2];
auto stride_1 = x.strides()[x.ndim() - 1];
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
if (x.ndim() < 2) {
if (x.strides()[0] == 1) {
return x;
}
} else {
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return x_copy;
auto stride_0 = x.strides()[x.ndim() - 2];
auto stride_1 = x.strides()[x.ndim() - 1];
if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x;
}
}
array x_copy = contiguous_copy_gpu(x, s);
d.add_temporary(x_copy, s.index);
return x_copy;
}
inline int get_qmv_batch_limit(int D, int O, metal::Device& d) {

View File

@ -179,7 +179,7 @@ void serialize(Writer& os, const array& arr) {
}
template <>
array deserialize(Reader& is) {
auto shape = deserialize<std::vector<int>>(is);
auto shape = deserialize<Shape>(is);
auto type = deserialize<Dtype>(is);
return array(std::move(shape), type, nullptr, std::vector<array>{});
}
@ -640,7 +640,7 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
auto outputs = arr.outputs();
serialize(os, arrays_to_ids(outputs));
std::vector<std::vector<int>> shapes;
std::vector<Shape> shapes;
std::vector<Dtype> dtypes;
for (auto& o : outputs) {
shapes.push_back(o.shape());
@ -813,14 +813,14 @@ ImportedFunction::ImportedFunction(const std::string& file)
std::shared_ptr<Primitive> prim = factory.load(is);
auto num_siblings = deserialize<uint64_t>(is);
if (num_siblings == 0) {
auto shape = deserialize<std::vector<int>>(is);
auto shape = deserialize<Shape>(is);
auto type = deserialize<Dtype>(is);
tape.emplace_back(
std::move(shape), type, std::move(prim), std::move(inputs));
array_map.emplace(id, tape.back());
} else {
auto ids = deserialize<std::vector<uint64_t>>(is);
auto shapes = deserialize<std::vector<std::vector<int>>>(is);
auto shapes = deserialize<std::vector<Shape>>(is);
auto types = deserialize<std::vector<Dtype>>(is);
auto arrays = array::make_arrays(
std::move(shapes),
@ -841,7 +841,7 @@ ImportedFunction::ImportedFunction(const std::string& file)
if (auto it = constants.find(id); it != constants.end()) {
tape.push_back(it->second);
} else {
auto shape = deserialize<std::vector<int>>(is);
auto shape = deserialize<Shape>(is);
auto type = deserialize<Dtype>(is);
size_t offset = is.tell();
tape.push_back(array(

View File

@ -3027,9 +3027,9 @@ array kron(const array& a, const array& b, StreamOrDevice s /* = {} */) {
}
int ndim = std::max(a.ndim(), b.ndim());
std::vector<int> a_shape(2 * ndim, 1);
std::vector<int> b_shape(2 * ndim, 1);
std::vector<int> out_shape(ndim, 1);
Shape a_shape(2 * ndim, 1);
Shape b_shape(2 * ndim, 1);
Shape out_shape(ndim, 1);
for (int i = ndim - 1, j = a.ndim() - 1; j >= 0; j--, i--) {
a_shape[2 * i] = a.shape(j);

View File

@ -1205,8 +1205,8 @@ array conv_weight_backward_patches(
auto in_padded =
pad(in,
padded_axes,
Shape(padding_lo),
Shape(padding_hi),
Shape(padding_lo.begin(), padding_lo.end()),
Shape(padding_hi.begin(), padding_hi.end()),
array(0, in.dtype()),
"constant",
s);

View File

@ -592,7 +592,7 @@ class Broadcast : public UnaryPrimitive {
static Shape output_shape(const std::vector<array>& inputs);
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
bool is_equivalent(const Primitive& other) const override;
std::vector<int> state() const {
Shape state() const {
return shape_;
};
@ -1146,7 +1146,7 @@ class Gather : public UnaryPrimitive {
DEFINE_NAME(Gather)
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
std::pair<std::vector<int>, std::vector<int>> state() const {
std::pair<std::vector<int>, Shape> state() const {
return {axes_, slice_sizes_};
}
@ -1668,7 +1668,7 @@ class RandomBits : public UnaryPrimitive {
DEFINE_VMAP()
DEFINE_NAME(RandomBits)
bool is_equivalent(const Primitive& other) const override;
std::pair<std::vector<int>, int> state() const {
std::pair<Shape, int> state() const {
return {shape_, width_};
};
@ -1703,7 +1703,7 @@ class Reshape : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_NAME(Reshape)
bool is_equivalent(const Primitive& other) const override;
std::vector<int> state() const {
Shape state() const {
return shape_;
};
static Shape output_shape(const array& input, Shape shape);
@ -2121,7 +2121,7 @@ class Split : public Primitive {
DEFINE_GRADS()
DEFINE_NAME(Split)
bool is_equivalent(const Primitive& other) const override;
std::pair<std::vector<int>, int> state() const {
std::pair<Shape, int> state() const {
return {indices_, axis_};
};

528
mlx/small_vector.h Normal file
View File

@ -0,0 +1,528 @@
// Copyright © 2025 Apple Inc.
// Copyright © 2018 the V8 project authors.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following
// disclaimer in the documentation and/or other materials provided
// with the distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <algorithm>
#include <cassert>
#include <type_traits>
#include <utility>
namespace mlx::core {
#if defined(__has_builtin)
#define MLX_HAS_BUILTIN(x) __has_builtin(x)
#else
#define MLX_HAS_BUILTIN(x) 0
#endif
#if defined(__has_attribute)
#define MLX_HAS_ATTRIBUTE(x) __has_attribute(x)
#else
#define MLX_HAS_ATTRIBUTE(x) 0
#endif
#if MLX_HAS_BUILTIN(__builtin_expect)
#define MLX_LIKELY(condition) (__builtin_expect(!!(condition), 1))
#define MLX_UNLIKELY(condition) (__builtin_expect(!!(condition), 0))
#else
#define MLX_LIKELY(condition) (condition)
#define MLX_UNLIKELY(condition) (condition)
#endif
#if MLX_HAS_ATTRIBUTE(noinline)
#define MLX_NOINLINE __attribute__((noinline))
#else
#define MLX_NOINLINE
#endif
template <typename T, typename = void>
struct is_iterator : std::false_type {};
template <typename T>
struct is_iterator<
T,
std::void_t<
typename std::iterator_traits<T>::difference_type,
typename std::iterator_traits<T>::iterator_category,
typename std::iterator_traits<T>::pointer,
typename std::iterator_traits<T>::reference,
typename std::iterator_traits<T>::value_type>> : std::true_type {};
template <typename T>
constexpr bool is_iterator_v = is_iterator<T>::value;
// Minimal SmallVector implementation. Uses inline storage first, switches to
// dynamic storage when it overflows.
//
// Notes:
// * The default inline storage size is MAX_NDIM, as it is mainly used for
// shapes and strides, users should choose a better size for other cases.
// * The data() returns real address even for empty vector.
// * The pointer returned by data() will change after moving the vector as it
// points to the inline storage.
// * For trivial elements the storage will not be default constructed,
// i.e. SmallVector<int>(10) will not be filled with 0 by default.
template <typename T, size_t kSize = 10, typename Allocator = std::allocator<T>>
class SmallVector {
public:
using value_type = T;
using reference = T&;
using const_reference = const T&;
using iterator = T*;
using const_iterator = const T*;
using difference_type = std::ptrdiff_t;
using size_type = std::size_t;
SmallVector() = default;
explicit SmallVector(const Allocator& allocator) : allocator_(allocator) {}
explicit SmallVector(size_t size, const Allocator& allocator = Allocator())
: allocator_(allocator) {
resize(size);
}
SmallVector(
size_t size,
const T& initial_value,
const Allocator& allocator = Allocator())
: allocator_(allocator) {
resize(size, initial_value);
}
SmallVector(
std::initializer_list<T> init,
const Allocator& allocator = Allocator())
: allocator_(allocator) {
if (init.size() > capacity()) {
grow(init.size());
}
assert(capacity() >= init.size()); // sanity check
std::uninitialized_move(init.begin(), init.end(), begin_);
end_ = begin_ + init.size();
}
template <typename Iter, typename = std::enable_if_t<is_iterator_v<Iter>>>
SmallVector(Iter begin, Iter end, const Allocator& allocator = Allocator())
: allocator_(allocator) {
size_t size = std::distance(begin, end);
if (size > capacity()) {
grow(size);
}
assert(capacity() >= size); // sanity check
std::uninitialized_copy(begin, end, begin_);
end_ = begin_ + size;
}
SmallVector(const SmallVector& other) : allocator_(other.allocator_) {
*this = other;
}
SmallVector(const SmallVector& other, const Allocator& allocator)
: allocator_(allocator) {
*this = other;
}
SmallVector(SmallVector&& other) : allocator_(std::move(other.allocator_)) {
*this = std::move(other);
}
SmallVector(SmallVector&& other, const Allocator& allocator)
: allocator_(allocator) {
*this = std::move(other);
}
~SmallVector() {
free_storage();
}
SmallVector& operator=(const SmallVector& other) {
if (this == &other) {
return *this;
}
size_t other_size = other.size();
if (capacity() < other_size) {
// Create large-enough heap-allocated storage.
free_storage();
begin_ = allocator_.allocate(other_size);
end_of_storage_ = begin_ + other_size;
std::uninitialized_copy(other.begin_, other.end_, begin_);
} else if constexpr (kHasTrivialElement) {
std::copy(other.begin_, other.end_, begin_);
} else {
ptrdiff_t to_copy =
std::min(static_cast<ptrdiff_t>(other_size), end_ - begin_);
std::copy(other.begin_, other.begin_ + to_copy, begin_);
if (other.begin_ + to_copy < other.end_) {
std::uninitialized_copy(
other.begin_ + to_copy, other.end_, begin_ + to_copy);
} else {
std::destroy_n(begin_ + to_copy, size() - to_copy);
}
}
end_ = begin_ + other_size;
return *this;
}
SmallVector& operator=(SmallVector&& other) {
if (this == &other) {
return *this;
}
if (other.is_big()) {
free_storage();
begin_ = other.begin_;
end_ = other.end_;
end_of_storage_ = other.end_of_storage_;
} else {
assert(capacity() >= other.size()); // sanity check
size_t other_size = other.size();
if constexpr (kHasTrivialElement) {
std::move(other.begin_, other.end_, begin_);
} else {
ptrdiff_t to_move =
std::min(static_cast<ptrdiff_t>(other_size), end_ - begin_);
std::move(other.begin_, other.begin_ + to_move, begin_);
if (other.begin_ + to_move < other.end_) {
std::uninitialized_move(
other.begin_ + to_move, other.end_, begin_ + to_move);
} else {
std::destroy_n(begin_ + to_move, size() - to_move);
}
}
end_ = begin_ + other_size;
}
other.reset_to_inline_storage();
return *this;
}
bool operator==(const SmallVector& other) const {
if (size() != other.size()) {
return false;
}
return std::equal(begin_, end_, other.begin_);
}
bool operator!=(const SmallVector& other) const {
return !(*this == other);
}
T* data() {
return begin_;
}
const T* data() const {
return begin_;
}
iterator begin() {
return begin_;
}
const_iterator begin() const {
return begin_;
}
iterator end() {
return end_;
}
const_iterator end() const {
return end_;
}
const_iterator cbegin() const {
return begin_;
}
const_iterator cend() const {
return end_;
}
auto rbegin() {
return std::make_reverse_iterator(end_);
}
auto rbegin() const {
return std::make_reverse_iterator(end_);
}
auto rend() {
return std::make_reverse_iterator(begin_);
}
auto rend() const {
return std::make_reverse_iterator(begin_);
}
size_t size() const {
return end_ - begin_;
}
bool empty() const {
return end_ == begin_;
}
size_t capacity() const {
return end_of_storage_ - begin_;
}
T& front() {
assert(size() != 0);
return begin_[0];
}
const T& front() const {
assert(size() != 0);
return begin_[0];
}
T& back() {
assert(size() != 0);
return end_[-1];
}
const T& back() const {
assert(size() != 0);
return end_[-1];
}
T& at(size_t index) {
if (index >= size()) {
throw std::out_of_range("SmallVector out of range.");
}
return begin_[index];
}
const T& at(size_t index) const {
return const_cast<SmallVector*>(this)->at(index);
}
T& operator[](size_t index) {
assert(size() > index);
return begin_[index];
}
const T& operator[](size_t index) const {
return const_cast<SmallVector*>(this)->operator[](index);
}
template <typename... Args>
void emplace_back(Args&&... args) {
if (MLX_UNLIKELY(end_ == end_of_storage_)) {
grow();
}
void* storage = end_;
end_ += 1;
new (storage) T(std::forward<Args>(args)...);
}
void push_back(T x) {
emplace_back(std::move(x));
}
void pop_back(size_t count = 1) {
assert(size() >= count);
end_ -= count;
std::destroy_n(end_, count);
}
iterator insert(iterator pos, T value) {
return insert(pos, static_cast<size_t>(1), std::move(value));
}
iterator insert(iterator pos, size_t count, T value) {
assert(pos <= end_);
size_t offset = pos - begin_;
size_t old_size = size();
resize(old_size + count);
pos = begin_ + offset;
iterator old_end = begin_ + old_size;
assert(old_end <= end_);
std::move_backward(pos, old_end, end_);
if constexpr (kHasTrivialElement) {
std::fill_n(pos, count, value);
} else {
std::fill_n(pos + 1, count - 1, value);
*pos = std::move(value);
}
return pos;
}
template <typename Iter, typename = std::enable_if_t<is_iterator_v<Iter>>>
iterator insert(iterator pos, Iter begin, Iter end) {
if constexpr (std::is_same_v<std::decay_t<Iter>, iterator>) {
// The implementation can not take overlapping range.
assert(!(begin >= pos && begin < pos + std::distance(begin, end)));
assert(!(end > pos && end <= pos + std::distance(begin, end)));
}
assert(pos <= end_);
size_t offset = pos - begin_;
size_t count = std::distance(begin, end);
size_t old_size = size();
resize(old_size + count);
pos = begin_ + offset;
iterator old_end = begin_ + old_size;
assert(old_end <= end_);
std::move_backward(pos, old_end, end_);
std::copy(begin, end, pos);
return pos;
}
iterator insert(iterator pos, std::initializer_list<const T> values) {
return insert(pos, values.begin(), values.end());
}
iterator erase(iterator erase_start, iterator erase_end) {
assert(erase_start >= begin_);
assert(erase_start <= erase_end);
assert(erase_end <= end_);
iterator new_end = std::move(erase_end, end_, erase_start);
std::destroy_n(new_end, std::distance(new_end, end_));
end_ = new_end;
return erase_start;
}
iterator erase(iterator pos) {
return erase(pos, pos + 1);
}
void resize(size_t new_size) {
if (new_size > capacity()) {
grow(new_size);
}
T* new_end = begin_ + new_size;
if constexpr (!kHasTrivialElement) {
if (new_end > end_) {
std::uninitialized_default_construct(end_, new_end);
} else {
std::destroy_n(new_end, end_ - new_end);
}
}
end_ = new_end;
}
void resize(size_t new_size, const T& initial_value) {
if (new_size > capacity()) {
grow(new_size);
}
T* new_end = begin_ + new_size;
if (new_end > end_) {
std::uninitialized_fill(end_, new_end, initial_value);
} else {
std::destroy_n(new_end, end_ - new_end);
}
end_ = new_end;
}
void reserve(size_t new_capacity) {
if (new_capacity > capacity()) {
grow(new_capacity);
}
}
// Clear without reverting back to inline storage.
void clear() {
std::destroy_n(begin_, end_ - begin_);
end_ = begin_;
}
// Grows the backing store by a factor of two, and at least to {min_capacity}.
// TODO: Move to private after removing external code using this method.
MLX_NOINLINE void grow(size_t min_capacity = 0) {
size_t new_capacity = std::max(min_capacity, 2 * capacity());
// Round up to power of 2.
new_capacity--;
new_capacity |= new_capacity >> 1;
new_capacity |= new_capacity >> 2;
new_capacity |= new_capacity >> 4;
new_capacity |= new_capacity >> 8;
new_capacity |= new_capacity >> 16;
if constexpr (sizeof(size_t) == sizeof(uint64_t)) {
new_capacity |= new_capacity >> 32;
}
new_capacity++;
T* new_storage = allocator_.allocate(new_capacity);
if (new_storage == nullptr) {
throw std::bad_alloc();
}
size_t in_use = end_ - begin_;
std::uninitialized_move(begin_, end_, new_storage);
free_storage();
begin_ = new_storage;
end_ = new_storage + in_use;
end_of_storage_ = new_storage + new_capacity;
}
private:
MLX_NOINLINE void free_storage() {
std::destroy_n(begin_, end_ - begin_);
if (is_big()) {
allocator_.deallocate(begin_, end_of_storage_ - begin_);
}
}
// Clear and go back to inline storage. Dynamic storage is *not* freed. For
// internal use only.
void reset_to_inline_storage() {
if constexpr (!kHasTrivialElement) {
if (!is_big())
std::destroy_n(begin_, end_ - begin_);
}
begin_ = inline_storage_begin();
end_ = begin_;
end_of_storage_ = begin_ + kSize;
}
bool is_big() const {
return begin_ != inline_storage_begin();
}
T* inline_storage_begin() {
return reinterpret_cast<T*>(inline_storage_);
}
const T* inline_storage_begin() const {
return reinterpret_cast<const T*>(inline_storage_);
}
Allocator allocator_;
// Invariants:
// 1. The elements in the range between `begin_` (included) and `end_` (not
// included) will be initialized at all times.
// 2. All other elements outside the range, both in the inline storage and in
// the dynamic storage (if it exists), will be uninitialized at all times.
T* begin_ = inline_storage_begin();
T* end_ = begin_;
T* end_of_storage_ = begin_ + kSize;
alignas(T) char inline_storage_[sizeof(T) * kSize];
static constexpr bool kHasTrivialElement =
std::is_trivially_copyable<T>::value &&
std::is_trivially_destructible<T>::value;
};
#undef MLX_HAS_BUILTIN
#undef MLX_HAS_ATTRIBUTE
#undef MLX_LIKELY
#undef MLX_UNLIKELY
#undef MLX_NOINLINE
} // namespace mlx::core

View File

@ -259,6 +259,25 @@ std::ostream& operator<<(std::ostream& os, array a) {
return os;
}
std::ostream& operator<<(std::ostream& os, const SmallVector<int>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
std::ostream& operator<<(std::ostream& os, const SmallVector<int64_t>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {
os << v[i] << ((i == v.size() - 1) ? "" : ",");
}
os << ")";
return os;
}
// TODO: Code is duplicated but they should be deleted soon.
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
os << "(";
for (int i = 0; i < v.size(); ++i) {

View File

@ -100,6 +100,8 @@ std::ostream& operator<<(std::ostream& os, const Stream& s);
std::ostream& operator<<(std::ostream& os, const Dtype& d);
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
std::ostream& operator<<(std::ostream& os, array a);
std::ostream& operator<<(std::ostream& os, const SmallVector<int>& v);
std::ostream& operator<<(std::ostream& os, const SmallVector<int64_t>& v);
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {

View File

@ -15,6 +15,7 @@
#include "python/src/buffer.h"
#include "python/src/convert.h"
#include "python/src/indexing.h"
#include "python/src/small_vector.h"
#include "python/src/utils.h"
#include "mlx/mlx.h"
@ -303,7 +304,7 @@ void init_array(nb::module_& m) {
R"pbdoc(The number of bytes in the array.)pbdoc")
.def_prop_ro(
"shape",
[](const mx::array& a) { return nb::tuple(nb::cast(a.shape())); },
[](const mx::array& a) { return nb::cast(a.shape()); },
R"pbdoc(
The shape of the array as a Python tuple.

View File

@ -9,7 +9,7 @@
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/ops.h"
#include "python/src/small_vector.h"
#include "python/src/utils.h"
namespace mx = mlx::core;

View File

@ -10,6 +10,7 @@
#include "mlx/array.h"
#include "mlx/export.h"
#include "mlx/graph_utils.h"
#include "python/src/small_vector.h"
#include "python/src/trees.h"
namespace mx = mlx::core;

View File

@ -8,10 +8,10 @@
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include "python/src/utils.h"
#include "mlx/fast.h"
#include "mlx/ops.h"
#include "python/src/small_vector.h"
#include "python/src/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind;

View File

@ -8,6 +8,7 @@
#include "mlx/fft.h"
#include "mlx/ops.h"
#include "python/src/small_vector.h"
namespace mx = mlx::core;
namespace nb = nanobind;

View File

@ -9,6 +9,7 @@
#include <nanobind/stl/vector.h>
#include "mlx/linalg.h"
#include "python/src/small_vector.h"
namespace mx = mlx::core;
namespace nb = nanobind;

View File

@ -12,6 +12,7 @@
#include "mlx/ops.h"
#include "mlx/utils.h"
#include "python/src/load.h"
#include "python/src/small_vector.h"
#include "python/src/utils.h"
namespace mx = mlx::core;

View File

@ -7,8 +7,10 @@
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include "mlx/backend/metal/metal.h"
#include "mlx/memory.h"
#include "python/src/small_vector.h"
namespace mx = mlx::core;
namespace nb = nanobind;

View File

@ -16,6 +16,7 @@
#include "mlx/ops.h"
#include "mlx/utils.h"
#include "python/src/load.h"
#include "python/src/small_vector.h"
#include "python/src/utils.h"
namespace mx = mlx::core;

View File

@ -7,10 +7,10 @@
#include <chrono>
#include "python/src/utils.h"
#include "mlx/ops.h"
#include "mlx/random.h"
#include "python/src/small_vector.h"
#include "python/src/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind;

77
python/src/small_vector.h Normal file
View File

@ -0,0 +1,77 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/small_vector.h"
#include <nanobind/stl/detail/nb_list.h>
NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)
template <typename Type, size_t Size, typename Alloc>
struct type_caster<mlx::core::SmallVector<Type, Size, Alloc>> {
using List = mlx::core::SmallVector<Type, Size, Alloc>;
using Caster = make_caster<Type>;
NB_TYPE_CASTER(
List,
const_name(NB_TYPING_TUPLE "[") + make_caster<Type>::Name +
const_name(", ...]"))
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {
size_t size;
PyObject* temp;
// Will initialize 'size' and 'temp'. All return values and
// return parameters are zero/NULL in the case of a failure.
PyObject** o = seq_get(src.ptr(), &size, &temp);
value.clear();
value.reserve(size);
Caster caster;
bool success = o != nullptr;
flags = flags_for_local_caster<Type>(flags);
for (size_t i = 0; i < size; ++i) {
if (!caster.from_python(o[i], flags, cleanup) ||
!caster.template can_cast<Type>()) {
success = false;
break;
}
value.push_back(caster.operator cast_t<Type>());
}
Py_XDECREF(temp);
return success;
}
template <typename T>
static handle from_cpp(T&& src, rv_policy policy, cleanup_list* cleanup) {
object ret = steal(PyTuple_New(src.size()));
if (ret.is_valid()) {
Py_ssize_t index = 0;
for (auto&& value : src) {
handle h = Caster::from_cpp(forward_like_<T>(value), policy, cleanup);
if (!h.is_valid()) {
ret.reset();
break;
}
NB_TUPLE_SET_ITEM(ret.ptr(), index++, h.ptr());
}
}
return ret.release();
}
};
NAMESPACE_END(detail)
NAMESPACE_END(NB_NAMESPACE)

View File

@ -20,6 +20,7 @@
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
#include "python/src/mlx_func.h"
#include "python/src/small_vector.h"
#include "python/src/trees.h"
namespace mx = mlx::core;