mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 01:46:37 +08:00
Clean up code handling both std::vector and SmallVector (#2493)
This commit is contained in:
parent
888b13ed63
commit
37b440faa8
@ -60,22 +60,12 @@ struct CommandEncoder {
|
||||
enc_->updateFence(fence);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void set_vector_bytes(const SmallVector<T>& vec, size_t nelems, int idx) {
|
||||
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
|
||||
void set_vector_bytes(const Vec& vec, size_t nelems, int idx) {
|
||||
enc_->setBytes(vec.data(), nelems * sizeof(typename Vec::value_type), idx);
|
||||
}
|
||||
template <typename T>
|
||||
void set_vector_bytes(const SmallVector<T>& vec, int idx) {
|
||||
return set_vector_bytes(vec, vec.size(), idx);
|
||||
}
|
||||
|
||||
// TODO: Code is duplicated but they should be deleted soon.
|
||||
template <typename T>
|
||||
void set_vector_bytes(const std::vector<T>& vec, size_t nelems, int idx) {
|
||||
enc_->setBytes(vec.data(), nelems * sizeof(T), idx);
|
||||
}
|
||||
template <typename T>
|
||||
void set_vector_bytes(const std::vector<T>& vec, int idx) {
|
||||
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
|
||||
void set_vector_bytes(const Vec& vec, int idx) {
|
||||
return set_vector_bytes(vec, vec.size(), idx);
|
||||
}
|
||||
|
||||
|
@ -519,6 +519,18 @@ class SmallVector {
|
||||
std::is_trivially_destructible<T>::value;
|
||||
};
|
||||
|
||||
template <typename>
|
||||
struct is_vector : std::false_type {};
|
||||
|
||||
template <typename T, size_t Size, typename Allocator>
|
||||
struct is_vector<SmallVector<T, Size, Allocator>> : std::true_type {};
|
||||
|
||||
template <typename T, typename Allocator>
|
||||
struct is_vector<std::vector<T, Allocator>> : std::true_type {};
|
||||
|
||||
template <typename Vec>
|
||||
inline constexpr bool is_vector_v = is_vector<Vec>::value;
|
||||
|
||||
#undef MLX_HAS_BUILTIN
|
||||
#undef MLX_HAS_ATTRIBUTE
|
||||
#undef MLX_LIKELY
|
||||
|
@ -259,43 +259,6 @@ std::ostream& operator<<(std::ostream& os, array a) {
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const SmallVector<int>& v) {
|
||||
os << "(";
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||
}
|
||||
os << ")";
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const SmallVector<int64_t>& v) {
|
||||
os << "(";
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||
}
|
||||
os << ")";
|
||||
return os;
|
||||
}
|
||||
|
||||
// TODO: Code is duplicated but they should be deleted soon.
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v) {
|
||||
os << "(";
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||
}
|
||||
os << ")";
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v) {
|
||||
os << "(";
|
||||
for (int i = 0; i < v.size(); ++i) {
|
||||
os << v[i] << ((i == v.size() - 1) ? "" : ",");
|
||||
}
|
||||
os << ")";
|
||||
return os;
|
||||
}
|
||||
|
||||
namespace env {
|
||||
|
||||
int get_var(const char* name, int default_value) {
|
||||
|
17
mlx/utils.h
17
mlx/utils.h
@ -100,10 +100,6 @@ std::ostream& operator<<(std::ostream& os, const Stream& s);
|
||||
std::ostream& operator<<(std::ostream& os, const Dtype& d);
|
||||
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
|
||||
std::ostream& operator<<(std::ostream& os, array a);
|
||||
std::ostream& operator<<(std::ostream& os, const SmallVector<int>& v);
|
||||
std::ostream& operator<<(std::ostream& os, const SmallVector<int64_t>& v);
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int>& v);
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v);
|
||||
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
|
||||
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
|
||||
}
|
||||
@ -114,6 +110,19 @@ inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
|
||||
return os << static_cast<float>(v);
|
||||
}
|
||||
|
||||
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
|
||||
inline std::ostream& operator<<(std::ostream& os, const Vec& v) {
|
||||
os << "(";
|
||||
for (auto it = v.begin(); it != v.end(); ++it) {
|
||||
os << *it;
|
||||
if (it != std::prev(v.end())) {
|
||||
os << ",";
|
||||
}
|
||||
}
|
||||
os << ")";
|
||||
return os;
|
||||
}
|
||||
|
||||
inline bool is_power_of_2(int n) {
|
||||
return ((n & (n - 1)) == 0) && n != 0;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user