From 28fcd2b519f0fabcd681f2c33e14d71983cad819 Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 26 Mar 2024 06:55:54 +0900 Subject: [PATCH] Add missing && when forwarding args (#894) Without the && args would be copied and perfect forwarding won't work. Also add template utils to make sure the function only forwards array and not vector. --- mlx/array.h | 11 +++++++++++ mlx/backend/common/utils.h | 5 ++--- mlx/graph_utils.h | 8 ++++---- mlx/transforms.h | 2 +- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/mlx/array.h b/mlx/array.h index 9b7cf0f8f..266353196 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -510,4 +510,15 @@ void array::init(It src) { } } +/* Utilities for determining whether a template parameter is array. */ +template +inline constexpr bool is_array_v = + std::is_same_v>, array>; + +template +inline constexpr bool is_arrays_v = (is_array_v && ...); + +template +using enable_for_arrays_t = typename std::enable_if_t>; + } // namespace mlx::core diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 24ded411b..a1a1207c7 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -89,9 +89,8 @@ collapse_contiguous_dims(const std::vector& xs) { return collapse_contiguous_dims(xs[0].shape(), strides); } -template -inline std::tuple, std::vector>> -collapse_contiguous_dims(Arrays... xs) { +template > +inline auto collapse_contiguous_dims(Arrays&&... xs) { return collapse_contiguous_dims( std::vector{std::forward(xs)...}); } diff --git a/mlx/graph_utils.h b/mlx/graph_utils.h index 5e024704e..d70228dad 100644 --- a/mlx/graph_utils.h +++ b/mlx/graph_utils.h @@ -14,15 +14,15 @@ struct NodeNamer { void print_graph(std::ostream& os, const std::vector& outputs); -template -void print_graph(std::ostream& os, Arrays... outputs) { +template > +void print_graph(std::ostream& os, Arrays&&... outputs) { print_graph(os, std::vector{std::forward(outputs)...}); } void export_to_dot(std::ostream& os, const std::vector& outputs); -template -void export_to_dot(std::ostream& os, Arrays... outputs) { +template > +void export_to_dot(std::ostream& os, Arrays&&... outputs) { export_to_dot(os, std::vector{std::forward(outputs)...}); } diff --git a/mlx/transforms.h b/mlx/transforms.h index c3d84287d..a83765100 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -8,7 +8,7 @@ namespace mlx::core { void eval(std::vector outputs); -template +template > void eval(Arrays&&... outputs) { eval(std::vector{std::forward(outputs)...}); }