From fd94be28ead745a1c80656a409b73d5b1e56bd2a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 13 Jan 2024 13:34:27 -0800 Subject: [PATCH] fix test + choose stream with slight care --- mlx/transforms.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index ed0c082e9..6135a54f7 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -204,11 +204,17 @@ void eval(const std::vector& outputs) { std::unordered_set cache; std::unordered_map> deps; - auto synchronizer = array( - {}, - bool_, - std::make_unique(default_stream(default_device())), - outputs); + // Make an effort to choose a good output stream + Stream stream = default_stream(default_device()); + for (auto& o : outputs) { + if (!o.is_evaled() && o.has_primitive()) { + stream = o.primitive().stream(); + break; + } + } + + auto synchronizer = + array({}, bool_, std::make_unique(stream), outputs); recurse = [&](const array& a) { auto id = a.id();