Use SmallVector for shapes and strides

This commit is contained in:
Cheng
2025-08-01 21:38:10 +09:00
parent be9bc96da4
commit 68e9c60d22
30 changed files with 677 additions and 101 deletions

View File

@@ -15,6 +15,7 @@
#include "python/src/buffer.h"
#include "python/src/convert.h"
#include "python/src/indexing.h"
#include "python/src/small_vector.h"
#include "python/src/utils.h"
#include "mlx/mlx.h"

View File

@@ -9,7 +9,7 @@
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/ops.h"
#include "python/src/small_vector.h"
#include "python/src/utils.h"
namespace mx = mlx::core;

View File

@@ -10,6 +10,7 @@
#include "mlx/array.h"
#include "mlx/export.h"
#include "mlx/graph_utils.h"
#include "python/src/small_vector.h"
#include "python/src/trees.h"
namespace mx = mlx::core;

View File

@@ -8,10 +8,10 @@
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include "python/src/utils.h"
#include "mlx/fast.h"
#include "mlx/ops.h"
#include "python/src/small_vector.h"
#include "python/src/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind;

View File

@@ -8,6 +8,7 @@
#include "mlx/fft.h"
#include "mlx/ops.h"
#include "python/src/small_vector.h"
namespace mx = mlx::core;
namespace nb = nanobind;

View File

@@ -9,6 +9,7 @@
#include <nanobind/stl/vector.h>
#include "mlx/linalg.h"
#include "python/src/small_vector.h"
namespace mx = mlx::core;
namespace nb = nanobind;

View File

@@ -12,6 +12,7 @@
#include "mlx/ops.h"
#include "mlx/utils.h"
#include "python/src/load.h"
#include "python/src/small_vector.h"
#include "python/src/utils.h"
namespace mx = mlx::core;

View File

@@ -7,8 +7,10 @@
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include "mlx/backend/metal/metal.h"
#include "mlx/memory.h"
#include "python/src/small_vector.h"
namespace mx = mlx::core;
namespace nb = nanobind;

View File

@@ -16,6 +16,7 @@
#include "mlx/ops.h"
#include "mlx/utils.h"
#include "python/src/load.h"
#include "python/src/small_vector.h"
#include "python/src/utils.h"
namespace mx = mlx::core;

View File

@@ -7,10 +7,10 @@
#include <chrono>
#include "python/src/utils.h"
#include "mlx/ops.h"
#include "mlx/random.h"
#include "python/src/small_vector.h"
#include "python/src/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind;

17
python/src/small_vector.h Normal file
View File

@@ -0,0 +1,17 @@
// Copyright © 2025 Apple Inc.
#pragma once
#include "mlx/small_vector.h"
#include <nanobind/stl/detail/nb_list.h>
NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)
template <typename Type, size_t Size, typename Alloc>
struct type_caster<mlx::core::SmallVector<Type, Size, Alloc>>
: list_caster<mlx::core::SmallVector<Type, Size, Alloc>, Type> {};
NAMESPACE_END(detail)
NAMESPACE_END(NB_NAMESPACE)

View File

@@ -20,6 +20,7 @@
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
#include "python/src/mlx_func.h"
#include "python/src/small_vector.h"
#include "python/src/trees.h"
namespace mx = mlx::core;