diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index ed0c082e9..6135a54f7 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -204,11 +204,17 @@ void eval(const std::vector& outputs) { std::unordered_set cache; std::unordered_map> deps; - auto synchronizer = array( - {}, - bool_, - std::make_unique(default_stream(default_device())), - outputs); + // Make an effort to choose a good output stream + Stream stream = default_stream(default_device()); + for (auto& o : outputs) { + if (!o.is_evaled() && o.has_primitive()) { + stream = o.primitive().stream(); + break; + } + } + + auto synchronizer = + array({}, bool_, std::make_unique(stream), outputs); recurse = [&](const array& a) { auto id = a.id();