Bfs width limit (#1568)

* width limit

* fix

* large limit

* put env vars in env namespace
This commit is contained in:
Awni Hannun 2024-11-08 15:00:46 -08:00 committed by GitHub
parent 8c34c9dac4
commit c1fe1ef081
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 58 additions and 22 deletions

View File

@ -1,11 +1,11 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstdlib>
#include <memory>
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"
namespace mlx::core::metal {
@ -13,20 +13,6 @@ bool is_available() {
return true;
}
int max_ops_per_buffer() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
return atoi(buff_str);
} else {
return 10;
}
};
static int max_ops_per_buffer_ = get_val();
return max_ops_per_buffer_;
}
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
inline void check_error(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
@ -77,7 +63,8 @@ std::function<void()> make_task(array arr, bool signal) {
out.set_status(array::Status::evaluated);
}
if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
if (signal ||
d.get_command_buffer_ops(s.index) >= env::max_ops_per_buffer()) {
d.end_encoding(s.index);
if (signal) {
command_buffer->encodeSignalEvent(

View File

@ -1,5 +1,6 @@
// Copyright © 2023-2024 Apple Inc.
#include <algorithm>
#include <deque>
#include <future>
#include <numeric>
#include <set>
@ -40,7 +41,7 @@ int detail::InTracing::tracing_counter{0};
int detail::RetainGraph::tracing_counter{0};
array eval_impl(std::vector<array> outputs, bool async) {
std::vector<array> tape;
std::deque<array> tape;
// stream events to use for synchronization
std::unordered_map<uint32_t, Event> events;
@ -72,6 +73,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
while (!dfs.empty()) {
auto& [a_ref, idx] = dfs.top();
auto& a = a_ref.get();
if (idx < a.inputs().size()) {
// Add an input, and continue
auto& in = a.inputs()[idx++];
@ -130,14 +132,32 @@ array eval_impl(std::vector<array> outputs, bool async) {
dfs.pop();
}
// Build the tape in BFS order
// Build the tape in BFS order with a width limit
int max_width = env::bfs_max_width();
dfs = std::stack<std::pair<std::reference_wrapper<array>, int>>();
tape.push_back(synchronizer);
for (int i = 0; !cache.empty() && i < tape.size(); ++i) {
auto& a = tape[i];
for (auto& in : a.inputs()) {
for (int i = 0; !cache.empty() && (i < tape.size() || !dfs.empty());) {
auto& a = (i >= tape.size()) ? dfs.top().first.get() : tape[i];
int j = 0;
if (i >= tape.size()) {
j = dfs.top().second;
dfs.pop();
} else {
i++;
}
for (; j < a.inputs().size(); ++j) {
auto& in = a.inputs()[j];
if (in.status() != array::Status::unscheduled) {
continue;
}
// If the width limit is exceeded, push the array on the stack
// and go down a level
if ((tape.size() - i) >= max_width) {
dfs.emplace(a, j);
break;
}
auto it = cache.find(in.id());
it->second -= 1;

View File

@ -1,9 +1,10 @@
// Copyright © 2023 Apple Inc.
#include <cstdlib>
#include <sstream>
#include <vector>
#include "utils.h"
#include "mlx/utils.h"
namespace mlx::core {
@ -336,4 +337,16 @@ std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v) {
return os;
}
namespace env {
int get_var(const char* name, int default_value) {
if (const char* buff_str = std::getenv(name)) {
return atoi(buff_str);
} else {
return default_value;
}
}
} // namespace env
} // namespace mlx::core

View File

@ -120,4 +120,20 @@ inline int next_power_of_2(int n) {
return pow(2, std::ceil(std::log2(n)));
}
namespace env {
int get_var(const char* name, int default_value);
inline int bfs_max_width() {
static int bfs_max_width_ = get_var("MLX_BFS_MAX_WIDTH", 20);
return bfs_max_width_;
}
inline int max_ops_per_buffer() {
static int max_ops_per_buffer_ = get_var("MLX_MAX_OPS_PER_BUFFER", 10);
return max_ops_per_buffer_;
}
} // namespace env
} // namespace mlx::core