From 4650d94d98a615462a27c6c099f0d2bda8cb4e86 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 21 Mar 2024 22:15:48 +0900 Subject: [PATCH] Add missing && in eval (#864) Without the && args would be copied and perfect forwarding won't work. To avoid eval calling itself recursively, the vector version of eval is changed to take by value instead, which will save a copy of array when a rvalue is passed. --- mlx/transforms.cpp | 6 +++--- mlx/transforms.h | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 88581c52d..7a7a134fd 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -37,7 +37,7 @@ class Synchronizer : public Primitive { // are currently under a function transformation. int detail::InTracing::tracing_counter{0}; -void eval(const std::vector& outputs) { +void eval(std::vector outputs) { std::function recurse; std::queue tape; std::unordered_set cache; @@ -52,8 +52,8 @@ void eval(const std::vector& outputs) { } } - auto synchronizer = - array({}, bool_, std::make_unique(stream), outputs); + auto synchronizer = array( + {}, bool_, std::make_unique(stream), std::move(outputs)); size_t depth_counter = 0; recurse = [&](const array& a) { diff --git a/mlx/transforms.h b/mlx/transforms.h index 95fdfcffe..c3d84287d 100644 --- a/mlx/transforms.h +++ b/mlx/transforms.h @@ -6,10 +6,10 @@ namespace mlx::core { -void eval(const std::vector& outputs); +void eval(std::vector outputs); template -void eval(Arrays... outputs) { +void eval(Arrays&&... outputs) { eval(std::vector{std::forward(outputs)...}); }