Bfs width limit (#1568)

* width limit

* fix

* large limit

* put env vars in env namespace
This commit is contained in:
Awni Hannun
2024-11-08 15:00:46 -08:00
committed by GitHub
parent 8c34c9dac4
commit c1fe1ef081
4 changed files with 58 additions and 22 deletions

View File

@@ -1,11 +1,11 @@
// Copyright © 2023-2024 Apple Inc.
#include <cstdlib>
#include <memory>
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"
namespace mlx::core::metal {
@@ -13,20 +13,6 @@ bool is_available() {
return true;
}
int max_ops_per_buffer() {
auto get_val = []() {
if (const char* buff_str = std::getenv("MLX_MAX_OPS_PER_BUFFER")) {
return atoi(buff_str);
} else {
return 10;
}
};
static int max_ops_per_buffer_ = get_val();
return max_ops_per_buffer_;
}
#define MAX_OPS_PER_BUFFER max_ops_per_buffer()
inline void check_error(MTL::CommandBuffer* cbuf) {
if (cbuf->status() == MTL::CommandBufferStatusError) {
std::ostringstream msg;
@@ -77,7 +63,8 @@ std::function<void()> make_task(array arr, bool signal) {
out.set_status(array::Status::evaluated);
}
if (signal || d.get_command_buffer_ops(s.index) >= MAX_OPS_PER_BUFFER) {
if (signal ||
d.get_command_buffer_ops(s.index) >= env::max_ops_per_buffer()) {
d.end_encoding(s.index);
if (signal) {
command_buffer->encodeSignalEvent(