mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-12 12:16:43 +08:00
Use SmallVector for shapes and strides (#2454)
* Use SmallVector for shapes and strides * Convert SmallVector to tuple
This commit is contained in:
parent
7d86a5c108
commit
828c5f1137
@ -10,6 +10,7 @@
|
|||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
#include "mlx/dtype.h"
|
#include "mlx/dtype.h"
|
||||||
#include "mlx/event.h"
|
#include "mlx/event.h"
|
||||||
|
#include "mlx/small_vector.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -18,8 +19,8 @@ class Primitive;
|
|||||||
|
|
||||||
using Deleter = std::function<void(allocator::Buffer)>;
|
using Deleter = std::function<void(allocator::Buffer)>;
|
||||||
using ShapeElem = int32_t;
|
using ShapeElem = int32_t;
|
||||||
using Shape = std::vector<ShapeElem>;
|
using Shape = SmallVector<ShapeElem>;
|
||||||
using Strides = std::vector<int64_t>;
|
using Strides = SmallVector<int64_t>;
|
||||||
|
|
||||||
class array {
|
class array {
|
||||||
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
/* An array is really a node in a graph. It contains a shared ArrayDesc
|
||||||
|
@ -200,7 +200,7 @@ void shared_buffer_reshape(
|
|||||||
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline std::vector<T> remove_index(std::vector<T> vec, size_t index) {
|
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
|
||||||
vec.erase(std::next(vec.begin(), index));
|
vec.erase(std::next(vec.begin(), index));
|
||||||
return vec;
|
return vec;
|
||||||
}
|
}
|
||||||
|
@ -288,6 +288,14 @@ void Compiled::eval_cpu(
|
|||||||
auto [contiguous, shape, strides] =
|
auto [contiguous, shape, strides] =
|
||||||
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
|
||||||
|
|
||||||
|
// Force allocating shape/strides on heap so we can take their data() first
|
||||||
|
// and then std::move them.
|
||||||
|
// TODO: Refactor code to avoid heap allocation.
|
||||||
|
shape.grow();
|
||||||
|
for (auto& s : strides) {
|
||||||
|
s.grow();
|
||||||
|
}
|
||||||
|
|
||||||
// Collect function input arguments.
|
// Collect function input arguments.
|
||||||
std::vector<void*> args;
|
std::vector<void*> args;
|
||||||
int strides_index = 1;
|
int strides_index = 1;
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/cpu/copy.h"
|
#include "mlx/backend/cpu/copy.h"
|
||||||
#include "mlx/backend/cpu/encoder.h"
|
#include "mlx/backend/cpu/encoder.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@ -333,47 +333,24 @@ void Sort::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
int axis = axis_;
|
||||||
|
if (axis < 0) {
|
||||||
|
axis += in.ndim();
|
||||||
|
}
|
||||||
|
|
||||||
// Copy input to output
|
// Copy input to output
|
||||||
CopyType ctype = (in.flags().contiguous && in.strides()[axis_] != 0)
|
CopyType ctype = (in.flags().contiguous && in.strides()[axis] != 0)
|
||||||
? CopyType::Vector
|
? CopyType::Vector
|
||||||
: CopyType::General;
|
: CopyType::General;
|
||||||
copy_cpu(in, out, ctype, stream());
|
copy_cpu(in, out, ctype, stream());
|
||||||
|
|
||||||
auto& encoder = cpu::get_command_encoder(stream());
|
auto& encoder = cpu::get_command_encoder(stream());
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.dispatch(
|
encoder.dispatch([out = array::unsafe_weak_copy(out), axis]() mutable {
|
||||||
[out = array::unsafe_weak_copy(out), axis_ = axis_]() mutable {
|
dispatch_all_types(out.dtype(), [&](auto type_tag) {
|
||||||
switch (out.dtype()) {
|
sort<MLX_GET_TYPE(type_tag)>(out, axis);
|
||||||
case bool_:
|
});
|
||||||
return sort<bool>(out, axis_);
|
});
|
||||||
case uint8:
|
|
||||||
return sort<uint8_t>(out, axis_);
|
|
||||||
case uint16:
|
|
||||||
return sort<uint16_t>(out, axis_);
|
|
||||||
case uint32:
|
|
||||||
return sort<uint32_t>(out, axis_);
|
|
||||||
case uint64:
|
|
||||||
return sort<uint64_t>(out, axis_);
|
|
||||||
case int8:
|
|
||||||
return sort<int8_t>(out, axis_);
|
|
||||||
case int16:
|
|
||||||
return sort<int16_t>(out, axis_);
|
|
||||||
case int32:
|
|
||||||
return sort<int32_t>(out, axis_);
|
|
||||||
case int64:
|
|
||||||
return sort<int64_t>(out, axis_);
|
|
||||||
case float32:
|
|
||||||
return sort<float>(out, axis_);
|
|
||||||
case float64:
|
|
||||||
return sort<double>(out, axis_);
|
|
||||||
case float16:
|
|
||||||
return sort<float16_t>(out, axis_);
|
|
||||||
case bfloat16:
|
|
||||||
return sort<bfloat16_t>(out, axis_);
|
|
||||||
case complex64:
|
|
||||||
return sort<complex64_t>(out, axis_);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
void ArgPartition::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
@ -55,13 +55,13 @@ auto& conv_cache() {
|
|||||||
return cache;
|
return cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename Vec>
|
||||||
inline std::vector<T> convert_vector(const std::vector<U>& vec) {
|
inline SmallVector<T> convert_vector(const Vec& vec) {
|
||||||
return std::vector<T>(vec.begin(), vec.end());
|
return SmallVector<T>(vec.begin(), vec.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, template <typename U> class Vec>
|
||||||
inline std::array<T, MAX_NDIM> fixed_vector(const std::vector<T>& vec) {
|
inline std::array<T, MAX_NDIM> fixed_vector(const Vec<T>& vec) {
|
||||||
if (vec.size() > MAX_NDIM) {
|
if (vec.size() > MAX_NDIM) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
fmt::format("ndim can not be larger than {}.", MAX_NDIM));
|
||||||
@ -78,7 +78,7 @@ auto nhwc_to_nchw(const array& x) {
|
|||||||
auto strides = convert_vector<int64_t>(x.strides());
|
auto strides = convert_vector<int64_t>(x.strides());
|
||||||
strides.insert(strides.begin() + 1, strides.back());
|
strides.insert(strides.begin() + 1, strides.back());
|
||||||
strides.erase(strides.end() - 1);
|
strides.erase(strides.end() - 1);
|
||||||
return std::make_tuple(shape, strides);
|
return std::make_tuple(std::move(shape), std::move(strides));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) {
|
||||||
@ -298,10 +298,10 @@ std::optional<cudnn_frontend::OperationGraph> build_op_graph(
|
|||||||
array& x,
|
array& x,
|
||||||
array& w,
|
array& w,
|
||||||
array& y,
|
array& y,
|
||||||
const std::vector<int64_t>& stride,
|
const SmallVector<int64_t>& stride,
|
||||||
const std::vector<int64_t>& padding_lo,
|
const SmallVector<int64_t>& padding_lo,
|
||||||
const std::vector<int64_t>& padding_hi,
|
const SmallVector<int64_t>& padding_hi,
|
||||||
const std::vector<int64_t>& dilation) {
|
const SmallVector<int64_t>& dilation) {
|
||||||
try {
|
try {
|
||||||
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
|
auto compute_dtype = (dtype == float16 || dtype == bfloat16)
|
||||||
? CUDNN_DATA_FLOAT
|
? CUDNN_DATA_FLOAT
|
||||||
@ -468,7 +468,7 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out_) {
|
|||||||
|
|
||||||
// There is no reliable way to deduce the proper cuDNN backend for the
|
// There is no reliable way to deduce the proper cuDNN backend for the
|
||||||
// convolution, so we make a best guess and then try.
|
// convolution, so we make a best guess and then try.
|
||||||
std::vector<cudnnBackendDescriptorType_t> try_backends;
|
SmallVector<cudnnBackendDescriptorType_t, 2> try_backends;
|
||||||
if (flip_) {
|
if (flip_) {
|
||||||
// When weight is flipped, we assume it is backward input convolution.
|
// When weight is flipped, we assume it is backward input convolution.
|
||||||
try_backends.push_back(CONV_BACKWARD_INPUT);
|
try_backends.push_back(CONV_BACKWARD_INPUT);
|
||||||
|
@ -29,12 +29,12 @@ void append_indices_arg(
|
|||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
int nidx,
|
int nidx,
|
||||||
int idx_ndim) {
|
int idx_ndim) {
|
||||||
std::vector<const void*> indices(nidx);
|
SmallVector<const void*> indices(nidx);
|
||||||
for (int i = 0; i < nidx; ++i) {
|
for (int i = 0; i < nidx; ++i) {
|
||||||
indices[i] = inputs[i + 1].data<void>();
|
indices[i] = inputs[i + 1].data<void>();
|
||||||
}
|
}
|
||||||
args.append(std::move(indices));
|
args.append(std::move(indices));
|
||||||
std::vector<int32_t> indices_shape(nidx * idx_ndim);
|
SmallVector<int32_t> indices_shape(nidx * idx_ndim);
|
||||||
for (int i = 0; i < nidx; ++i) {
|
for (int i = 0; i < nidx; ++i) {
|
||||||
std::copy_n(
|
std::copy_n(
|
||||||
inputs[i + 1].shape().begin(),
|
inputs[i + 1].shape().begin(),
|
||||||
@ -42,7 +42,7 @@ void append_indices_arg(
|
|||||||
indices_shape.data() + i * idx_ndim);
|
indices_shape.data() + i * idx_ndim);
|
||||||
}
|
}
|
||||||
args.append(std::move(indices_shape));
|
args.append(std::move(indices_shape));
|
||||||
std::vector<int64_t> indices_strides(nidx * idx_ndim);
|
SmallVector<int64_t> indices_strides(nidx * idx_ndim);
|
||||||
for (int i = 0; i < nidx; ++i) {
|
for (int i = 0; i < nidx; ++i) {
|
||||||
std::copy_n(
|
std::copy_n(
|
||||||
inputs[i + 1].strides().begin(),
|
inputs[i + 1].strides().begin(),
|
||||||
@ -110,7 +110,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
args.append<int32_t>(src.ndim());
|
args.append<int32_t>(src.ndim());
|
||||||
args.append_ndim(slice_sizes_);
|
args.append_ndim(slice_sizes_);
|
||||||
args.append(slice_size);
|
args.append(slice_size);
|
||||||
args.append(axes_);
|
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
|
||||||
append_indices_arg(args, inputs, nidx, idx_ndim);
|
append_indices_arg(args, inputs, nidx, idx_ndim);
|
||||||
|
|
||||||
std::string kernel_name = fmt::format(
|
std::string kernel_name = fmt::format(
|
||||||
@ -211,7 +211,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
args.append_ndim(out.shape());
|
args.append_ndim(out.shape());
|
||||||
args.append_ndim(out.strides());
|
args.append_ndim(out.strides());
|
||||||
args.append<int32_t>(out.ndim());
|
args.append<int32_t>(out.ndim());
|
||||||
args.append(axes_);
|
args.append(SmallVector<int32_t>(axes_.begin(), axes_.end()));
|
||||||
append_indices_arg(args, inputs, nidx, idx_ndim);
|
append_indices_arg(args, inputs, nidx, idx_ndim);
|
||||||
|
|
||||||
std::string kernel_name = fmt::format(
|
std::string kernel_name = fmt::format(
|
||||||
|
@ -40,19 +40,14 @@ struct KernelArgs {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void append(std::vector<T> vec) {
|
void append(SmallVector<T> vec) {
|
||||||
if (vec.empty()) {
|
storage_.emplace_back(std::move(vec));
|
||||||
// The nullptr can not be used as arg, pass something not null.
|
append_ptr(std::get<SmallVector<T>>(storage_.back()).data());
|
||||||
append(std::monostate{});
|
|
||||||
} else {
|
|
||||||
append_ptr(vec.data());
|
|
||||||
storage_.emplace_back(std::move(vec));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure the arg is copied to an array with size of NDIM.
|
// Make sure the arg is copied to an array with size of NDIM.
|
||||||
template <size_t NDIM = MAX_NDIM, typename T>
|
template <size_t NDIM = MAX_NDIM, typename T>
|
||||||
void append_ndim(std::vector<T> vec) {
|
void append_ndim(SmallVector<T> vec) {
|
||||||
if (vec.size() > NDIM) {
|
if (vec.size() > NDIM) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
fmt::format("ndim can not be larger than {}.", NDIM));
|
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||||
@ -76,9 +71,9 @@ struct KernelArgs {
|
|||||||
int32_t,
|
int32_t,
|
||||||
uint32_t,
|
uint32_t,
|
||||||
int64_t,
|
int64_t,
|
||||||
std::vector<const void*>,
|
SmallVector<const void*>,
|
||||||
std::vector<int32_t>,
|
SmallVector<int32_t>,
|
||||||
std::vector<int64_t>>;
|
SmallVector<int64_t>>;
|
||||||
std::deque<Arg> storage_;
|
std::deque<Arg> storage_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -101,7 +101,7 @@ inline constexpr bool is_inexact_v = is_floating_v<T> || is_complex_v<T>;
|
|||||||
|
|
||||||
// Utility to copy data from vector to array in host.
|
// Utility to copy data from vector to array in host.
|
||||||
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
template <int NDIM = MAX_NDIM, typename T = int32_t>
|
||||||
inline cuda::std::array<T, NDIM> const_param(const std::vector<T>& vec) {
|
inline cuda::std::array<T, NDIM> const_param(const SmallVector<T>& vec) {
|
||||||
if (vec.size() > NDIM) {
|
if (vec.size() > NDIM) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
fmt::format("ndim can not be larger than {}.", NDIM));
|
fmt::format("ndim can not be larger than {}.", NDIM));
|
||||||
|
@ -28,15 +28,20 @@ inline array ensure_row_contiguous_matrix(
|
|||||||
const array& x,
|
const array& x,
|
||||||
cu::CommandEncoder& enc,
|
cu::CommandEncoder& enc,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
auto stride_0 = x.strides()[x.ndim() - 2];
|
if (x.ndim() < 2) {
|
||||||
auto stride_1 = x.strides()[x.ndim() - 1];
|
if (x.strides()[0] == 1) {
|
||||||
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
return x;
|
||||||
return x;
|
}
|
||||||
} else {
|
} else {
|
||||||
array x_copy = contiguous_copy_gpu(x, s);
|
auto stride_0 = x.strides()[x.ndim() - 2];
|
||||||
enc.add_temporary(x_copy);
|
auto stride_1 = x.strides()[x.ndim() - 1];
|
||||||
return x_copy;
|
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
|
enc.add_temporary(x_copy);
|
||||||
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -60,6 +60,16 @@ struct CommandEncoder {
|
|||||||
enc_->updateFence(fence);
|
enc_->updateFence(fence);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void set_vector_bytes(const SmallVector<T>& vec, size_t nelems, int idx) {
|
||||||
|
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
void set_vector_bytes(const SmallVector<T>& vec, int idx) {
|
||||||
|
return set_vector_bytes(vec, vec.size(), idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Code is duplicated but they should be deleted soon.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
|
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
|
||||||
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
|
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||||
|
@ -32,15 +32,20 @@ inline array ensure_row_contiguous_matrix(
|
|||||||
const array& x,
|
const array& x,
|
||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
auto stride_0 = x.strides()[x.ndim() - 2];
|
if (x.ndim() < 2) {
|
||||||
auto stride_1 = x.strides()[x.ndim() - 1];
|
if (x.strides()[0] == 1) {
|
||||||
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
return x;
|
||||||
return x;
|
}
|
||||||
} else {
|
} else {
|
||||||
array x_copy = contiguous_copy_gpu(x, s);
|
auto stride_0 = x.strides()[x.ndim() - 2];
|
||||||
d.add_temporary(x_copy, s.index);
|
auto stride_1 = x.strides()[x.ndim() - 1];
|
||||||
return x_copy;
|
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
|
d.add_temporary(x_copy, s.index);
|
||||||
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int get_qmv_batch_limit(int D, int O, metal::Device& d) {
|
inline int get_qmv_batch_limit(int D, int O, metal::Device& d) {
|
||||||
|
@ -179,7 +179,7 @@ void serialize(Writer& os, const array& arr) {
|
|||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
array deserialize(Reader& is) {
|
array deserialize(Reader& is) {
|
||||||
auto shape = deserialize<std::vector<int>>(is);
|
auto shape = deserialize<Shape>(is);
|
||||||
auto type = deserialize<Dtype>(is);
|
auto type = deserialize<Dtype>(is);
|
||||||
return array(std::move(shape), type, nullptr, std::vector<array>{});
|
return array(std::move(shape), type, nullptr, std::vector<array>{});
|
||||||
}
|
}
|
||||||
@ -640,7 +640,7 @@ void FunctionExporter::export_function(const Args& args, const Kwargs& kwargs) {
|
|||||||
auto outputs = arr.outputs();
|
auto outputs = arr.outputs();
|
||||||
serialize(os, arrays_to_ids(outputs));
|
serialize(os, arrays_to_ids(outputs));
|
||||||
|
|
||||||
std::vector<std::vector<int>> shapes;
|
std::vector<Shape> shapes;
|
||||||
std::vector<Dtype> dtypes;
|
std::vector<Dtype> dtypes;
|
||||||
for (auto& o : outputs) {
|
for (auto& o : outputs) {
|
||||||
shapes.push_back(o.shape());
|
shapes.push_back(o.shape());
|
||||||
@ -813,14 +813,14 @@ ImportedFunction::ImportedFunction(const std::string& file)
|
|||||||
std::shared_ptr<Primitive> prim = factory.load(is);
|
std::shared_ptr<Primitive> prim = factory.load(is);
|
||||||
auto num_siblings = deserialize<uint64_t>(is);
|
auto num_siblings = deserialize<uint64_t>(is);
|
||||||
if (num_siblings == 0) {
|
if (num_siblings == 0) {
|
||||||
auto shape = deserialize<std::vector<int>>(is);
|
auto shape = deserialize<Shape>(is);
|
||||||
auto type = deserialize<Dtype>(is);
|
auto type = deserialize<Dtype>(is);
|
||||||
tape.emplace_back(
|
tape.emplace_back(
|
||||||
std::move(shape), type, std::move(prim), std::move(inputs));
|
std::move(shape), type, std::move(prim), std::move(inputs));
|
||||||
array_map.emplace(id, tape.back());
|
array_map.emplace(id, tape.back());
|
||||||
} else {
|
} else {
|
||||||
auto ids = deserialize<std::vector<uint64_t>>(is);
|
auto ids = deserialize<std::vector<uint64_t>>(is);
|
||||||
auto shapes = deserialize<std::vector<std::vector<int>>>(is);
|
auto shapes = deserialize<std::vector<Shape>>(is);
|
||||||
auto types = deserialize<std::vector<Dtype>>(is);
|
auto types = deserialize<std::vector<Dtype>>(is);
|
||||||
auto arrays = array::make_arrays(
|
auto arrays = array::make_arrays(
|
||||||
std::move(shapes),
|
std::move(shapes),
|
||||||
@ -841,7 +841,7 @@ ImportedFunction::ImportedFunction(const std::string& file)
|
|||||||
if (auto it = constants.find(id); it != constants.end()) {
|
if (auto it = constants.find(id); it != constants.end()) {
|
||||||
tape.push_back(it->second);
|
tape.push_back(it->second);
|
||||||
} else {
|
} else {
|
||||||
auto shape = deserialize<std::vector<int>>(is);
|
auto shape = deserialize<Shape>(is);
|
||||||
auto type = deserialize<Dtype>(is);
|
auto type = deserialize<Dtype>(is);
|
||||||
size_t offset = is.tell();
|
size_t offset = is.tell();
|
||||||
tape.push_back(array(
|
tape.push_back(array(
|
||||||
|
@ -3027,9 +3027,9 @@ array kron(const array& a, const array& b, StreamOrDevice s /* = {} */) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int ndim = std::max(a.ndim(), b.ndim());
|
int ndim = std::max(a.ndim(), b.ndim());
|
||||||
std::vector<int> a_shape(2 * ndim, 1);
|
Shape a_shape(2 * ndim, 1);
|
||||||
std::vector<int> b_shape(2 * ndim, 1);
|
Shape b_shape(2 * ndim, 1);
|
||||||
std::vector<int> out_shape(ndim, 1);
|
Shape out_shape(ndim, 1);
|
||||||
|
|
||||||
for (int i = ndim - 1, j = a.ndim() - 1; j >= 0; j--, i--) {
|
for (int i = ndim - 1, j = a.ndim() - 1; j >= 0; j--, i--) {
|
||||||
a_shape[2 * i] = a.shape(j);
|
a_shape[2 * i] = a.shape(j);
|
||||||
|
@ -1205,8 +1205,8 @@ array conv_weight_backward_patches(
|
|||||||
auto in_padded =
|
auto in_padded =
|
||||||
pad(in,
|
pad(in,
|
||||||
padded_axes,
|
padded_axes,
|
||||||
Shape(padding_lo),
|
Shape(padding_lo.begin(), padding_lo.end()),
|
||||||
Shape(padding_hi),
|
Shape(padding_hi.begin(), padding_hi.end()),
|
||||||
array(0, in.dtype()),
|
array(0, in.dtype()),
|
||||||
"constant",
|
"constant",
|
||||||
s);
|
s);
|
||||||
|
@ -592,7 +592,7 @@ class Broadcast : public UnaryPrimitive {
|
|||||||
static Shape output_shape(const std::vector<array>& inputs);
|
static Shape output_shape(const std::vector<array>& inputs);
|
||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
std::vector<int> state() const {
|
Shape state() const {
|
||||||
return shape_;
|
return shape_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1146,7 +1146,7 @@ class Gather : public UnaryPrimitive {
|
|||||||
DEFINE_NAME(Gather)
|
DEFINE_NAME(Gather)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
std::pair<std::vector<int>, std::vector<int>> state() const {
|
std::pair<std::vector<int>, Shape> state() const {
|
||||||
return {axes_, slice_sizes_};
|
return {axes_, slice_sizes_};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1668,7 +1668,7 @@ class RandomBits : public UnaryPrimitive {
|
|||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_NAME(RandomBits)
|
DEFINE_NAME(RandomBits)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
std::pair<std::vector<int>, int> state() const {
|
std::pair<Shape, int> state() const {
|
||||||
return {shape_, width_};
|
return {shape_, width_};
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1703,7 +1703,7 @@ class Reshape : public UnaryPrimitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_NAME(Reshape)
|
DEFINE_NAME(Reshape)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
std::vector<int> state() const {
|
Shape state() const {
|
||||||
return shape_;
|
return shape_;
|
||||||
};
|
};
|
||||||
static Shape output_shape(const array& input, Shape shape);
|
static Shape output_shape(const array& input, Shape shape);
|
||||||
@ -2121,7 +2121,7 @@ class Split : public Primitive {
|
|||||||
DEFINE_GRADS()
|
DEFINE_GRADS()
|
||||||
DEFINE_NAME(Split)
|
DEFINE_NAME(Split)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
std::pair<std::vector<int>, int> state() const {
|
std::pair<Shape, int> state() const {
|
||||||
return {indices_, axis_};
|
return {indices_, axis_};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
528
mlx/small_vector.h
Normal file
528
mlx/small_vector.h
Normal file
@ -0,0 +1,528 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
// Copyright © 2018 the V8 project authors.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without
|
||||||
|
// modification, are permitted provided that the following conditions are
|
||||||
|
// met:
|
||||||
|
//
|
||||||
|
// * Redistributions of source code must retain the above copyright
|
||||||
|
// notice, this list of conditions and the following disclaimer.
|
||||||
|
// * Redistributions in binary form must reproduce the above
|
||||||
|
// copyright notice, this list of conditions and the following
|
||||||
|
// disclaimer in the documentation and/or other materials provided
|
||||||
|
// with the distribution.
|
||||||
|
// * Neither the name of Google Inc. nor the names of its
|
||||||
|
// contributors may be used to endorse or promote products derived
|
||||||
|
// from this software without specific prior written permission.
|
||||||
|
//
|
||||||
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <type_traits>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
#if defined(__has_builtin)
|
||||||
|
#define MLX_HAS_BUILTIN(x) __has_builtin(x)
|
||||||
|
#else
|
||||||
|
#define MLX_HAS_BUILTIN(x) 0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__has_attribute)
|
||||||
|
#define MLX_HAS_ATTRIBUTE(x) __has_attribute(x)
|
||||||
|
#else
|
||||||
|
#define MLX_HAS_ATTRIBUTE(x) 0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if MLX_HAS_BUILTIN(__builtin_expect)
|
||||||
|
#define MLX_LIKELY(condition) (__builtin_expect(!!(condition), 1))
|
||||||
|
#define MLX_UNLIKELY(condition) (__builtin_expect(!!(condition), 0))
|
||||||
|
#else
|
||||||
|
#define MLX_LIKELY(condition) (condition)
|
||||||
|
#define MLX_UNLIKELY(condition) (condition)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if MLX_HAS_ATTRIBUTE(noinline)
|
||||||
|
#define MLX_NOINLINE __attribute__((noinline))
|
||||||
|
#else
|
||||||
|
#define MLX_NOINLINE
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <typename T, typename = void>
|
||||||
|
struct is_iterator : std::false_type {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct is_iterator<
|
||||||
|
T,
|
||||||
|
std::void_t<
|
||||||
|
typename std::iterator_traits<T>::difference_type,
|
||||||
|
typename std::iterator_traits<T>::iterator_category,
|
||||||
|
typename std::iterator_traits<T>::pointer,
|
||||||
|
typename std::iterator_traits<T>::reference,
|
||||||
|
typename std::iterator_traits<T>::value_type>> : std::true_type {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
constexpr bool is_iterator_v = is_iterator<T>::value;
|
||||||
|
|
||||||
|
// Minimal SmallVector implementation. Uses inline storage first, switches to
|
||||||
|
// dynamic storage when it overflows.
|
||||||
|
//
|
||||||
|
// Notes:
|
||||||
|
// * The default inline storage size is MAX_NDIM, as it is mainly used for
|
||||||
|
// shapes and strides, users should choose a better size for other cases.
|
||||||
|
// * The data() returns real address even for empty vector.
|
||||||
|
// * The pointer returned by data() will change after moving the vector as it
|
||||||
|
// points to the inline storage.
|
||||||
|
// * For trivial elements the storage will not be default constructed,
|
||||||
|
// i.e. SmallVector<int>(10) will not be filled with 0 by default.
|
||||||
|
template <typename T, size_t kSize = 10, typename Allocator = std::allocator<T>>
|
||||||
|
class SmallVector {
|
||||||
|
public:
|
||||||
|
using value_type = T;
|
||||||
|
using reference = T&;
|
||||||
|
using const_reference = const T&;
|
||||||
|
using iterator = T*;
|
||||||
|
using const_iterator = const T*;
|
||||||
|
using difference_type = std::ptrdiff_t;
|
||||||
|
using size_type = std::size_t;
|
||||||
|
|
||||||
|
SmallVector() = default;
|
||||||
|
|
||||||
|
explicit SmallVector(const Allocator& allocator) : allocator_(allocator) {}
|
||||||
|
|
||||||
|
explicit SmallVector(size_t size, const Allocator& allocator = Allocator())
|
||||||
|
: allocator_(allocator) {
|
||||||
|
resize(size);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector(
|
||||||
|
size_t size,
|
||||||
|
const T& initial_value,
|
||||||
|
const Allocator& allocator = Allocator())
|
||||||
|
: allocator_(allocator) {
|
||||||
|
resize(size, initial_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector(
|
||||||
|
std::initializer_list<T> init,
|
||||||
|
const Allocator& allocator = Allocator())
|
||||||
|
: allocator_(allocator) {
|
||||||
|
if (init.size() > capacity()) {
|
||||||
|
grow(init.size());
|
||||||
|
}
|
||||||
|
assert(capacity() >= init.size()); // sanity check
|
||||||
|
std::uninitialized_move(init.begin(), init.end(), begin_);
|
||||||
|
end_ = begin_ + init.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Iter, typename = std::enable_if_t<is_iterator_v<Iter>>>
|
||||||
|
SmallVector(Iter begin, Iter end, const Allocator& allocator = Allocator())
|
||||||
|
: allocator_(allocator) {
|
||||||
|
size_t size = std::distance(begin, end);
|
||||||
|
if (size > capacity()) {
|
||||||
|
grow(size);
|
||||||
|
}
|
||||||
|
assert(capacity() >= size); // sanity check
|
||||||
|
std::uninitialized_copy(begin, end, begin_);
|
||||||
|
end_ = begin_ + size;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector(const SmallVector& other) : allocator_(other.allocator_) {
|
||||||
|
*this = other;
|
||||||
|
}
|
||||||
|
SmallVector(const SmallVector& other, const Allocator& allocator)
|
||||||
|
: allocator_(allocator) {
|
||||||
|
*this = other;
|
||||||
|
}
|
||||||
|
SmallVector(SmallVector&& other) : allocator_(std::move(other.allocator_)) {
|
||||||
|
*this = std::move(other);
|
||||||
|
}
|
||||||
|
SmallVector(SmallVector&& other, const Allocator& allocator)
|
||||||
|
: allocator_(allocator) {
|
||||||
|
*this = std::move(other);
|
||||||
|
}
|
||||||
|
|
||||||
|
~SmallVector() {
|
||||||
|
free_storage();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector& operator=(const SmallVector& other) {
|
||||||
|
if (this == &other) {
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
size_t other_size = other.size();
|
||||||
|
if (capacity() < other_size) {
|
||||||
|
// Create large-enough heap-allocated storage.
|
||||||
|
free_storage();
|
||||||
|
begin_ = allocator_.allocate(other_size);
|
||||||
|
end_of_storage_ = begin_ + other_size;
|
||||||
|
std::uninitialized_copy(other.begin_, other.end_, begin_);
|
||||||
|
} else if constexpr (kHasTrivialElement) {
|
||||||
|
std::copy(other.begin_, other.end_, begin_);
|
||||||
|
} else {
|
||||||
|
ptrdiff_t to_copy =
|
||||||
|
std::min(static_cast<ptrdiff_t>(other_size), end_ - begin_);
|
||||||
|
std::copy(other.begin_, other.begin_ + to_copy, begin_);
|
||||||
|
if (other.begin_ + to_copy < other.end_) {
|
||||||
|
std::uninitialized_copy(
|
||||||
|
other.begin_ + to_copy, other.end_, begin_ + to_copy);
|
||||||
|
} else {
|
||||||
|
std::destroy_n(begin_ + to_copy, size() - to_copy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
end_ = begin_ + other_size;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector& operator=(SmallVector&& other) {
|
||||||
|
if (this == &other) {
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
if (other.is_big()) {
|
||||||
|
free_storage();
|
||||||
|
begin_ = other.begin_;
|
||||||
|
end_ = other.end_;
|
||||||
|
end_of_storage_ = other.end_of_storage_;
|
||||||
|
} else {
|
||||||
|
assert(capacity() >= other.size()); // sanity check
|
||||||
|
size_t other_size = other.size();
|
||||||
|
if constexpr (kHasTrivialElement) {
|
||||||
|
std::move(other.begin_, other.end_, begin_);
|
||||||
|
} else {
|
||||||
|
ptrdiff_t to_move =
|
||||||
|
std::min(static_cast<ptrdiff_t>(other_size), end_ - begin_);
|
||||||
|
std::move(other.begin_, other.begin_ + to_move, begin_);
|
||||||
|
if (other.begin_ + to_move < other.end_) {
|
||||||
|
std::uninitialized_move(
|
||||||
|
other.begin_ + to_move, other.end_, begin_ + to_move);
|
||||||
|
} else {
|
||||||
|
std::destroy_n(begin_ + to_move, size() - to_move);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
end_ = begin_ + other_size;
|
||||||
|
}
|
||||||
|
other.reset_to_inline_storage();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator==(const SmallVector& other) const {
|
||||||
|
if (size() != other.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return std::equal(begin_, end_, other.begin_);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator!=(const SmallVector& other) const {
|
||||||
|
return !(*this == other);
|
||||||
|
}
|
||||||
|
|
||||||
|
T* data() {
|
||||||
|
return begin_;
|
||||||
|
}
|
||||||
|
const T* data() const {
|
||||||
|
return begin_;
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator begin() {
|
||||||
|
return begin_;
|
||||||
|
}
|
||||||
|
const_iterator begin() const {
|
||||||
|
return begin_;
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator end() {
|
||||||
|
return end_;
|
||||||
|
}
|
||||||
|
const_iterator end() const {
|
||||||
|
return end_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const_iterator cbegin() const {
|
||||||
|
return begin_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const_iterator cend() const {
|
||||||
|
return end_;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rbegin() {
|
||||||
|
return std::make_reverse_iterator(end_);
|
||||||
|
}
|
||||||
|
auto rbegin() const {
|
||||||
|
return std::make_reverse_iterator(end_);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rend() {
|
||||||
|
return std::make_reverse_iterator(begin_);
|
||||||
|
}
|
||||||
|
auto rend() const {
|
||||||
|
return std::make_reverse_iterator(begin_);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size() const {
|
||||||
|
return end_ - begin_;
|
||||||
|
}
|
||||||
|
bool empty() const {
|
||||||
|
return end_ == begin_;
|
||||||
|
}
|
||||||
|
size_t capacity() const {
|
||||||
|
return end_of_storage_ - begin_;
|
||||||
|
}
|
||||||
|
|
||||||
|
T& front() {
|
||||||
|
assert(size() != 0);
|
||||||
|
return begin_[0];
|
||||||
|
}
|
||||||
|
const T& front() const {
|
||||||
|
assert(size() != 0);
|
||||||
|
return begin_[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
T& back() {
|
||||||
|
assert(size() != 0);
|
||||||
|
return end_[-1];
|
||||||
|
}
|
||||||
|
const T& back() const {
|
||||||
|
assert(size() != 0);
|
||||||
|
return end_[-1];
|
||||||
|
}
|
||||||
|
|
||||||
|
T& at(size_t index) {
|
||||||
|
if (index >= size()) {
|
||||||
|
throw std::out_of_range("SmallVector out of range.");
|
||||||
|
}
|
||||||
|
return begin_[index];
|
||||||
|
}
|
||||||
|
const T& at(size_t index) const {
|
||||||
|
return const_cast<SmallVector*>(this)->at(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
T& operator[](size_t index) {
|
||||||
|
assert(size() > index);
|
||||||
|
return begin_[index];
|
||||||
|
}
|
||||||
|
const T& operator[](size_t index) const {
|
||||||
|
return const_cast<SmallVector*>(this)->operator[](index);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
void emplace_back(Args&&... args) {
|
||||||
|
if (MLX_UNLIKELY(end_ == end_of_storage_)) {
|
||||||
|
grow();
|
||||||
|
}
|
||||||
|
void* storage = end_;
|
||||||
|
end_ += 1;
|
||||||
|
new (storage) T(std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
void push_back(T x) {
|
||||||
|
emplace_back(std::move(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
void pop_back(size_t count = 1) {
|
||||||
|
assert(size() >= count);
|
||||||
|
end_ -= count;
|
||||||
|
std::destroy_n(end_, count);
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator insert(iterator pos, T value) {
|
||||||
|
return insert(pos, static_cast<size_t>(1), std::move(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator insert(iterator pos, size_t count, T value) {
|
||||||
|
assert(pos <= end_);
|
||||||
|
size_t offset = pos - begin_;
|
||||||
|
size_t old_size = size();
|
||||||
|
resize(old_size + count);
|
||||||
|
pos = begin_ + offset;
|
||||||
|
iterator old_end = begin_ + old_size;
|
||||||
|
assert(old_end <= end_);
|
||||||
|
std::move_backward(pos, old_end, end_);
|
||||||
|
if constexpr (kHasTrivialElement) {
|
||||||
|
std::fill_n(pos, count, value);
|
||||||
|
} else {
|
||||||
|
std::fill_n(pos + 1, count - 1, value);
|
||||||
|
*pos = std::move(value);
|
||||||
|
}
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Iter, typename = std::enable_if_t<is_iterator_v<Iter>>>
|
||||||
|
iterator insert(iterator pos, Iter begin, Iter end) {
|
||||||
|
if constexpr (std::is_same_v<std::decay_t<Iter>, iterator>) {
|
||||||
|
// The implementation can not take overlapping range.
|
||||||
|
assert(!(begin >= pos && begin < pos + std::distance(begin, end)));
|
||||||
|
assert(!(end > pos && end <= pos + std::distance(begin, end)));
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(pos <= end_);
|
||||||
|
size_t offset = pos - begin_;
|
||||||
|
size_t count = std::distance(begin, end);
|
||||||
|
size_t old_size = size();
|
||||||
|
resize(old_size + count);
|
||||||
|
pos = begin_ + offset;
|
||||||
|
iterator old_end = begin_ + old_size;
|
||||||
|
assert(old_end <= end_);
|
||||||
|
std::move_backward(pos, old_end, end_);
|
||||||
|
std::copy(begin, end, pos);
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator insert(iterator pos, std::initializer_list<const T> values) {
|
||||||
|
return insert(pos, values.begin(), values.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator erase(iterator erase_start, iterator erase_end) {
|
||||||
|
assert(erase_start >= begin_);
|
||||||
|
assert(erase_start <= erase_end);
|
||||||
|
assert(erase_end <= end_);
|
||||||
|
iterator new_end = std::move(erase_end, end_, erase_start);
|
||||||
|
std::destroy_n(new_end, std::distance(new_end, end_));
|
||||||
|
end_ = new_end;
|
||||||
|
return erase_start;
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator erase(iterator pos) {
|
||||||
|
return erase(pos, pos + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void resize(size_t new_size) {
|
||||||
|
if (new_size > capacity()) {
|
||||||
|
grow(new_size);
|
||||||
|
}
|
||||||
|
T* new_end = begin_ + new_size;
|
||||||
|
if constexpr (!kHasTrivialElement) {
|
||||||
|
if (new_end > end_) {
|
||||||
|
std::uninitialized_default_construct(end_, new_end);
|
||||||
|
} else {
|
||||||
|
std::destroy_n(new_end, end_ - new_end);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
end_ = new_end;
|
||||||
|
}
|
||||||
|
|
||||||
|
void resize(size_t new_size, const T& initial_value) {
|
||||||
|
if (new_size > capacity()) {
|
||||||
|
grow(new_size);
|
||||||
|
}
|
||||||
|
T* new_end = begin_ + new_size;
|
||||||
|
if (new_end > end_) {
|
||||||
|
std::uninitialized_fill(end_, new_end, initial_value);
|
||||||
|
} else {
|
||||||
|
std::destroy_n(new_end, end_ - new_end);
|
||||||
|
}
|
||||||
|
end_ = new_end;
|
||||||
|
}
|
||||||
|
|
||||||
|
void reserve(size_t new_capacity) {
|
||||||
|
if (new_capacity > capacity()) {
|
||||||
|
grow(new_capacity);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear without reverting back to inline storage.
|
||||||
|
void clear() {
|
||||||
|
std::destroy_n(begin_, end_ - begin_);
|
||||||
|
end_ = begin_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Grows the backing store by a factor of two, and at least to {min_capacity}.
|
||||||
|
// TODO: Move to private after removing external code using this method.
|
||||||
|
MLX_NOINLINE void grow(size_t min_capacity = 0) {
|
||||||
|
size_t new_capacity = std::max(min_capacity, 2 * capacity());
|
||||||
|
// Round up to power of 2.
|
||||||
|
new_capacity--;
|
||||||
|
new_capacity |= new_capacity >> 1;
|
||||||
|
new_capacity |= new_capacity >> 2;
|
||||||
|
new_capacity |= new_capacity >> 4;
|
||||||
|
new_capacity |= new_capacity >> 8;
|
||||||
|
new_capacity |= new_capacity >> 16;
|
||||||
|
if constexpr (sizeof(size_t) == sizeof(uint64_t)) {
|
||||||
|
new_capacity |= new_capacity >> 32;
|
||||||
|
}
|
||||||
|
new_capacity++;
|
||||||
|
|
||||||
|
T* new_storage = allocator_.allocate(new_capacity);
|
||||||
|
if (new_storage == nullptr) {
|
||||||
|
throw std::bad_alloc();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t in_use = end_ - begin_;
|
||||||
|
std::uninitialized_move(begin_, end_, new_storage);
|
||||||
|
free_storage();
|
||||||
|
begin_ = new_storage;
|
||||||
|
end_ = new_storage + in_use;
|
||||||
|
end_of_storage_ = new_storage + new_capacity;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
MLX_NOINLINE void free_storage() {
|
||||||
|
std::destroy_n(begin_, end_ - begin_);
|
||||||
|
if (is_big()) {
|
||||||
|
allocator_.deallocate(begin_, end_of_storage_ - begin_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear and go back to inline storage. Dynamic storage is *not* freed. For
|
||||||
|
// internal use only.
|
||||||
|
void reset_to_inline_storage() {
|
||||||
|
if constexpr (!kHasTrivialElement) {
|
||||||
|
if (!is_big())
|
||||||
|
std::destroy_n(begin_, end_ - begin_);
|
||||||
|
}
|
||||||
|
begin_ = inline_storage_begin();
|
||||||
|
end_ = begin_;
|
||||||
|
end_of_storage_ = begin_ + kSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_big() const {
|
||||||
|
return begin_ != inline_storage_begin();
|
||||||
|
}
|
||||||
|
|
||||||
|
T* inline_storage_begin() {
|
||||||
|
return reinterpret_cast<T*>(inline_storage_);
|
||||||
|
}
|
||||||
|
const T* inline_storage_begin() const {
|
||||||
|
return reinterpret_cast<const T*>(inline_storage_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Allocator allocator_;
|
||||||
|
|
||||||
|
// Invariants:
|
||||||
|
// 1. The elements in the range between `begin_` (included) and `end_` (not
|
||||||
|
// included) will be initialized at all times.
|
||||||
|
// 2. All other elements outside the range, both in the inline storage and in
|
||||||
|
// the dynamic storage (if it exists), will be uninitialized at all times.
|
||||||
|
|
||||||
|
T* begin_ = inline_storage_begin();
|
||||||
|
T* end_ = begin_;
|
||||||
|
T* end_of_storage_ = begin_ + kSize;
|
||||||
|
|
||||||
|
alignas(T) char inline_storage_[sizeof(T) * kSize];
|
||||||
|
|
||||||
|
static constexpr bool kHasTrivialElement =
|
||||||
|
std::is_trivially_copyable<T>::value &&
|
||||||
|
std::is_trivially_destructible<T>::value;
|
||||||
|
};
|
||||||
|
|
||||||
|
#undef MLX_HAS_BUILTIN
|
||||||
|
#undef MLX_HAS_ATTRIBUTE
|
||||||
|
#undef MLX_LIKELY
|
||||||
|
#undef MLX_UNLIKELY
|
||||||
|
#undef MLX_NOINLINE
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -259,6 +259,25 @@ std::ostream& operator<<(std::ostream& os, array a) {
|
|||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, const SmallVector<int>& v) {
|
||||||
|
os << "(";
|
||||||
|
for (int i = 0; i < v.size(); ++i) {
|
||||||
|
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||||
|
}
|
||||||
|
os << ")";
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, const SmallVector<int64_t>& v) {
|
||||||
|
os << "(";
|
||||||
|
for (int i = 0; i < v.size(); ++i) {
|
||||||
|
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||||
|
}
|
||||||
|
os << ")";
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Code is duplicated but they should be deleted soon.
|
||||||
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
|
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
|
||||||
os << "(";
|
os << "(";
|
||||||
for (int i = 0; i < v.size(); ++i) {
|
for (int i = 0; i < v.size(); ++i) {
|
||||||
|
@ -100,6 +100,8 @@ std::ostream& operator<<(std::ostream& os, const Stream& s);
|
|||||||
std::ostream& operator<<(std::ostream& os, const Dtype& d);
|
std::ostream& operator<<(std::ostream& os, const Dtype& d);
|
||||||
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
|
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
|
||||||
std::ostream& operator<<(std::ostream& os, array a);
|
std::ostream& operator<<(std::ostream& os, array a);
|
||||||
|
std::ostream& operator<<(std::ostream& os, const SmallVector<int>& v);
|
||||||
|
std::ostream& operator<<(std::ostream& os, const SmallVector<int64_t>& v);
|
||||||
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
|
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
|
||||||
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
|
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
|
||||||
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
|
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
#include "python/src/buffer.h"
|
#include "python/src/buffer.h"
|
||||||
#include "python/src/convert.h"
|
#include "python/src/convert.h"
|
||||||
#include "python/src/indexing.h"
|
#include "python/src/indexing.h"
|
||||||
|
#include "python/src/small_vector.h"
|
||||||
#include "python/src/utils.h"
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
@ -303,7 +304,7 @@ void init_array(nb::module_& m) {
|
|||||||
R"pbdoc(The number of bytes in the array.)pbdoc")
|
R"pbdoc(The number of bytes in the array.)pbdoc")
|
||||||
.def_prop_ro(
|
.def_prop_ro(
|
||||||
"shape",
|
"shape",
|
||||||
[](const mx::array& a) { return nb::tuple(nb::cast(a.shape())); },
|
[](const mx::array& a) { return nb::cast(a.shape()); },
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
The shape of the array as a Python tuple.
|
The shape of the array as a Python tuple.
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@
|
|||||||
|
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
|
#include "python/src/small_vector.h"
|
||||||
#include "python/src/utils.h"
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/export.h"
|
#include "mlx/export.h"
|
||||||
#include "mlx/graph_utils.h"
|
#include "mlx/graph_utils.h"
|
||||||
|
#include "python/src/small_vector.h"
|
||||||
#include "python/src/trees.h"
|
#include "python/src/trees.h"
|
||||||
|
|
||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
@ -8,10 +8,10 @@
|
|||||||
#include <nanobind/stl/variant.h>
|
#include <nanobind/stl/variant.h>
|
||||||
#include <nanobind/stl/vector.h>
|
#include <nanobind/stl/vector.h>
|
||||||
|
|
||||||
#include "python/src/utils.h"
|
|
||||||
|
|
||||||
#include "mlx/fast.h"
|
#include "mlx/fast.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
|
#include "python/src/small_vector.h"
|
||||||
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include "mlx/fft.h"
|
#include "mlx/fft.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
|
#include "python/src/small_vector.h"
|
||||||
|
|
||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
#include <nanobind/stl/vector.h>
|
#include <nanobind/stl/vector.h>
|
||||||
|
|
||||||
#include "mlx/linalg.h"
|
#include "mlx/linalg.h"
|
||||||
|
#include "python/src/small_vector.h"
|
||||||
|
|
||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
#include "python/src/load.h"
|
#include "python/src/load.h"
|
||||||
|
#include "python/src/small_vector.h"
|
||||||
#include "python/src/utils.h"
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
@ -7,8 +7,10 @@
|
|||||||
#include <nanobind/stl/unordered_map.h>
|
#include <nanobind/stl/unordered_map.h>
|
||||||
#include <nanobind/stl/variant.h>
|
#include <nanobind/stl/variant.h>
|
||||||
#include <nanobind/stl/vector.h>
|
#include <nanobind/stl/vector.h>
|
||||||
|
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/memory.h"
|
#include "mlx/memory.h"
|
||||||
|
#include "python/src/small_vector.h"
|
||||||
|
|
||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
#include "python/src/load.h"
|
#include "python/src/load.h"
|
||||||
|
#include "python/src/small_vector.h"
|
||||||
#include "python/src/utils.h"
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
@ -7,10 +7,10 @@
|
|||||||
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
|
||||||
#include "python/src/utils.h"
|
|
||||||
|
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/random.h"
|
#include "mlx/random.h"
|
||||||
|
#include "python/src/small_vector.h"
|
||||||
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
|
77
python/src/small_vector.h
Normal file
77
python/src/small_vector.h
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/small_vector.h"
|
||||||
|
|
||||||
|
#include <nanobind/stl/detail/nb_list.h>
|
||||||
|
|
||||||
|
NAMESPACE_BEGIN(NB_NAMESPACE)
|
||||||
|
NAMESPACE_BEGIN(detail)
|
||||||
|
|
||||||
|
template <typename Type, size_t Size, typename Alloc>
|
||||||
|
struct type_caster<mlx::core::SmallVector<Type, Size, Alloc>> {
|
||||||
|
using List = mlx::core::SmallVector<Type, Size, Alloc>;
|
||||||
|
using Caster = make_caster<Type>;
|
||||||
|
|
||||||
|
NB_TYPE_CASTER(
|
||||||
|
List,
|
||||||
|
const_name(NB_TYPING_TUPLE "[") + make_caster<Type>::Name +
|
||||||
|
const_name(", ...]"))
|
||||||
|
|
||||||
|
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept {
|
||||||
|
size_t size;
|
||||||
|
PyObject* temp;
|
||||||
|
|
||||||
|
// Will initialize 'size' and 'temp'. All return values and
|
||||||
|
// return parameters are zero/NULL in the case of a failure.
|
||||||
|
PyObject** o = seq_get(src.ptr(), &size, &temp);
|
||||||
|
|
||||||
|
value.clear();
|
||||||
|
value.reserve(size);
|
||||||
|
|
||||||
|
Caster caster;
|
||||||
|
bool success = o != nullptr;
|
||||||
|
|
||||||
|
flags = flags_for_local_caster<Type>(flags);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
if (!caster.from_python(o[i], flags, cleanup) ||
|
||||||
|
!caster.template can_cast<Type>()) {
|
||||||
|
success = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
value.push_back(caster.operator cast_t<Type>());
|
||||||
|
}
|
||||||
|
|
||||||
|
Py_XDECREF(temp);
|
||||||
|
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static handle from_cpp(T&& src, rv_policy policy, cleanup_list* cleanup) {
|
||||||
|
object ret = steal(PyTuple_New(src.size()));
|
||||||
|
|
||||||
|
if (ret.is_valid()) {
|
||||||
|
Py_ssize_t index = 0;
|
||||||
|
|
||||||
|
for (auto&& value : src) {
|
||||||
|
handle h = Caster::from_cpp(forward_like_<T>(value), policy, cleanup);
|
||||||
|
|
||||||
|
if (!h.is_valid()) {
|
||||||
|
ret.reset();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
NB_TUPLE_SET_ITEM(ret.ptr(), index++, h.ptr());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret.release();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
NAMESPACE_END(detail)
|
||||||
|
NAMESPACE_END(NB_NAMESPACE)
|
@ -20,6 +20,7 @@
|
|||||||
#include "mlx/transforms_impl.h"
|
#include "mlx/transforms_impl.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
#include "python/src/mlx_func.h"
|
#include "python/src/mlx_func.h"
|
||||||
|
#include "python/src/small_vector.h"
|
||||||
#include "python/src/trees.h"
|
#include "python/src/trees.h"
|
||||||
|
|
||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
|
Loading…
Reference in New Issue
Block a user