Add missing && in eval (#864)

Without the && args would be copied and perfect forwarding won't work.

To avoid eval calling itself recursively, the vector version of eval is
changed to take by value instead, which will save a copy of array when a
rvalue is passed.
This commit is contained in:
Cheng 2024-03-21 22:15:48 +09:00 committed by GitHub
parent a5681ebc52
commit 4650d94d98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 5 deletions

View File

@ -37,7 +37,7 @@ class Synchronizer : public Primitive {
// are currently under a function transformation.
int detail::InTracing::tracing_counter{0};
void eval(const std::vector<array>& outputs) {
void eval(std::vector<array> outputs) {
std::function<void(const array&)> recurse;
std::queue<array> tape;
std::unordered_set<std::uintptr_t> cache;
@ -52,8 +52,8 @@ void eval(const std::vector<array>& outputs) {
}
}
auto synchronizer =
array({}, bool_, std::make_unique<Synchronizer>(stream), outputs);
auto synchronizer = array(
{}, bool_, std::make_unique<Synchronizer>(stream), std::move(outputs));
size_t depth_counter = 0;
recurse = [&](const array& a) {

View File

@ -6,10 +6,10 @@
namespace mlx::core {
void eval(const std::vector<array>& outputs);
void eval(std::vector<array> outputs);
template <typename... Arrays>
void eval(Arrays... outputs) {
void eval(Arrays&&... outputs) {
eval(std::vector<array>{std::forward<Arrays>(outputs)...});
}