Add missing && in eval (#864)

Without the && args would be copied and perfect forwarding won't work.

To avoid eval calling itself recursively, the vector version of eval is
changed to take by value instead, which will save a copy of array when a
rvalue is passed.
This commit is contained in:
Cheng 2024-03-21 22:15:48 +09:00 committed by GitHub
parent a5681ebc52
commit 4650d94d98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 5 deletions

View File

@ -37,7 +37,7 @@ class Synchronizer : public Primitive {
// are currently under a function transformation. // are currently under a function transformation.
int detail::InTracing::tracing_counter{0}; int detail::InTracing::tracing_counter{0};
void eval(const std::vector<array>& outputs) { void eval(std::vector<array> outputs) {
std::function<void(const array&)> recurse; std::function<void(const array&)> recurse;
std::queue<array> tape; std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache; std::unordered_set<std::uintptr_t> cache;
@ -52,8 +52,8 @@ void eval(const std::vector<array>& outputs) {
} }
} }
auto synchronizer = auto synchronizer = array(
array({}, bool_, std::make_unique<Synchronizer>(stream), outputs); {}, bool_, std::make_unique<Synchronizer>(stream), std::move(outputs));
size_t depth_counter = 0; size_t depth_counter = 0;
recurse = [&](const array& a) { recurse = [&](const array& a) {

View File

@ -6,10 +6,10 @@
namespace mlx::core { namespace mlx::core {
void eval(const std::vector<array>& outputs); void eval(std::vector<array> outputs);
template <typename... Arrays> template <typename... Arrays>
void eval(Arrays... outputs) { void eval(Arrays&&... outputs) {
eval(std::vector<array>{std::forward<Arrays>(outputs)...}); eval(std::vector<array>{std::forward<Arrays>(outputs)...});
} }