diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 00df2ddeb..fefb7cdc0 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -60,22 +60,12 @@ struct CommandEncoder { enc_->updateFence(fence); } - template - void set_vector_bytes(const SmallVector& vec, size_t nelems, int idx) { - enc_->setBytes(vec.data(), nelems * sizeof(T), idx); + template >> + void set_vector_bytes(const Vec& vec, size_t nelems, int idx) { + enc_->setBytes(vec.data(), nelems * sizeof(typename Vec::value_type), idx); } - template - void set_vector_bytes(const SmallVector& vec, int idx) { - return set_vector_bytes(vec, vec.size(), idx); - } - - // TODO: Code is duplicated but they should be deleted soon. - template - void set_vector_bytes(const std::vector& vec, size_t nelems, int idx) { - enc_->setBytes(vec.data(), nelems * sizeof(T), idx); - } - template - void set_vector_bytes(const std::vector& vec, int idx) { + template >> + void set_vector_bytes(const Vec& vec, int idx) { return set_vector_bytes(vec, vec.size(), idx); } diff --git a/mlx/small_vector.h b/mlx/small_vector.h index 0a3371058..143101c82 100644 --- a/mlx/small_vector.h +++ b/mlx/small_vector.h @@ -519,6 +519,18 @@ class SmallVector { std::is_trivially_destructible::value; }; +template +struct is_vector : std::false_type {}; + +template +struct is_vector> : std::true_type {}; + +template +struct is_vector> : std::true_type {}; + +template +inline constexpr bool is_vector_v = is_vector::value; + #undef MLX_HAS_BUILTIN #undef MLX_HAS_ATTRIBUTE #undef MLX_LIKELY diff --git a/mlx/utils.cpp b/mlx/utils.cpp index eac18239e..2a850d9f9 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -259,43 +259,6 @@ std::ostream& operator<<(std::ostream& os, array a) { return os; } -std::ostream& operator<<(std::ostream& os, const SmallVector& v) { - os << "("; - for (int i = 0; i < v.size(); ++i) { - os << v[i] << ((i == v.size() - 1) ? "" : ","); - } - os << ")"; - return os; -} - -std::ostream& operator<<(std::ostream& os, const SmallVector& v) { - os << "("; - for (int i = 0; i < v.size(); ++i) { - os << v[i] << ((i == v.size() - 1) ? "" : ","); - } - os << ")"; - return os; -} - -// TODO: Code is duplicated but they should be deleted soon. -std::ostream& operator<<(std::ostream& os, const std::vector& v) { - os << "("; - for (int i = 0; i < v.size(); ++i) { - os << v[i] << ((i == v.size() - 1) ? "" : ","); - } - os << ")"; - return os; -} - -std::ostream& operator<<(std::ostream& os, const std::vector& v) { - os << "("; - for (int i = 0; i < v.size(); ++i) { - os << v[i] << ((i == v.size() - 1) ? "" : ","); - } - os << ")"; - return os; -} - namespace env { int get_var(const char* name, int default_value) { diff --git a/mlx/utils.h b/mlx/utils.h index 451393540..076842f78 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -100,10 +100,6 @@ std::ostream& operator<<(std::ostream& os, const Stream& s); std::ostream& operator<<(std::ostream& os, const Dtype& d); std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k); std::ostream& operator<<(std::ostream& os, array a); -std::ostream& operator<<(std::ostream& os, const SmallVector& v); -std::ostream& operator<<(std::ostream& os, const SmallVector& v); -std::ostream& operator<<(std::ostream& os, const std::vector& v); -std::ostream& operator<<(std::ostream& os, const std::vector& v); inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j"; } @@ -114,6 +110,19 @@ inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) { return os << static_cast(v); } +template >> +inline std::ostream& operator<<(std::ostream& os, const Vec& v) { + os << "("; + for (auto it = v.begin(); it != v.end(); ++it) { + os << *it; + if (it != std::prev(v.end())) { + os << ","; + } + } + os << ")"; + return os; +} + inline bool is_power_of_2(int n) { return ((n & (n - 1)) == 0) && n != 0; }