diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index dfa21aa0a2..00df2ddeba 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -104,7 +104,7 @@ struct CommandEncoder { }; // Outputs of all kernels in the encoder including temporaries - std::unordered_set outputs() { + std::unordered_set& outputs() { return all_outputs_; };