mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-29 13:01:14 +08:00
Add missing && when forwarding args (#894)
Without the && args would be copied and perfect forwarding won't work. Also add template utils to make sure the function only forwards array and not vector<array>.
This commit is contained in:
parent
8e686764ac
commit
28fcd2b519
11
mlx/array.h
11
mlx/array.h
@ -510,4 +510,15 @@ void array::init(It src) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Utilities for determining whether a template parameter is array. */
|
||||||
|
template <typename T>
|
||||||
|
inline constexpr bool is_array_v =
|
||||||
|
std::is_same_v<std::remove_cv_t<std::remove_reference_t<T>>, array>;
|
||||||
|
|
||||||
|
template <typename... T>
|
||||||
|
inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
|
||||||
|
|
||||||
|
template <typename... T>
|
||||||
|
using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -89,9 +89,8 @@ collapse_contiguous_dims(const std::vector<array>& xs) {
|
|||||||
return collapse_contiguous_dims(xs[0].shape(), strides);
|
return collapse_contiguous_dims(xs[0].shape(), strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... Arrays>
|
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||||
inline std::tuple<std::vector<int>, std::vector<std::vector<size_t>>>
|
inline auto collapse_contiguous_dims(Arrays&&... xs) {
|
||||||
collapse_contiguous_dims(Arrays... xs) {
|
|
||||||
return collapse_contiguous_dims(
|
return collapse_contiguous_dims(
|
||||||
std::vector<array>{std::forward<Arrays>(xs)...});
|
std::vector<array>{std::forward<Arrays>(xs)...});
|
||||||
}
|
}
|
||||||
|
@ -14,15 +14,15 @@ struct NodeNamer {
|
|||||||
|
|
||||||
void print_graph(std::ostream& os, const std::vector<array>& outputs);
|
void print_graph(std::ostream& os, const std::vector<array>& outputs);
|
||||||
|
|
||||||
template <typename... Arrays>
|
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||||
void print_graph(std::ostream& os, Arrays... outputs) {
|
void print_graph(std::ostream& os, Arrays&&... outputs) {
|
||||||
print_graph(os, std::vector<array>{std::forward<Arrays>(outputs)...});
|
print_graph(os, std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||||
}
|
}
|
||||||
|
|
||||||
void export_to_dot(std::ostream& os, const std::vector<array>& outputs);
|
void export_to_dot(std::ostream& os, const std::vector<array>& outputs);
|
||||||
|
|
||||||
template <typename... Arrays>
|
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||||
void export_to_dot(std::ostream& os, Arrays... outputs) {
|
void export_to_dot(std::ostream& os, Arrays&&... outputs) {
|
||||||
export_to_dot(os, std::vector<array>{std::forward<Arrays>(outputs)...});
|
export_to_dot(os, std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ namespace mlx::core {
|
|||||||
|
|
||||||
void eval(std::vector<array> outputs);
|
void eval(std::vector<array> outputs);
|
||||||
|
|
||||||
template <typename... Arrays>
|
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
||||||
void eval(Arrays&&... outputs) {
|
void eval(Arrays&&... outputs) {
|
||||||
eval(std::vector<array>{std::forward<Arrays>(outputs)...});
|
eval(std::vector<array>{std::forward<Arrays>(outputs)...});
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user