mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Use SmallVector for shapes and strides
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
#include "python/src/buffer.h"
|
||||
#include "python/src/convert.h"
|
||||
#include "python/src/indexing.h"
|
||||
#include "python/src/small_vector.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
|
||||
#include "python/src/small_vector.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/export.h"
|
||||
#include "mlx/graph_utils.h"
|
||||
#include "python/src/small_vector.h"
|
||||
#include "python/src/trees.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
@@ -8,10 +8,10 @@
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include "python/src/utils.h"
|
||||
|
||||
#include "mlx/fast.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "python/src/small_vector.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "python/src/small_vector.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include "mlx/linalg.h"
|
||||
#include "python/src/small_vector.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "python/src/load.h"
|
||||
#include "python/src/small_vector.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
@@ -7,8 +7,10 @@
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/memory.h"
|
||||
#include "python/src/small_vector.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "python/src/load.h"
|
||||
#include "python/src/small_vector.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
@@ -7,10 +7,10 @@
|
||||
|
||||
#include <chrono>
|
||||
|
||||
#include "python/src/utils.h"
|
||||
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/random.h"
|
||||
#include "python/src/small_vector.h"
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
|
||||
17
python/src/small_vector.h
Normal file
17
python/src/small_vector.h
Normal file
@@ -0,0 +1,17 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/small_vector.h"
|
||||
|
||||
#include <nanobind/stl/detail/nb_list.h>
|
||||
|
||||
NAMESPACE_BEGIN(NB_NAMESPACE)
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <typename Type, size_t Size, typename Alloc>
|
||||
struct type_caster<mlx::core::SmallVector<Type, Size, Alloc>>
|
||||
: list_caster<mlx::core::SmallVector<Type, Size, Alloc>, Type> {};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
NAMESPACE_END(NB_NAMESPACE)
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "mlx/transforms_impl.h"
|
||||
#include "mlx/utils.h"
|
||||
#include "python/src/mlx_func.h"
|
||||
#include "python/src/small_vector.h"
|
||||
#include "python/src/trees.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
|
||||
Reference in New Issue
Block a user