From c1fe1ef081fbcaad9bb2cc82a612da6b7752a224 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 8 Nov 2024 15:00:46 -0800 Subject: [PATCH] Bfs width limit (#1568) * width limit * fix * large limit * put env vars in env namespace --- mlx/backend/metal/metal.cpp | 19 +++---------------- mlx/transforms.cpp | 30 +++++++++++++++++++++++++----- mlx/utils.cpp | 15 ++++++++++++++- mlx/utils.h | 16 ++++++++++++++++ 4 files changed, 58 insertions(+), 22 deletions(-) diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 2a5e6334e..4b662bb36 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -1,11 +1,11 @@ // Copyright © 2023-2024 Apple Inc. -#include #include #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/utils.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" +#include "mlx/utils.h" namespace mlx::core::metal { @@ -13,20 +13,6 @@ bool is_available() { return true; } -int max_ops_per_buffer() { - auto get_val = []() { - if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) { - return atoi(buff_str); - } else { - return 10; - } - }; - static int max_ops_per_buffer_ = get_val(); - return max_ops_per_buffer_; -} - -#define MAX_OPS_PER_BUFFER max_ops_per_buffer() - inline void check_error(MTL::CommandBuffer* cbuf) { if (cbuf->status() == MTL::CommandBufferStatusError) { std::ostringstream msg; @@ -77,7 +63,8 @@ std::function make_task(array arr, bool signal) { out.set_status(array::Status::evaluated); } - if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) { + if (signal || + d.get_command_buffer_ops(s.index) >= env::max_ops_per_buffer()) { d.end_encoding(s.index); if (signal) { command_buffer->encodeSignalEvent( diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 1d5127389..7a24b2f94 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -1,5 +1,6 @@ // Copyright © 2023-2024 Apple Inc. #include +#include #include #include #include @@ -40,7 +41,7 @@ int detail::InTracing::tracing_counter{0}; int detail::RetainGraph::tracing_counter{0}; array eval_impl(std::vector outputs, bool async) { - std::vector tape; + std::deque tape; // stream events to use for synchronization std::unordered_map events; @@ -72,6 +73,7 @@ array eval_impl(std::vector outputs, bool async) { while (!dfs.empty()) { auto& [a_ref, idx] = dfs.top(); auto& a = a_ref.get(); + if (idx < a.inputs().size()) { // Add an input, and continue auto& in = a.inputs()[idx++]; @@ -130,14 +132,32 @@ array eval_impl(std::vector outputs, bool async) { dfs.pop(); } - // Build the tape in BFS order + // Build the tape in BFS order with a width limit + int max_width = env::bfs_max_width(); + dfs = std::stack, int>>(); tape.push_back(synchronizer); - for (int i = 0; !cache.empty() && i < tape.size(); ++i) { - auto& a = tape[i]; - for (auto& in : a.inputs()) { + for (int i = 0; !cache.empty() && (i < tape.size() || !dfs.empty());) { + auto& a = (i >= tape.size()) ? dfs.top().first.get() : tape[i]; + int j = 0; + if (i >= tape.size()) { + j = dfs.top().second; + dfs.pop(); + } else { + i++; + } + for (; j < a.inputs().size(); ++j) { + auto& in = a.inputs()[j]; if (in.status() != array::Status::unscheduled) { continue; } + + // If the width limit is exceeded, push the array on the stack + // and go down a level + if ((tape.size() - i) >= max_width) { + dfs.emplace(a, j); + break; + } + auto it = cache.find(in.id()); it->second -= 1; diff --git a/mlx/utils.cpp b/mlx/utils.cpp index e3c2c72bd..847432370 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -1,9 +1,10 @@ // Copyright © 2023 Apple Inc. +#include #include #include -#include "utils.h" +#include "mlx/utils.h" namespace mlx::core { @@ -336,4 +337,16 @@ std::ostream& operator<<(std::ostream& os, const std::vector& v) { return os; } +namespace env { + +int get_var(const char* name, int default_value) { + if (const char* buff_str = std::getenv(name)) { + return atoi(buff_str); + } else { + return default_value; + } +} + +} // namespace env + } // namespace mlx::core diff --git a/mlx/utils.h b/mlx/utils.h index e536da55f..e5d1ad9ae 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -120,4 +120,20 @@ inline int next_power_of_2(int n) { return pow(2, std::ceil(std::log2(n))); } +namespace env { + +int get_var(const char* name, int default_value); + +inline int bfs_max_width() { + static int bfs_max_width_ = get_var("MLX_BFS_MAX_WIDTH", 20); + return bfs_max_width_; +} + +inline int max_ops_per_buffer() { + static int max_ops_per_buffer_ = get_var("MLX_MAX_OPS_PER_BUFFER", 10); + return max_ops_per_buffer_; +} + +} // namespace env + } // namespace mlx::core