mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Bfs width limit (#1568)
* width limit * fix * large limit * put env vars in env namespace
This commit is contained in:
parent
8c34c9dac4
commit
c1fe1ef081
@ -1,11 +1,11 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::metal {
|
||||
|
||||
@ -13,20 +13,6 @@ bool is_available() {
|
||||
return true;
|
||||
}
|
||||
|
||||
int max_ops_per_buffer() {
|
||||
auto get_val = []() {
|
||||
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
|
||||
return atoi(buff_str);
|
||||
} else {
|
||||
return 10;
|
||||
}
|
||||
};
|
||||
static int max_ops_per_buffer_ = get_val();
|
||||
return max_ops_per_buffer_;
|
||||
}
|
||||
|
||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
||||
|
||||
inline void check_error(MTL::CommandBuffer* cbuf) {
|
||||
if (cbuf->status() == MTL::CommandBufferStatusError) {
|
||||
std::ostringstream msg;
|
||||
@ -77,7 +63,8 @@ std::function<void()> make_task(array arr, bool signal) {
|
||||
out.set_status(array::Status::evaluated);
|
||||
}
|
||||
|
||||
if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
|
||||
if (signal ||
|
||||
d.get_command_buffer_ops(s.index) >= env::max_ops_per_buffer()) {
|
||||
d.end_encoding(s.index);
|
||||
if (signal) {
|
||||
command_buffer->encodeSignalEvent(
|
||||
|
@ -1,5 +1,6 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
#include <algorithm>
|
||||
#include <deque>
|
||||
#include <future>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
@ -40,7 +41,7 @@ int detail::InTracing::tracing_counter{0};
|
||||
int detail::RetainGraph::tracing_counter{0};
|
||||
|
||||
array eval_impl(std::vector<array> outputs, bool async) {
|
||||
std::vector<array> tape;
|
||||
std::deque<array> tape;
|
||||
|
||||
// stream events to use for synchronization
|
||||
std::unordered_map<uint32_t, Event> events;
|
||||
@ -72,6 +73,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
while (!dfs.empty()) {
|
||||
auto& [a_ref, idx] = dfs.top();
|
||||
auto& a = a_ref.get();
|
||||
|
||||
if (idx < a.inputs().size()) {
|
||||
// Add an input, and continue
|
||||
auto& in = a.inputs()[idx++];
|
||||
@ -130,14 +132,32 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
dfs.pop();
|
||||
}
|
||||
|
||||
// Build the tape in BFS order
|
||||
// Build the tape in BFS order with a width limit
|
||||
int max_width = env::bfs_max_width();
|
||||
dfs = std::stack<std::pair<std::reference_wrapper<array>, int>>();
|
||||
tape.push_back(synchronizer);
|
||||
for (int i = 0; !cache.empty() && i < tape.size(); ++i) {
|
||||
auto& a = tape[i];
|
||||
for (auto& in : a.inputs()) {
|
||||
for (int i = 0; !cache.empty() && (i < tape.size() || !dfs.empty());) {
|
||||
auto& a = (i >= tape.size()) ? dfs.top().first.get() : tape[i];
|
||||
int j = 0;
|
||||
if (i >= tape.size()) {
|
||||
j = dfs.top().second;
|
||||
dfs.pop();
|
||||
} else {
|
||||
i++;
|
||||
}
|
||||
for (; j < a.inputs().size(); ++j) {
|
||||
auto& in = a.inputs()[j];
|
||||
if (in.status() != array::Status::unscheduled) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If the width limit is exceeded, push the array on the stack
|
||||
// and go down a level
|
||||
if ((tape.size() - i) >= max_width) {
|
||||
dfs.emplace(a, j);
|
||||
break;
|
||||
}
|
||||
|
||||
auto it = cache.find(in.id());
|
||||
it->second -= 1;
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@ -336,4 +337,16 @@ std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v) {
|
||||
return os;
|
||||
}
|
||||
|
||||
namespace env {
|
||||
|
||||
int get_var(const char* name, int default_value) {
|
||||
if (const char* buff_str = std::getenv(name)) {
|
||||
return atoi(buff_str);
|
||||
} else {
|
||||
return default_value;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace env
|
||||
|
||||
} // namespace mlx::core
|
||||
|
16
mlx/utils.h
16
mlx/utils.h
@ -120,4 +120,20 @@ inline int next_power_of_2(int n) {
|
||||
return pow(2, std::ceil(std::log2(n)));
|
||||
}
|
||||
|
||||
namespace env {
|
||||
|
||||
int get_var(const char* name, int default_value);
|
||||
|
||||
inline int bfs_max_width() {
|
||||
static int bfs_max_width_ = get_var("MLX_BFS_MAX_WIDTH", 20);
|
||||
return bfs_max_width_;
|
||||
}
|
||||
|
||||
inline int max_ops_per_buffer() {
|
||||
static int max_ops_per_buffer_ = get_var("MLX_MAX_OPS_PER_BUFFER", 10);
|
||||
return max_ops_per_buffer_;
|
||||
}
|
||||
|
||||
} // namespace env
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Loading…
Reference in New Issue
Block a user