From 37b440faa89fffc0406c08990f256f3054adfd26 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 16 Aug 2025 09:01:10 +0900 Subject: [PATCH] Clean up code handling both std::vector and SmallVector (#2493) --- mlx/backend/metal/device.h | 20 +++++--------------- mlx/small_vector.h | 12 ++++++++++++ mlx/utils.cpp | 37 ------------------------------------- mlx/utils.h | 17 +++++++++++++---- 4 files changed, 30 insertions(+), 56 deletions(-) diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 00df2ddeba..fefb7cdc0c 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -60,22 +60,12 @@ struct CommandEncoder { enc_->updateFence(fence); } - template - void set_vector_bytes(const SmallVector& vec, size_t nelems, int idx) { - enc_->setBytes(vec.data(), nelems * sizeof(T), idx); + template >> + void set_vector_bytes(const Vec& vec, size_t nelems, int idx) { + enc_->setBytes(vec.data(), nelems * sizeof(typename Vec::value_type), idx); } - template - void set_vector_bytes(const SmallVector& vec, int idx) { - return set_vector_bytes(vec, vec.size(), idx); - } - - // TODO: Code is duplicated but they should be deleted soon. - template - void set_vector_bytes(const std::vector& vec, size_t nelems, int idx) { - enc_->setBytes(vec.data(), nelems * sizeof(T), idx); - } - template - void set_vector_bytes(const std::vector& vec, int idx) { + template >> + void set_vector_bytes(const Vec& vec, int idx) { return set_vector_bytes(vec, vec.size(), idx); } diff --git a/mlx/small_vector.h b/mlx/small_vector.h index 0a3371058a..143101c82f 100644 --- a/mlx/small_vector.h +++ b/mlx/small_vector.h @@ -519,6 +519,18 @@ class SmallVector { std::is_trivially_destructible::value; }; +template +struct is_vector : std::false_type {}; + +template +struct is_vector> : std::true_type {}; + +template +struct is_vector> : std::true_type {}; + +template +inline constexpr bool is_vector_v = is_vector::value; + #undef MLX_HAS_BUILTIN #undef MLX_HAS_ATTRIBUTE #undef MLX_LIKELY diff --git a/mlx/utils.cpp b/mlx/utils.cpp index eac18239ee..2a850d9f99 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -259,43 +259,6 @@ std::ostream& operator<<(std::ostream& os, array a) { return os; } -std::ostream& operator<<(std::ostream& os, const SmallVector& v) { - os << "("; - for (int i = 0; i < v.size(); ++i) { - os << v[i] << ((i == v.size() - 1) ? "" : ","); - } - os << ")"; - return os; -} - -std::ostream& operator<<(std::ostream& os, const SmallVector& v) { - os << "("; - for (int i = 0; i < v.size(); ++i) { - os << v[i] << ((i == v.size() - 1) ? "" : ","); - } - os << ")"; - return os; -} - -// TODO: Code is duplicated but they should be deleted soon. -std::ostream& operator<<(std::ostream& os, const std::vector& v) { - os << "("; - for (int i = 0; i < v.size(); ++i) { - os << v[i] << ((i == v.size() - 1) ? "" : ","); - } - os << ")"; - return os; -} - -std::ostream& operator<<(std::ostream& os, const std::vector& v) { - os << "("; - for (int i = 0; i < v.size(); ++i) { - os << v[i] << ((i == v.size() - 1) ? "" : ","); - } - os << ")"; - return os; -} - namespace env { int get_var(const char* name, int default_value) { diff --git a/mlx/utils.h b/mlx/utils.h index 4513935407..076842f78a 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -100,10 +100,6 @@ std::ostream& operator<<(std::ostream& os, const Stream& s); std::ostream& operator<<(std::ostream& os, const Dtype& d); std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k); std::ostream& operator<<(std::ostream& os, array a); -std::ostream& operator<<(std::ostream& os, const SmallVector& v); -std::ostream& operator<<(std::ostream& os, const SmallVector& v); -std::ostream& operator<<(std::ostream& os, const std::vector& v); -std::ostream& operator<<(std::ostream& os, const std::vector& v); inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j"; } @@ -114,6 +110,19 @@ inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) { return os << static_cast(v); } +template >> +inline std::ostream& operator<<(std::ostream& os, const Vec& v) { + os << "("; + for (auto it = v.begin(); it != v.end(); ++it) { + os << *it; + if (it != std::prev(v.end())) { + os << ","; + } + } + os << ")"; + return os; +} + inline bool is_power_of_2(int n) { return ((n & (n - 1)) == 0) && n != 0; }