mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Bfs width limit (#1568)
* width limit * fix * large limit * put env vars in env namespace
This commit is contained in:
parent
8c34c9dac4
commit
c1fe1ef081
@ -1,11 +1,11 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <cstdlib>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core::metal {
|
namespace mlx::core::metal {
|
||||||
|
|
||||||
@ -13,20 +13,6 @@ bool is_available() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
int max_ops_per_buffer() {
|
|
||||||
auto get_val = []() {
|
|
||||||
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
|
|
||||||
return atoi(buff_str);
|
|
||||||
} else {
|
|
||||||
return 10;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
static int max_ops_per_buffer_ = get_val();
|
|
||||||
return max_ops_per_buffer_;
|
|
||||||
}
|
|
||||||
|
|
||||||
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
|
|
||||||
|
|
||||||
inline void check_error(MTL::CommandBuffer* cbuf) {
|
inline void check_error(MTL::CommandBuffer* cbuf) {
|
||||||
if (cbuf->status() == MTL::CommandBufferStatusError) {
|
if (cbuf->status() == MTL::CommandBufferStatusError) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -77,7 +63,8 @@ std::function<void()> make_task(array arr, bool signal) {
|
|||||||
out.set_status(array::Status::evaluated);
|
out.set_status(array::Status::evaluated);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
|
if (signal ||
|
||||||
|
d.get_command_buffer_ops(s.index) >= env::max_ops_per_buffer()) {
|
||||||
d.end_encoding(s.index);
|
d.end_encoding(s.index);
|
||||||
if (signal) {
|
if (signal) {
|
||||||
command_buffer->encodeSignalEvent(
|
command_buffer->encodeSignalEvent(
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <deque>
|
||||||
#include <future>
|
#include <future>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <set>
|
#include <set>
|
||||||
@ -40,7 +41,7 @@ int detail::InTracing::tracing_counter{0};
|
|||||||
int detail::RetainGraph::tracing_counter{0};
|
int detail::RetainGraph::tracing_counter{0};
|
||||||
|
|
||||||
array eval_impl(std::vector<array> outputs, bool async) {
|
array eval_impl(std::vector<array> outputs, bool async) {
|
||||||
std::vector<array> tape;
|
std::deque<array> tape;
|
||||||
|
|
||||||
// stream events to use for synchronization
|
// stream events to use for synchronization
|
||||||
std::unordered_map<uint32_t, Event> events;
|
std::unordered_map<uint32_t, Event> events;
|
||||||
@ -72,6 +73,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
while (!dfs.empty()) {
|
while (!dfs.empty()) {
|
||||||
auto& [a_ref, idx] = dfs.top();
|
auto& [a_ref, idx] = dfs.top();
|
||||||
auto& a = a_ref.get();
|
auto& a = a_ref.get();
|
||||||
|
|
||||||
if (idx < a.inputs().size()) {
|
if (idx < a.inputs().size()) {
|
||||||
// Add an input, and continue
|
// Add an input, and continue
|
||||||
auto& in = a.inputs()[idx++];
|
auto& in = a.inputs()[idx++];
|
||||||
@ -130,14 +132,32 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
dfs.pop();
|
dfs.pop();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build the tape in BFS order
|
// Build the tape in BFS order with a width limit
|
||||||
|
int max_width = env::bfs_max_width();
|
||||||
|
dfs = std::stack<std::pair<std::reference_wrapper<array>, int>>();
|
||||||
tape.push_back(synchronizer);
|
tape.push_back(synchronizer);
|
||||||
for (int i = 0; !cache.empty() && i < tape.size(); ++i) {
|
for (int i = 0; !cache.empty() && (i < tape.size() || !dfs.empty());) {
|
||||||
auto& a = tape[i];
|
auto& a = (i >= tape.size()) ? dfs.top().first.get() : tape[i];
|
||||||
for (auto& in : a.inputs()) {
|
int j = 0;
|
||||||
|
if (i >= tape.size()) {
|
||||||
|
j = dfs.top().second;
|
||||||
|
dfs.pop();
|
||||||
|
} else {
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
for (; j < a.inputs().size(); ++j) {
|
||||||
|
auto& in = a.inputs()[j];
|
||||||
if (in.status() != array::Status::unscheduled) {
|
if (in.status() != array::Status::unscheduled) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the width limit is exceeded, push the array on the stack
|
||||||
|
// and go down a level
|
||||||
|
if ((tape.size() - i) >= max_width) {
|
||||||
|
dfs.emplace(a, j);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
auto it = cache.find(in.id());
|
auto it = cache.find(in.id());
|
||||||
it->second -= 1;
|
it->second -= 1;
|
||||||
|
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
@ -336,4 +337,16 @@ std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& v) {
|
|||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace env {
|
||||||
|
|
||||||
|
int get_var(const char* name, int default_value) {
|
||||||
|
if (const char* buff_str = std::getenv(name)) {
|
||||||
|
return atoi(buff_str);
|
||||||
|
} else {
|
||||||
|
return default_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace env
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
16
mlx/utils.h
16
mlx/utils.h
@ -120,4 +120,20 @@ inline int next_power_of_2(int n) {
|
|||||||
return pow(2, std::ceil(std::log2(n)));
|
return pow(2, std::ceil(std::log2(n)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace env {
|
||||||
|
|
||||||
|
int get_var(const char* name, int default_value);
|
||||||
|
|
||||||
|
inline int bfs_max_width() {
|
||||||
|
static int bfs_max_width_ = get_var("MLX_BFS_MAX_WIDTH", 20);
|
||||||
|
return bfs_max_width_;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int max_ops_per_buffer() {
|
||||||
|
static int max_ops_per_buffer_ = get_var("MLX_MAX_OPS_PER_BUFFER", 10);
|
||||||
|
return max_ops_per_buffer_;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace env
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
Loading…
Reference in New Issue
Block a user