mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-12 12:16:43 +08:00
Use SmallVector for shapes and strides (#2454)
* Use SmallVector for shapes and strides * Convert SmallVector to tuple
This commit is contained in:
parent
7d86a5c108
commit
828c5f1137
@ -10,6 +10,7 @@
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/dtype.h"
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/small_vector.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@ -18,8 +19,8 @@ class Primitive;
|
||||
|
||||
using Deleter = std::function<void(allocator::Buffer)>;
|
||||
using ShapeElem = int32_t;
|
||||
using Shape = std::vector<ShapeElem>;
|
||||
using Strides = std::vector<int64_t>;
|
||||
using Shape = SmallVector<ShapeElem>;
|
||||
using Strides = SmallVector<int64_t>;
|
||||
|
||||
class array {
|
||||
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||
|
@ -200,7 +200,7 @@ void shared_buffer_reshape(
|
||||
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
|
||||
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
|
||||
vec.erase(std::next(vec.begin(), index));
|
||||
return vec;
|
||||
}
|
||||
|
@ -288,6 +288,14 @@ void Compiled::eval_cpu(
|
||||
auto [contiguous, shape, strides] =
|
||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||
|
||||
// Force allocating shape/strides on heap so we can take their data() first
|
||||
// and then std::move them.
|
||||
// TODO: Refactor code to avoid heap allocation.
|
||||
shape.grow();
|
||||
for (auto& s : strides) {
|
||||
s.grow();
|
||||
}
|
||||
|
||||
// Collect function input arguments.
|
||||
std::vector<void*> args;
|
||||
int strides_index = 1;
|
||||
|
@ -8,7 +8,7 @@
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/cpu/copy.h"
|
||||
#include "mlx/backend/cpu/encoder.h"
|
||||
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
@ -333,46 +333,23 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
|
||||
int axis = axis_;
|
||||
if (axis < 0) {
|
||||
axis += in.ndim();
|
||||
}
|
||||
|
||||
// Copy input to output
|
||||
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
|
||||
CopyType ctype = (in.flags().contiguous && in.strides()[axis] != 0)
|
||||
? CopyType::Vector
|
||||
: CopyType::General;
|
||||
copy_cpu(in, out, ctype, stream());
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch(
|
||||
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
|
||||
switch (out.dtype()) {
|
||||
case bool_:
|
||||
return sort<bool>(out, axis_);
|
||||
case uint8:
|
||||
return sort<uint8_t>(out, axis_);
|
||||
case uint16:
|
||||
return sort<uint16_t>(out, axis_);
|
||||
case uint32:
|
||||
return sort<uint32_t>(out, axis_);
|
||||
case uint64:
|
||||
return sort<uint64_t>(out, axis_);
|
||||
case int8:
|
||||
return sort<int8_t>(out, axis_);
|
||||
case int16:
|
||||
return sort<int16_t>(out, axis_);
|
||||
case int32:
|
||||
return sort<int32_t>(out, axis_);
|
||||
case int64:
|
||||
return sort<int64_t>(out, axis_);
|
||||
case float32:
|
||||
return sort<float>(out, axis_);
|
||||
case float64:
|
||||
return sort<double>(out, axis_);
|
||||
case float16:
|
||||
return sort<float16_t>(out, axis_);
|
||||
case bfloat16:
|
||||
return sort<bfloat16_t>(out, axis_);
|
||||
case complex64:
|
||||
return sort<complex64_t>(out, axis_);
|
||||
}
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out), axis]() mutable {
|
||||
dispatch_all_types(out.dtype(), [&](auto type_tag) {
|
||||
sort<MLX_GET_TYPE(type_tag)>(out, axis);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -55,13 +55,13 @@ auto& conv_cache() {
|
||||
return cache;
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
inline std::vector<T> convert_vector(const std::vector<U>& vec) {
|
||||
return std::vector<T>(vec.begin(), vec.end());
|
||||
template <typename T, typename Vec>
|
||||
inline SmallVector<T> convert_vector(const Vec& vec) {
|
||||
return SmallVector<T>(vec.begin(), vec.end());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::array<T, MAX_NDIM> fixed_vector(const std::vector<T>& vec) {
|
||||
template <typename T, template <typename U> class Vec>
|
||||
inline std::array<T, MAX_NDIM> fixed_vector(const Vec<T>& vec) {
|
||||
if (vec.size() > MAX_NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||
@ -78,7 +78,7 @@ auto nhwc_to_nchw(const array& x) {
|
||||
auto strides = convert_vector<int64_t>(x.strides());
|
||||
strides.insert(strides.begin() + 1, strides.back());
|
||||
strides.erase(strides.end() - 1);
|
||||
return std::make_tuple(shape, strides);
|
||||
return std::make_tuple(std::move(shape), std::move(strides));
|
||||
}
|
||||
|
||||
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||
@ -298,10 +298,10 @@ std::optional<cudnn_frontend::OperationGraph> build_op_graph(
|
||||
array& x,
|
||||
array& w,
|
||||
array& y,
|
||||
const std::vector<int64_t>& stride,
|
||||
const std::vector<int64_t>& padding_lo,
|
||||
const std::vector<int64_t>& padding_hi,
|
||||
const std::vector<int64_t>& dilation) {
|
||||
const SmallVector<int64_t>& stride,
|
||||
const SmallVector<int64_t>& padding_lo,
|
||||
const SmallVector<int64_t>& padding_hi,
|
||||
const SmallVector<int64_t>& dilation) {
|
||||
try {
|
||||
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
|
||||
? CUDNN_DATA_FLOAT
|
||||
@ -468,7 +468,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
||||
|
||||
// There is no reliable way to deduce the proper cuDNN backend for the
|
||||
// convolution, so we make a best guess and then try.
|
||||
std::vector<cudnnBackendDescriptorType_t> try_backends;
|
||||
SmallVector<cudnnBackendDescriptorType_t, 2> try_backends;
|
||||
if (flip_) {
|
||||
// When weight is flipped, we assume it is backward input convolution.
|
||||
try_backends.push_back(CONV_BACKWARD_INPUT);
|
||||
|
@ -29,12 +29,12 @@ void append_indices_arg(
|
||||
const std::vector<array>& inputs,
|
||||
int nidx,
|
||||
int idx_ndim) {
|
||||
std::vector<const void*> indices(nidx);
|
||||
SmallVector<const void*> indices(nidx);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
indices[i] = inputs[i + 1].data<void>();
|
||||
}
|
||||
args.append(std::move(indices));
|
||||
std::vector<int32_t> indices_shape(nidx * idx_ndim);
|
||||
SmallVector<int32_t> indices_shape(nidx * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy_n(
|
||||
inputs[i + 1].shape().begin(),
|
||||
@ -42,7 +42,7 @@ void append_indices_arg(
|
||||
indices_shape.data() + i * idx_ndim);
|
||||
}
|
||||
args.append(std::move(indices_shape));
|
||||
std::vector<int64_t> indices_strides(nidx * idx_ndim);
|
||||
SmallVector<int64_t> indices_strides(nidx * idx_ndim);
|
||||
for (int i = 0; i < nidx; ++i) {
|
||||
std::copy_n(
|
||||
inputs[i + 1].strides().begin(),
|
||||
@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
args.append<int32_t>(src.ndim());
|
||||
args.append_ndim(slice_sizes_);
|
||||
args.append(slice_size);
|
||||
args.append(axes_);
|
||||
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
|
||||
append_indices_arg(args, inputs, nidx, idx_ndim);
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
@ -211,7 +211,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
args.append_ndim(out.shape());
|
||||
args.append_ndim(out.strides());
|
||||
args.append<int32_t>(out.ndim());
|
||||
args.append(axes_);
|
||||
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
|
||||
append_indices_arg(args, inputs, nidx, idx_ndim);
|
||||
|
||||
std::string kernel_name = fmt::format(
|
||||
|
@ -40,19 +40,14 @@ struct KernelArgs {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void append(std::vector<T> vec) {
|
||||
if (vec.empty()) {
|
||||
// The nullptr can not be used as arg, pass something not null.
|
||||
append(std::monostate{});
|
||||
} else {
|
||||
append_ptr(vec.data());
|
||||
void append(SmallVector<T> vec) {
|
||||
storage_.emplace_back(std::move(vec));
|
||||
}
|
||||
append_ptr(std::get<SmallVector<T>>(storage_.back()).data());
|
||||
}
|
||||
|
||||
// Make sure the arg is copied to an array with size of NDIM.
|
||||
template <size_t NDIM = MAX_NDIM, typename T>
|
||||
void append_ndim(std::vector<T> vec) {
|
||||
void append_ndim(SmallVector<T> vec) {
|
||||
if (vec.size() > NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||
@ -76,9 +71,9 @@ struct KernelArgs {
|
||||
int32_t,
|
||||
uint32_t,
|
||||
int64_t,
|
||||
std::vector<const void*>,
|
||||
std::vector<int32_t>,
|
||||
std::vector<int64_t>>;
|
||||
SmallVector<const void*>,
|
||||
SmallVector<int32_t>,
|
||||
SmallVector<int64_t>>;
|
||||
std::deque<Arg> storage_;
|
||||
};
|
||||
|
||||
|
@ -101,7 +101,7 @@ inline constexpr bool is_inexact_v = is_floating_v<T> || is_complex_v<T>;
|
||||
|
||||
// Utility to copy data from vector to array in host.
|
||||
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
||||
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
|
||||
inline cuda::std::array<T, NDIM> const_param(const SmallVector<T>& vec) {
|
||||
if (vec.size() > NDIM) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||
|
@ -28,15 +28,20 @@ inline array ensure_row_contiguous_matrix(
|
||||
const array& x,
|
||||
cu::CommandEncoder& enc,
|
||||
const Stream& s) {
|
||||
if (x.ndim() < 2) {
|
||||
if (x.strides()[0] == 1) {
|
||||
return x;
|
||||
}
|
||||
} else {
|
||||
auto stride_0 = x.strides()[x.ndim() - 2];
|
||||
auto stride_1 = x.strides()[x.ndim() - 1];
|
||||
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
||||
return x;
|
||||
} else {
|
||||
}
|
||||
}
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
enc.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -60,6 +60,16 @@ struct CommandEncoder {
|
||||
enc_->updateFence(fence);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_vector_bytes(const SmallVector<T>& vec, size_t nelems, int idx) {
|
||||
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||
}
|
||||
template <typename T>
|
||||
void set_vector_bytes(const SmallVector<T>& vec, int idx) {
|
||||
return set_vector_bytes(vec, vec.size(), idx);
|
||||
}
|
||||
|
||||
// TODO: Code is duplicated but they should be deleted soon.
|
||||
template <typename T>
|
||||
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
|
||||
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||
|
@ -32,15 +32,20 @@ inline array ensure_row_contiguous_matrix(
|
||||
const array& x,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
if (x.ndim() < 2) {
|
||||
if (x.strides()[0] == 1) {
|
||||
return x;
|
||||
}
|
||||
} else {
|
||||
auto stride_0 = x.strides()[x.ndim() - 2];
|
||||
auto stride_1 = x.strides()[x.ndim() - 1];
|
||||
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
||||
return x;
|
||||
} else {
|
||||
}
|
||||
}
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return x_copy;
|
||||
}
|
||||
}
|
||||
|
||||
inline int get_qmv_batch_limit(int D, int O, metal::Device& d) {
|
||||
|
@ -179,7 +179,7 @@ void serialize(Writer& os, const array& arr) {
|
||||
}
|
||||
template <>
|
||||
array deserialize(Reader& is) {
|
||||
auto shape = deserialize<std::vector<int>>(is);
|
||||
auto shape = deserialize<Shape>(is);
|
||||
auto type = deserialize<Dtype>(is);
|
||||
return array(std::move(shape), type, nullptr, std::vector<array>{});
|
||||
}
|
||||
@ -640,7 +640,7 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
||||
auto outputs = arr.outputs();
|
||||
serialize(os, arrays_to_ids(outputs));
|
||||
|
||||
std::vector<std::vector<int>> shapes;
|
||||
std::vector<Shape> shapes;
|
||||
std::vector<Dtype> dtypes;
|
||||
for (auto& o : outputs) {
|
||||
shapes.push_back(o.shape());
|
||||
@ -813,14 +813,14 @@ ImportedFunction::ImportedFunction(const std::string& file)
|
||||
std::shared_ptr<Primitive> prim = factory.load(is);
|
||||
auto num_siblings = deserialize<uint64_t>(is);
|
||||
if (num_siblings == 0) {
|
||||
auto shape = deserialize<std::vector<int>>(is);
|
||||
auto shape = deserialize<Shape>(is);
|
||||
auto type = deserialize<Dtype>(is);
|
||||
tape.emplace_back(
|
||||
std::move(shape), type, std::move(prim), std::move(inputs));
|
||||
array_map.emplace(id, tape.back());
|
||||
} else {
|
||||
auto ids = deserialize<std::vector<uint64_t>>(is);
|
||||
auto shapes = deserialize<std::vector<std::vector<int>>>(is);
|
||||
auto shapes = deserialize<std::vector<Shape>>(is);
|
||||
auto types = deserialize<std::vector<Dtype>>(is);
|
||||
auto arrays = array::make_arrays(
|
||||
std::move(shapes),
|
||||
@ -841,7 +841,7 @@ ImportedFunction::ImportedFunction(const std::string& file)
|
||||
if (auto it = constants.find(id); it != constants.end()) {
|
||||
tape.push_back(it->second);
|
||||
} else {
|
||||
auto shape = deserialize<std::vector<int>>(is);
|
||||
auto shape = deserialize<Shape>(is);
|
||||
auto type = deserialize<Dtype>(is);
|
||||
size_t offset = is.tell();
|
||||
tape.push_back(array(
|
||||
|
@ -3027,9 +3027,9 @@ array kron(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
||||
}
|
||||
|
||||
int ndim = std::max(a.ndim(), b.ndim());
|
||||
std::vector<int> a_shape(2 * ndim, 1);
|
||||
std::vector<int> b_shape(2 * ndim, 1);
|
||||
std::vector<int> out_shape(ndim, 1);
|
||||
Shape a_shape(2 * ndim, 1);
|
||||
Shape b_shape(2 * ndim, 1);
|
||||
Shape out_shape(ndim, 1);
|
||||
|
||||
for (int i = ndim - 1, j = a.ndim() - 1; j >= 0; j--, i--) {
|
||||
a_shape[2 * i] = a.shape(j);
|
||||
|
@ -1205,8 +1205,8 @@ array conv_weight_backward_patches(
|
||||
auto in_padded =
|
||||
pad(in,
|
||||
padded_axes,
|
||||
Shape(padding_lo),
|
||||
Shape(padding_hi),
|
||||
Shape(padding_lo.begin(), padding_lo.end()),
|
||||
Shape(padding_hi.begin(), padding_hi.end()),
|
||||
array(0, in.dtype()),
|
||||
"constant",
|
||||
s);
|
||||
|
@ -592,7 +592,7 @@ class Broadcast : public UnaryPrimitive {
|
||||
static Shape output_shape(const std::vector<array>& inputs);
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<int> state() const {
|
||||
Shape state() const {
|
||||
return shape_;
|
||||
};
|
||||
|
||||
@ -1146,7 +1146,7 @@ class Gather : public UnaryPrimitive {
|
||||
DEFINE_NAME(Gather)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||
std::pair<std::vector<int>, std::vector<int>> state() const {
|
||||
std::pair<std::vector<int>, Shape> state() const {
|
||||
return {axes_, slice_sizes_};
|
||||
}
|
||||
|
||||
@ -1668,7 +1668,7 @@ class RandomBits : public UnaryPrimitive {
|
||||
DEFINE_VMAP()
|
||||
DEFINE_NAME(RandomBits)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::pair<std::vector<int>, int> state() const {
|
||||
std::pair<Shape, int> state() const {
|
||||
return {shape_, width_};
|
||||
};
|
||||
|
||||
@ -1703,7 +1703,7 @@ class Reshape : public UnaryPrimitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_NAME(Reshape)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::vector<int> state() const {
|
||||
Shape state() const {
|
||||
return shape_;
|
||||
};
|
||||
static Shape output_shape(const array& input, Shape shape);
|
||||
@ -2121,7 +2121,7 @@ class Split : public Primitive {
|
||||
DEFINE_GRADS()
|
||||
DEFINE_NAME(Split)
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
std::pair<std::vector<int>, int> state() const {
|
||||
std::pair<Shape, int> state() const {
|
||||
return {indices_, axis_};
|
||||
};
|
||||
|
||||
|
528
mlx/small_vector.h
Normal file
528
mlx/small_vector.h
Normal file
@ -0,0 +1,528 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
// Copyright © 2018 the V8 project authors.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following
|
||||
// disclaimer in the documentation and/or other materials provided
|
||||
// with the distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived
|
||||
// from this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
#if defined(__has_builtin)
|
||||
#define MLX_HAS_BUILTIN(x) __has_builtin(x)
|
||||
#else
|
||||
#define MLX_HAS_BUILTIN(x) 0
|
||||
#endif
|
||||
|
||||
#if defined(__has_attribute)
|
||||
#define MLX_HAS_ATTRIBUTE(x) __has_attribute(x)
|
||||
#else
|
||||
#define MLX_HAS_ATTRIBUTE(x) 0
|
||||
#endif
|
||||
|
||||
#if MLX_HAS_BUILTIN(__builtin_expect)
|
||||
#define MLX_LIKELY(condition) (__builtin_expect(!!(condition), 1))
|
||||
#define MLX_UNLIKELY(condition) (__builtin_expect(!!(condition), 0))
|
||||
#else
|
||||
#define MLX_LIKELY(condition) (condition)
|
||||
#define MLX_UNLIKELY(condition) (condition)
|
||||
#endif
|
||||
|
||||
#if MLX_HAS_ATTRIBUTE(noinline)
|
||||
#define MLX_NOINLINE __attribute__((noinline))
|
||||
#else
|
||||
#define MLX_NOINLINE
|
||||
#endif
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct is_iterator : std::false_type {};
|
||||
|
||||
template <typename T>
|
||||
struct is_iterator<
|
||||
T,
|
||||
std::void_t<
|
||||
typename std::iterator_traits<T>::difference_type,
|
||||
typename std::iterator_traits<T>::iterator_category,
|
||||
typename std::iterator_traits<T>::pointer,
|
||||
typename std::iterator_traits<T>::reference,
|
||||
typename std::iterator_traits<T>::value_type>> : std::true_type {};
|
||||
|
||||
template <typename T>
|
||||
constexpr bool is_iterator_v = is_iterator<T>::value;
|
||||
|
||||
// Minimal SmallVector implementation. Uses inline storage first, switches to
|
||||
// dynamic storage when it overflows.
|
||||
//
|
||||
// Notes:
|
||||
// * The default inline storage size is MAX_NDIM, as it is mainly used for
|
||||
// shapes and strides, users should choose a better size for other cases.
|
||||
// * The data() returns real address even for empty vector.
|
||||
// * The pointer returned by data() will change after moving the vector as it
|
||||
// points to the inline storage.
|
||||
// * For trivial elements the storage will not be default constructed,
|
||||
// i.e. SmallVector<int>(10) will not be filled with 0 by default.
|
||||
template <typename T, size_t kSize = 10, typename Allocator = std::allocator<T>>
|
||||
class SmallVector {
|
||||
public:
|
||||
using value_type = T;
|
||||
using reference = T&;
|
||||
using const_reference = const T&;
|
||||
using iterator = T*;
|
||||
using const_iterator = const T*;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using size_type = std::size_t;
|
||||
|
||||
SmallVector() = default;
|
||||
|
||||
explicit SmallVector(const Allocator& allocator) : allocator_(allocator) {}
|
||||
|
||||
explicit SmallVector(size_t size, const Allocator& allocator = Allocator())
|
||||
: allocator_(allocator) {
|
||||
resize(size);
|
||||
}
|
||||
|
||||
SmallVector(
|
||||
size_t size,
|
||||
const T& initial_value,
|
||||
const Allocator& allocator = Allocator())
|
||||
: allocator_(allocator) {
|
||||
resize(size, initial_value);
|
||||
}
|
||||
|
||||
SmallVector(
|
||||
std::initializer_list<T> init,
|
||||
const Allocator& allocator = Allocator())
|
||||
: allocator_(allocator) {
|
||||
if (init.size() > capacity()) {
|
||||
grow(init.size());
|
||||
}
|
||||
assert(capacity() >= init.size()); // sanity check
|
||||
std::uninitialized_move(init.begin(), init.end(), begin_);
|
||||
end_ = begin_ + init.size();
|
||||
}
|
||||
|
||||
template <typename Iter, typename = std::enable_if_t<is_iterator_v<Iter>>>
|
||||
SmallVector(Iter begin, Iter end, const Allocator& allocator = Allocator())
|
||||
: allocator_(allocator) {
|
||||
size_t size = std::distance(begin, end);
|
||||
if (size > capacity()) {
|
||||
grow(size);
|
||||
}
|
||||
assert(capacity() >= size); // sanity check
|
||||
std::uninitialized_copy(begin, end, begin_);
|
||||
end_ = begin_ + size;
|
||||
}
|
||||
|
||||
SmallVector(const SmallVector& other) : allocator_(other.allocator_) {
|
||||
*this = other;
|
||||
}
|
||||
SmallVector(const SmallVector& other, const Allocator& allocator)
|
||||
: allocator_(allocator) {
|
||||
*this = other;
|
||||
}
|
||||
SmallVector(SmallVector&& other) : allocator_(std::move(other.allocator_)) {
|
||||
*this = std::move(other);
|
||||
}
|
||||
SmallVector(SmallVector&& other, const Allocator& allocator)
|
||||
: allocator_(allocator) {
|
||||
*this = std::move(other);
|
||||
}
|
||||
|
||||
~SmallVector() {
|
||||
free_storage();
|
||||
}
|
||||
|
||||
SmallVector& operator=(const SmallVector& other) {
|
||||
if (this == &other) {
|
||||
return *this;
|
||||
}
|
||||
size_t other_size = other.size();
|
||||
if (capacity() < other_size) {
|
||||
// Create large-enough heap-allocated storage.
|
||||
free_storage();
|
||||
begin_ = allocator_.allocate(other_size);
|
||||
end_of_storage_ = begin_ + other_size;
|
||||
std::uninitialized_copy(other.begin_, other.end_, begin_);
|
||||
} else if constexpr (kHasTrivialElement) {
|
||||
std::copy(other.begin_, other.end_, begin_);
|
||||
} else {
|
||||
ptrdiff_t to_copy =
|
||||
std::min(static_cast<ptrdiff_t>(other_size), end_ - begin_);
|
||||
std::copy(other.begin_, other.begin_ + to_copy, begin_);
|
||||
if (other.begin_ + to_copy < other.end_) {
|
||||
std::uninitialized_copy(
|
||||
other.begin_ + to_copy, other.end_, begin_ + to_copy);
|
||||
} else {
|
||||
std::destroy_n(begin_ + to_copy, size() - to_copy);
|
||||
}
|
||||
}
|
||||
end_ = begin_ + other_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
SmallVector& operator=(SmallVector&& other) {
|
||||
if (this == &other) {
|
||||
return *this;
|
||||
}
|
||||
if (other.is_big()) {
|
||||
free_storage();
|
||||
begin_ = other.begin_;
|
||||
end_ = other.end_;
|
||||
end_of_storage_ = other.end_of_storage_;
|
||||
} else {
|
||||
assert(capacity() >= other.size()); // sanity check
|
||||
size_t other_size = other.size();
|
||||
if constexpr (kHasTrivialElement) {
|
||||
std::move(other.begin_, other.end_, begin_);
|
||||
} else {
|
||||
ptrdiff_t to_move =
|
||||
std::min(static_cast<ptrdiff_t>(other_size), end_ - begin_);
|
||||
std::move(other.begin_, other.begin_ + to_move, begin_);
|
||||
if (other.begin_ + to_move < other.end_) {
|
||||
std::uninitialized_move(
|
||||
other.begin_ + to_move, other.end_, begin_ + to_move);
|
||||
} else {
|
||||
std::destroy_n(begin_ + to_move, size() - to_move);
|
||||
}
|
||||
}
|
||||
end_ = begin_ + other_size;
|
||||
}
|
||||
other.reset_to_inline_storage();
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool operator==(const SmallVector& other) const {
|
||||
if (size() != other.size()) {
|
||||
return false;
|
||||
}
|
||||
return std::equal(begin_, end_, other.begin_);
|
||||
}
|
||||
|
||||
bool operator!=(const SmallVector& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
T* data() {
|
||||
return begin_;
|
||||
}
|
||||
const T* data() const {
|
||||
return begin_;
|
||||
}
|
||||
|
||||
iterator begin() {
|
||||
return begin_;
|
||||
}
|
||||
const_iterator begin() const {
|
||||
return begin_;
|
||||
}
|
||||
|
||||
iterator end() {
|
||||
return end_;
|
||||
}
|
||||
const_iterator end() const {
|
||||
return end_;
|
||||
}
|
||||
|
||||
const_iterator cbegin() const {
|
||||
return begin_;
|
||||
}
|
||||
|
||||
const_iterator cend() const {
|
||||
return end_;
|
||||
}
|
||||
|
||||
auto rbegin() {
|
||||
return std::make_reverse_iterator(end_);
|
||||
}
|
||||
auto rbegin() const {
|
||||
return std::make_reverse_iterator(end_);
|
||||
}
|
||||
|
||||
auto rend() {
|
||||
return std::make_reverse_iterator(begin_);
|
||||
}
|
||||
auto rend() const {
|
||||
return std::make_reverse_iterator(begin_);
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return end_ - begin_;
|
||||
}
|
||||
bool empty() const {
|
||||
return end_ == begin_;
|
||||
}
|
||||
size_t capacity() const {
|
||||
return end_of_storage_ - begin_;
|
||||
}
|
||||
|
||||
T& front() {
|
||||
assert(size() != 0);
|
||||
return begin_[0];
|
||||
}
|
||||
const T& front() const {
|
||||
assert(size() != 0);
|
||||
return begin_[0];
|
||||
}
|
||||
|
||||
T& back() {
|
||||
assert(size() != 0);
|
||||
return end_[-1];
|
||||
}
|
||||
const T& back() const {
|
||||
assert(size() != 0);
|
||||
return end_[-1];
|
||||
}
|
||||
|
||||
T& at(size_t index) {
|
||||
if (index >= size()) {
|
||||
throw std::out_of_range("SmallVector out of range.");
|
||||
}
|
||||
return begin_[index];
|
||||
}
|
||||
const T& at(size_t index) const {
|
||||
return const_cast<SmallVector*>(this)->at(index);
|
||||
}
|
||||
|
||||
T& operator[](size_t index) {
|
||||
assert(size() > index);
|
||||
return begin_[index];
|
||||
}
|
||||
const T& operator[](size_t index) const {
|
||||
return const_cast<SmallVector*>(this)->operator[](index);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void emplace_back(Args&&... args) {
|
||||
if (MLX_UNLIKELY(end_ == end_of_storage_)) {
|
||||
grow();
|
||||
}
|
||||
void* storage = end_;
|
||||
end_ += 1;
|
||||
new (storage) T(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
void push_back(T x) {
|
||||
emplace_back(std::move(x));
|
||||
}
|
||||
|
||||
void pop_back(size_t count = 1) {
|
||||
assert(size() >= count);
|
||||
end_ -= count;
|
||||
std::destroy_n(end_, count);
|
||||
}
|
||||
|
||||
iterator insert(iterator pos, T value) {
|
||||
return insert(pos, static_cast<size_t>(1), std::move(value));
|
||||
}
|
||||
|
||||
iterator insert(iterator pos, size_t count, T value) {
|
||||
assert(pos <= end_);
|
||||
size_t offset = pos - begin_;
|
||||
size_t old_size = size();
|
||||
resize(old_size + count);
|
||||
pos = begin_ + offset;
|
||||
iterator old_end = begin_ + old_size;
|
||||
assert(old_end <= end_);
|
||||
std::move_backward(pos, old_end, end_);
|
||||
if constexpr (kHasTrivialElement) {
|
||||
std::fill_n(pos, count, value);
|
||||
} else {
|
||||
std::fill_n(pos + 1, count - 1, value);
|
||||
*pos = std::move(value);
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
template <typename Iter, typename = std::enable_if_t<is_iterator_v<Iter>>>
|
||||
iterator insert(iterator pos, Iter begin, Iter end) {
|
||||
if constexpr (std::is_same_v<std::decay_t<Iter>, iterator>) {
|
||||
// The implementation can not take overlapping range.
|
||||
assert(!(begin >= pos && begin < pos + std::distance(begin, end)));
|
||||
assert(!(end > pos && end <= pos + std::distance(begin, end)));
|
||||
}
|
||||
|
||||
assert(pos <= end_);
|
||||
size_t offset = pos - begin_;
|
||||
size_t count = std::distance(begin, end);
|
||||
size_t old_size = size();
|
||||
resize(old_size + count);
|
||||
pos = begin_ + offset;
|
||||
iterator old_end = begin_ + old_size;
|
||||
assert(old_end <= end_);
|
||||
std::move_backward(pos, old_end, end_);
|
||||
std::copy(begin, end, pos);
|
||||
return pos;
|
||||
}
|
||||
|
||||
iterator insert(iterator pos, std::initializer_list<const T> values) {
|
||||
return insert(pos, values.begin(), values.end());
|
||||
}
|
||||
|
||||
iterator erase(iterator erase_start, iterator erase_end) {
|
||||
assert(erase_start >= begin_);
|
||||
assert(erase_start <= erase_end);
|
||||
assert(erase_end <= end_);
|
||||
iterator new_end = std::move(erase_end, end_, erase_start);
|
||||
std::destroy_n(new_end, std::distance(new_end, end_));
|
||||
end_ = new_end;
|
||||
return erase_start;
|
||||
}
|
||||
|
||||
iterator erase(iterator pos) {
|
||||
return erase(pos, pos + 1);
|
||||
}
|
||||
|
||||
void resize(size_t new_size) {
|
||||
if (new_size > capacity()) {
|
||||
grow(new_size);
|
||||
}
|
||||
T* new_end = begin_ + new_size;
|
||||
if constexpr (!kHasTrivialElement) {
|
||||
if (new_end > end_) {
|
||||
std::uninitialized_default_construct(end_, new_end);
|
||||
} else {
|
||||
std::destroy_n(new_end, end_ - new_end);
|
||||
}
|
||||
}
|
||||
end_ = new_end;
|
||||
}
|
||||
|
||||
void resize(size_t new_size, const T& initial_value) {
|
||||
if (new_size > capacity()) {
|
||||
grow(new_size);
|
||||
}
|
||||
T* new_end = begin_ + new_size;
|
||||
if (new_end > end_) {
|
||||
std::uninitialized_fill(end_, new_end, initial_value);
|
||||
} else {
|
||||
std::destroy_n(new_end, end_ - new_end);
|
||||
}
|
||||
end_ = new_end;
|
||||
}
|
||||
|
||||
void reserve(size_t new_capacity) {
|
||||
if (new_capacity > capacity()) {
|
||||
grow(new_capacity);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear without reverting back to inline storage.
|
||||
void clear() {
|
||||
std::destroy_n(begin_, end_ - begin_);
|
||||
end_ = begin_;
|
||||
}
|
||||
|
||||
// Grows the backing store by a factor of two, and at least to {min_capacity}.
|
||||
// TODO: Move to private after removing external code using this method.
|
||||
MLX_NOINLINE void grow(size_t min_capacity = 0) {
|
||||
size_t new_capacity = std::max(min_capacity, 2 * capacity());
|
||||
// Round up to power of 2.
|
||||
new_capacity--;
|
||||
new_capacity |= new_capacity >> 1;
|
||||
new_capacity |= new_capacity >> 2;
|
||||
new_capacity |= new_capacity >> 4;
|
||||
new_capacity |= new_capacity >> 8;
|
||||
new_capacity |= new_capacity >> 16;
|
||||
if constexpr (sizeof(size_t) == sizeof(uint64_t)) {
|
||||
new_capacity |= new_capacity >> 32;
|
||||
}
|
||||
new_capacity++;
|
||||
|
||||
T* new_storage = allocator_.allocate(new_capacity);
|
||||
if (new_storage == nullptr) {
|
||||
throw std::bad_alloc();
|
||||
}
|
||||
|
||||
size_t in_use = end_ - begin_;
|
||||
std::uninitialized_move(begin_, end_, new_storage);
|
||||
free_storage();
|
||||
begin_ = new_storage;
|
||||
end_ = new_storage + in_use;
|
||||
end_of_storage_ = new_storage + new_capacity;
|
||||
}
|
||||
|
||||
private:
|
||||
MLX_NOINLINE void free_storage() {
|
||||
std::destroy_n(begin_, end_ - begin_);
|
||||
if (is_big()) {
|
||||
allocator_.deallocate(begin_, end_of_storage_ - begin_);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear and go back to inline storage. Dynamic storage is *not* freed. For
|
||||
// internal use only.
|
||||
void reset_to_inline_storage() {
|
||||
if constexpr (!kHasTrivialElement) {
|
||||
if (!is_big())
|
||||
std::destroy_n(begin_, end_ - begin_);
|
||||
}
|
||||
begin_ = inline_storage_begin();
|
||||
end_ = begin_;
|
||||
end_of_storage_ = begin_ + kSize;
|
||||
}
|
||||
|
||||
bool is_big() const {
|
||||
return begin_ != inline_storage_begin();
|
||||
}
|
||||
|
||||
T* inline_storage_begin() {
|
||||
return reinterpret_cast<T*>(inline_storage_);
|
||||
}
|
||||
const T* inline_storage_begin() const {
|
||||
return reinterpret_cast<const T*>(inline_storage_);
|
||||
}
|
||||
|
||||
Allocator allocator_;
|
||||
|
||||
// Invariants:
|
||||
// 1. The elements in the range between `begin_` (included) and `end_` (not
|
||||
// included) will be initialized at all times.
|
||||
// 2. All other elements outside the range, both in the inline storage and in
|
||||
// the dynamic storage (if it exists), will be uninitialized at all times.
|
||||
|
||||
T* begin_ = inline_storage_begin();
|
||||
T* end_ = begin_;
|
||||
T* end_of_storage_ = begin_ + kSize;
|
||||
|
||||
alignas(T) char inline_storage_[sizeof(T) * kSize];
|
||||
|
||||
static constexpr bool kHasTrivialElement =
|
||||
std::is_trivially_copyable<T>::value &&
|
||||
std::is_trivially_destructible<T>::value;
|
||||
};
|
||||
|
||||
#undef MLX_HAS_BUILTIN
|
||||
#undef MLX_HAS_ATTRIBUTE
|
||||
#undef MLX_LIKELY
|
||||
#undef MLX_UNLIKELY
|
||||
#undef MLX_NOINLINE
|
||||
|
||||
} // namespace mlx::core
|
@ -259,6 +259,25 @@ std::ostream& operator<<(std::ostream& os, array a) {
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const SmallVector<int>& v) {
|
||||
os << "(";
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||
}
|
||||
os << ")";
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const SmallVector<int64_t>& v) {
|
||||
os << "(";
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||
}
|
||||
os << ")";
|
||||
return os;
|
||||
}
|
||||
|
||||
// TODO: Code is duplicated but they should be deleted soon.
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
|
||||
os << "(";
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
|
@ -100,6 +100,8 @@ std::ostream& operator<<(std::ostream& os, const Stream& s);
|
||||
std::ostream& operator<<(std::ostream& os, const Dtype& d);
|
||||
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
|
||||
std::ostream& operator<<(std::ostream& os, array a);
|
||||
std::ostream& operator<<(std::ostream& os, const SmallVector<int>& v);
|
||||
std::ostream& operator<<(std::ostream& os, const SmallVector<int64_t>& v);
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
|
||||
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include "python/src/buffer.h"
|
||||
#include "python/src/convert.h"
|
||||
#include "python/src/indexing.h"
|
||||
#include "python/src/small_vector.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
@ -303,7 +304,7 @@ void init_array(nb::module_& m) {
|
||||
R"pbdoc(The number of bytes in the array.)pbdoc")
|
||||
.def_prop_ro(
|
||||
"shape",
|
||||
[](const mx::array& a) { return nb::tuple(nb::cast(a.shape())); },
|
||||
[](const mx::array& a) { return nb::cast(a.shape()); },
|
||||
R"pbdoc(
|
||||
The shape of the array as a Python tuple.
|
||||
|
||||
|
@ -9,7 +9,7 @@
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
|
||||
#include "python/src/small_vector.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/export.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "python/src/small_vector.h"
|
||||
#include "python/src/trees.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
@ -8,10 +8,10 @@
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include "python/src/utils.h"
|
||||
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "python/src/small_vector.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "python/src/small_vector.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include "mlx/linalg.h"
|
||||
#include "python/src/small_vector.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "python/src/load.h"
|
||||
#include "python/src/small_vector.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
@ -7,8 +7,10 @@
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/memory.h"
|
||||
#include "python/src/small_vector.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
|
@ -16,6 +16,7 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "python/src/load.h"
|
||||
#include "python/src/small_vector.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
@ -7,10 +7,10 @@
|
||||
|
||||
#include <chrono>
|
||||
|
||||
#include "python/src/utils.h"
|
||||
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/random.h"
|
||||
#include "python/src/small_vector.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
|
77
python/src/small_vector.h
Normal file
77
python/src/small_vector.h
Normal file
@ -0,0 +1,77 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/small_vector.h"
|
||||
|
||||
#include <nanobind/stl/detail/nb_list.h>
|
||||
|
||||
NAMESPACE_BEGIN(NB_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <typename Type, size_t Size, typename Alloc>
|
||||
struct type_caster<mlx::core::SmallVector<Type, Size, Alloc>> {
|
||||
using List = mlx::core::SmallVector<Type, Size, Alloc>;
|
||||
using Caster = make_caster<Type>;
|
||||
|
||||
NB_TYPE_CASTER(
|
||||
List,
|
||||
const_name(NB_TYPING_TUPLE "[") + make_caster<Type>::Name +
|
||||
const_name(", ...]"))
|
||||
|
||||
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {
|
||||
size_t size;
|
||||
PyObject* temp;
|
||||
|
||||
// Will initialize 'size' and 'temp'. All return values and
|
||||
// return parameters are zero/NULL in the case of a failure.
|
||||
PyObject** o = seq_get(src.ptr(), &size, &temp);
|
||||
|
||||
value.clear();
|
||||
value.reserve(size);
|
||||
|
||||
Caster caster;
|
||||
bool success = o != nullptr;
|
||||
|
||||
flags = flags_for_local_caster<Type>(flags);
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if (!caster.from_python(o[i], flags, cleanup) ||
|
||||
!caster.template can_cast<Type>()) {
|
||||
success = false;
|
||||
break;
|
||||
}
|
||||
|
||||
value.push_back(caster.operator cast_t<Type>());
|
||||
}
|
||||
|
||||
Py_XDECREF(temp);
|
||||
|
||||
return success;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static handle from_cpp(T&& src, rv_policy policy, cleanup_list* cleanup) {
|
||||
object ret = steal(PyTuple_New(src.size()));
|
||||
|
||||
if (ret.is_valid()) {
|
||||
Py_ssize_t index = 0;
|
||||
|
||||
for (auto&& value : src) {
|
||||
handle h = Caster::from_cpp(forward_like_<T>(value), policy, cleanup);
|
||||
|
||||
if (!h.is_valid()) {
|
||||
ret.reset();
|
||||
break;
|
||||
}
|
||||
|
||||
NB_TUPLE_SET_ITEM(ret.ptr(), index++, h.ptr());
|
||||
}
|
||||
}
|
||||
|
||||
return ret.release();
|
||||
}
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
NAMESPACE_END(NB_NAMESPACE)
|
@ -20,6 +20,7 @@
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "python/src/mlx_func.h"
|
||||
#include "python/src/small_vector.h"
|
||||
#include "python/src/trees.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
Loading…
Reference in New Issue
Block a user