From 7b7e2352cdaa97b063889dd87d59b1a6f4a8bed2 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 20 Mar 2025 16:48:43 -0700 Subject: [PATCH] fix malloc or wait deadlock (#1976) --- docs/src/dev/extensions.rst | 6 ++--- examples/extensions/axpby/axpby.cpp | 8 +++--- mlx/allocator.cpp | 26 ++----------------- mlx/allocator.h | 8 ++---- mlx/backend/common/binary.h | 10 +++---- mlx/backend/common/common.cpp | 2 +- mlx/backend/common/compiled.cpp | 4 +-- mlx/backend/common/copy.h | 4 +-- mlx/backend/common/load.cpp | 2 +- mlx/backend/common/ternary.h | 6 ++--- mlx/backend/cpu/arg_reduce.cpp | 2 +- mlx/backend/cpu/conv.cpp | 8 +++--- mlx/backend/cpu/distributed.cpp | 6 ++--- mlx/backend/cpu/eigh.cpp | 7 +++-- mlx/backend/cpu/fft.cpp | 2 +- mlx/backend/cpu/indexing.cpp | 4 +-- mlx/backend/cpu/inverse.cpp | 4 +-- mlx/backend/cpu/luf.cpp | 7 +++-- mlx/backend/cpu/masked_mm.cpp | 4 +-- mlx/backend/cpu/matmul.cpp | 2 +- mlx/backend/cpu/primitives.cpp | 14 +++++----- mlx/backend/cpu/qrf.cpp | 14 +++++----- mlx/backend/cpu/quantized.cpp | 10 +++---- mlx/backend/cpu/reduce.cpp | 2 +- mlx/backend/cpu/scan.cpp | 2 +- mlx/backend/cpu/softmax.cpp | 2 +- mlx/backend/cpu/sort.cpp | 4 +-- mlx/backend/cpu/svd.cpp | 12 ++++----- mlx/backend/cpu/unary.h | 4 +-- mlx/backend/metal/allocator.cpp | 21 ++++++++------- mlx/backend/metal/allocator.h | 5 ++-- mlx/backend/metal/conv.cpp | 12 ++++----- mlx/backend/metal/copy.cpp | 2 +- mlx/backend/metal/custom_kernel.cpp | 2 +- mlx/backend/metal/fence.cpp | 2 +- mlx/backend/metal/fft.cpp | 11 ++++---- mlx/backend/metal/hadamard.cpp | 4 +-- mlx/backend/metal/indexing.cpp | 4 +-- mlx/backend/metal/matmul.cpp | 12 ++++----- mlx/backend/metal/metal.h | 17 +++++++----- mlx/backend/metal/normalization.cpp | 18 ++++++------- mlx/backend/metal/primitives.cpp | 14 +++++----- mlx/backend/metal/quantized.cpp | 10 +++---- mlx/backend/metal/reduce.cpp | 8 +++--- mlx/backend/metal/rope.cpp | 4 +-- .../metal/scaled_dot_product_attention.cpp | 15 +++++------ mlx/backend/metal/scan.cpp | 2 +- mlx/backend/metal/slicing.cpp | 2 +- mlx/backend/metal/softmax.cpp | 2 +- mlx/backend/metal/sort.cpp | 19 +++++++------- mlx/backend/metal/unary.cpp | 4 +-- mlx/backend/no_metal/metal.cpp | 5 +++- mlx/transforms.cpp | 9 ++++++- python/src/metal.cpp | 16 +++++------- python/tests/test_eval.py | 12 +++++++++ 55 files changed, 201 insertions(+), 217 deletions(-) diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index 4272371cb..b8c3a4995 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -247,9 +247,7 @@ point-wise. This is captured in the templated function :meth:`axpby_impl`. float alpha_, float beta_, mx::Stream stream) { - // Allocate the output with `malloc_or_wait` which synchronously allocates - // memory, potentially waiting if the system is under memory pressure - out.set_data(mx::allocator::malloc_or_wait(out.nbytes())); + out.set_data(mx::allocator::malloc(out.nbytes())); // Get the CPU command encoder and register input and output arrays auto& encoder = mx::cpu::get_command_encoder(stream); @@ -393,7 +391,7 @@ below. auto& d = metal::device(s.device); // Allocate output memory - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); // Resolve name of kernel std::ostringstream kname; diff --git a/examples/extensions/axpby/axpby.cpp b/examples/extensions/axpby/axpby.cpp index d3663e71e..291246617 100644 --- a/examples/extensions/axpby/axpby.cpp +++ b/examples/extensions/axpby/axpby.cpp @@ -72,9 +72,7 @@ void axpby_impl( float alpha_, float beta_, mx::Stream stream) { - // Allocate the output with `malloc_or_wait` which synchronously allocates - // memory, potentially waiting if the system is under memory pressure - out.set_data(mx::allocator::malloc_or_wait(out.nbytes())); + out.set_data(mx::allocator::malloc(out.nbytes())); // Get the CPU command encoder and register input and output arrays auto& encoder = mx::cpu::get_command_encoder(stream); @@ -160,12 +158,12 @@ void Axpby::eval_gpu( // Allocate output memory with strides based on specialization if (contiguous_kernel) { out.set_data( - mx::allocator::malloc_or_wait(x.data_size() * out.itemsize()), + mx::allocator::malloc(x.data_size() * out.itemsize()), x.data_size(), x.strides(), x.flags()); } else { - out.set_data(mx::allocator::malloc_or_wait(out.nbytes())); + out.set_data(mx::allocator::malloc(out.nbytes())); } // Resolve name of kernel (corresponds to axpby.metal) diff --git a/mlx/allocator.cpp b/mlx/allocator.cpp index 8d7273a78..2d97a6db3 100644 --- a/mlx/allocator.cpp +++ b/mlx/allocator.cpp @@ -9,7 +9,7 @@ namespace mlx::core::allocator { Buffer malloc(size_t size) { - auto buffer = allocator().malloc(size, /* allow_swap */ true); + auto buffer = allocator().malloc(size); if (size && !buffer.ptr()) { std::ostringstream msg; msg << "[malloc] Unable to allocate " << size << " bytes."; @@ -22,7 +22,7 @@ void free(Buffer buffer) { allocator().free(buffer); } -Buffer CommonAllocator::malloc(size_t size, bool) { +Buffer CommonAllocator::malloc(size_t size) { void* ptr = std::malloc(size + sizeof(size_t)); if (ptr != nullptr) { *static_cast(ptr) = size; @@ -41,26 +41,4 @@ size_t CommonAllocator::size(Buffer buffer) const { return *static_cast(buffer.ptr()); } -Buffer malloc_or_wait(size_t size) { - auto buffer = allocator().malloc(size); - - while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) { - scheduler::wait_for_one(); - buffer = allocator().malloc(size); - } - - // Try swapping if needed - if (size && !buffer.ptr()) { - buffer = allocator().malloc(size, /* allow_swap = */ true); - } - - if (size && !buffer.ptr()) { - std::ostringstream msg; - msg << "[malloc_or_wait] Unable to allocate " << size << " bytes."; - throw std::runtime_error(msg.str()); - } - - return buffer; -} - } // namespace mlx::core::allocator diff --git a/mlx/allocator.h b/mlx/allocator.h index 2809c7f7f..d4e3e1d6e 100644 --- a/mlx/allocator.h +++ b/mlx/allocator.h @@ -32,14 +32,10 @@ Buffer malloc(size_t size); void free(Buffer buffer); -// Wait for running tasks to finish and free up memory -// if allocation fails -Buffer malloc_or_wait(size_t size); - class Allocator { /** Abstract base class for a memory allocator. */ public: - virtual Buffer malloc(size_t size, bool allow_swap = false) = 0; + virtual Buffer malloc(size_t size) = 0; virtual void free(Buffer buffer) = 0; virtual size_t size(Buffer buffer) const = 0; @@ -56,7 +52,7 @@ Allocator& allocator(); class CommonAllocator : public Allocator { /** A general CPU allocator. */ public: - virtual Buffer malloc(size_t size, bool allow_swap = false) override; + virtual Buffer malloc(size_t size) override; virtual void free(Buffer buffer) override; virtual size_t size(Buffer buffer) const override; diff --git a/mlx/backend/common/binary.h b/mlx/backend/common/binary.h index 0c56b71a0..ac6e1891d 100644 --- a/mlx/backend/common/binary.h +++ b/mlx/backend/common/binary.h @@ -44,14 +44,14 @@ inline void set_binary_op_output_data( switch (bopt) { case BinaryOpType::ScalarScalar: out.set_data( - allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags()); + allocator::malloc(out.itemsize()), 1, a.strides(), a.flags()); break; case BinaryOpType::ScalarVector: if (b_donatable) { out.copy_shared_buffer(b); } else { out.set_data( - allocator::malloc_or_wait(b.data_size() * out.itemsize()), + allocator::malloc(b.data_size() * out.itemsize()), b.data_size(), b.strides(), b.flags()); @@ -62,7 +62,7 @@ inline void set_binary_op_output_data( out.copy_shared_buffer(a); } else { out.set_data( - allocator::malloc_or_wait(a.data_size() * out.itemsize()), + allocator::malloc(a.data_size() * out.itemsize()), a.data_size(), a.strides(), a.flags()); @@ -75,7 +75,7 @@ inline void set_binary_op_output_data( out.copy_shared_buffer(b); } else { out.set_data( - allocator::malloc_or_wait(a.data_size() * out.itemsize()), + allocator::malloc(a.data_size() * out.itemsize()), a.data_size(), a.strides(), a.flags()); @@ -88,7 +88,7 @@ inline void set_binary_op_output_data( b_donatable && b.flags().row_contiguous && b.size() == out.size()) { out.copy_shared_buffer(b); } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); } break; } diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp index 3eb98d09a..57813e062 100644 --- a/mlx/backend/common/common.cpp +++ b/mlx/backend/common/common.cpp @@ -103,7 +103,7 @@ void ExpandDims::eval(const std::vector& inputs, array& out) { void NumberOfElements::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); double numel = 1; for (auto ax : axes_) { diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index dfa9e700a..f7b5598ab 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -188,7 +188,7 @@ void compiled_allocate_outputs( } for (; o < outputs.size(); ++o) { outputs[o].set_data( - allocator::malloc_or_wait(data_size * outputs[o].itemsize()), + allocator::malloc(data_size * outputs[o].itemsize()), data_size, strides, flags); @@ -211,7 +211,7 @@ void compiled_allocate_outputs( } } for (; o < outputs.size(); ++o) { - outputs[o].set_data(allocator::malloc_or_wait(outputs[o].nbytes())); + outputs[o].set_data(allocator::malloc(outputs[o].nbytes())); } } } diff --git a/mlx/backend/common/copy.h b/mlx/backend/common/copy.h index 5d8f7a58e..0c9f28c94 100644 --- a/mlx/backend/common/copy.h +++ b/mlx/backend/common/copy.h @@ -31,14 +31,14 @@ inline bool set_copy_output_data(const array& in, array& out, CopyType ctype) { return true; } else { out.set_data( - allocator::malloc_or_wait(in.data_size() * out.itemsize()), + allocator::malloc(in.data_size() * out.itemsize()), in.data_size(), in.strides(), in.flags()); return false; } } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); return false; } } diff --git a/mlx/backend/common/load.cpp b/mlx/backend/common/load.cpp index 3f194f1e2..ce41963de 100644 --- a/mlx/backend/common/load.cpp +++ b/mlx/backend/common/load.cpp @@ -28,7 +28,7 @@ void swap_endianness(uint8_t* data_bytes, size_t N) { namespace mlx::core { void Load::eval_cpu(const std::vector& inputs, array& out) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto read_task = [out_ptr = out.data(), size = out.size(), itemsize = out.itemsize(), diff --git a/mlx/backend/common/ternary.h b/mlx/backend/common/ternary.h index ad6df2fc4..d98dd8d68 100644 --- a/mlx/backend/common/ternary.h +++ b/mlx/backend/common/ternary.h @@ -48,12 +48,12 @@ inline void set_ternary_op_output_data( switch (topt) { case TernaryOpType::ScalarScalarScalar: out.set_data( - allocator::malloc_or_wait(out.itemsize()), 1, b.strides(), b.flags()); + allocator::malloc(out.itemsize()), 1, b.strides(), b.flags()); break; case TernaryOpType::VectorVectorVector: if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) { out.set_data( - allocator::malloc_or_wait(out.itemsize() * b.data_size()), + allocator::malloc(out.itemsize() * b.data_size()), b.data_size(), b.strides(), b.flags()); @@ -64,7 +64,7 @@ inline void set_ternary_op_output_data( if (!((a.flags().row_contiguous && maybe_donate(a)) || (b.flags().row_contiguous && maybe_donate(b)) || (c.flags().row_contiguous && maybe_donate(c)))) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); } break; } diff --git a/mlx/backend/cpu/arg_reduce.cpp b/mlx/backend/cpu/arg_reduce.cpp index c9bdc35b0..a8ba3efe2 100644 --- a/mlx/backend/cpu/arg_reduce.cpp +++ b/mlx/backend/cpu/arg_reduce.cpp @@ -68,7 +68,7 @@ void arg_reduce_dispatch( void ArgReduce::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(in); encoder.set_output_array(out); diff --git a/mlx/backend/cpu/conv.cpp b/mlx/backend/cpu/conv.cpp index 8259e0597..d52f92f8b 100644 --- a/mlx/backend/cpu/conv.cpp +++ b/mlx/backend/cpu/conv.cpp @@ -921,7 +921,7 @@ void explicit_gemm_conv_1D_cpu( if (out.dtype() != float32) { gemm_out = array(out.shape(), float32, nullptr, {}); - gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); + gemm_out.set_data(allocator::malloc(gemm_out.nbytes())); temps.push_back(gemm_out); } @@ -1048,7 +1048,7 @@ void explicit_gemm_conv_2D_cpu( if (out.dtype() != float32) { gemm_out = array(out.shape(), float32, nullptr, {}); - gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); + gemm_out.set_data(allocator::malloc(gemm_out.nbytes())); temps.push_back(gemm_out); } @@ -1214,7 +1214,7 @@ void explicit_gemm_conv_ND_cpu( if (out.dtype() != float32) { gemm_out = array(out.shape(), float32, nullptr, {}); - gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); + gemm_out.set_data(allocator::malloc(gemm_out.nbytes())); temps.push_back(gemm_out); } @@ -1327,7 +1327,7 @@ void conv_3D_cpu( } // namespace void Convolution::eval_cpu(const std::vector& inputs, array& out) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& in = inputs[0]; auto& wt = inputs[1]; diff --git a/mlx/backend/cpu/distributed.cpp b/mlx/backend/cpu/distributed.cpp index 1d6d9de06..1afa027a8 100644 --- a/mlx/backend/cpu/distributed.cpp +++ b/mlx/backend/cpu/distributed.cpp @@ -30,7 +30,7 @@ void AllReduce::eval_cpu( if (in.is_donatable()) { out.copy_shared_buffer(in); } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); } return in; } else { @@ -58,7 +58,7 @@ void AllGather::eval_cpu( assert(outputs.size() == 1); auto [in, copied] = ensure_row_contiguous(inputs[0], stream()); - outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); + outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); distributed::detail::all_gather(group(), in, outputs[0], stream()); if (copied) { auto& enc = cpu::get_command_encoder(stream()); @@ -87,7 +87,7 @@ void Recv::eval_cpu( assert(inputs.size() == 0); assert(outputs.size() == 1); - outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); + outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); distributed::detail::recv(group(), outputs[0], src_, stream()); } diff --git a/mlx/backend/cpu/eigh.cpp b/mlx/backend/cpu/eigh.cpp index 348a27199..b50f2c722 100644 --- a/mlx/backend/cpu/eigh.cpp +++ b/mlx/backend/cpu/eigh.cpp @@ -55,9 +55,8 @@ void eigh_impl( liwork = iwork; } - auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; - auto iwork_buf = - array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)}; + auto work_buf = array::Data{allocator::malloc(sizeof(T) * lwork)}; + auto iwork_buf = array::Data{allocator::malloc(sizeof(int) * liwork)}; for (size_t i = 0; i < size / (N * N); ++i) { syevd( &jobz, @@ -98,7 +97,7 @@ void Eigh::eval_cpu( ? outputs[1] : array(a.shape(), a.dtype(), nullptr, {}); - values.set_data(allocator::malloc_or_wait(values.nbytes())); + values.set_data(allocator::malloc(values.nbytes())); copy( a, diff --git a/mlx/backend/cpu/fft.cpp b/mlx/backend/cpu/fft.cpp index 90227575b..d9e5f8050 100644 --- a/mlx/backend/cpu/fft.cpp +++ b/mlx/backend/cpu/fft.cpp @@ -22,7 +22,7 @@ void FFT::eval_cpu(const std::vector& inputs, array& out) { s *= out.itemsize(); } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); std::vector shape; if (out.dtype() == float32) { diff --git a/mlx/backend/cpu/indexing.cpp b/mlx/backend/cpu/indexing.cpp index 6a32dc1d4..70d6b3eb7 100644 --- a/mlx/backend/cpu/indexing.cpp +++ b/mlx/backend/cpu/indexing.cpp @@ -197,7 +197,7 @@ void dispatch_gather( } void Gather::eval_cpu(const std::vector& inputs, array& out) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& src = inputs[0]; std::vector inds; @@ -354,7 +354,7 @@ void dispatch_gather_axis( } void GatherAxis::eval_cpu(const std::vector& inputs, array& out) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& src = inputs[0]; auto& inds = inputs[1]; diff --git a/mlx/backend/cpu/inverse.cpp b/mlx/backend/cpu/inverse.cpp index 9d9b71497..2e79addcb 100644 --- a/mlx/backend/cpu/inverse.cpp +++ b/mlx/backend/cpu/inverse.cpp @@ -11,7 +11,7 @@ namespace mlx::core { template void general_inv(T* inv, int N) { int info; - auto ipiv = array::Data{allocator::malloc_or_wait(sizeof(int) * N)}; + auto ipiv = array::Data{allocator::malloc(sizeof(int) * N)}; // Compute LU factorization. getrf( /* m = */ &N, @@ -49,7 +49,7 @@ void general_inv(T* inv, int N) { } const int lwork = workspace_size; - auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; + auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)}; // Compute inverse. getri( diff --git a/mlx/backend/cpu/luf.cpp b/mlx/backend/cpu/luf.cpp index d85de146c..9ac9361a4 100644 --- a/mlx/backend/cpu/luf.cpp +++ b/mlx/backend/cpu/luf.cpp @@ -30,8 +30,7 @@ void luf_impl( auto strides = lu.strides(); strides[ndim - 1] = M; strides[ndim - 2] = 1; - lu.set_data( - allocator::malloc_or_wait(lu.nbytes()), lu.nbytes(), strides, flags); + lu.set_data(allocator::malloc(lu.nbytes()), lu.nbytes(), strides, flags); copy_inplace( a, lu, @@ -44,8 +43,8 @@ void luf_impl( stream); auto a_ptr = lu.data(); - pivots.set_data(allocator::malloc_or_wait(pivots.nbytes())); - row_indices.set_data(allocator::malloc_or_wait(row_indices.nbytes())); + pivots.set_data(allocator::malloc(pivots.nbytes())); + row_indices.set_data(allocator::malloc(row_indices.nbytes())); auto pivots_ptr = pivots.data(); auto row_indices_ptr = row_indices.data(); size_t num_matrices = a.size() / (M * N); diff --git a/mlx/backend/cpu/masked_mm.cpp b/mlx/backend/cpu/masked_mm.cpp index 75b8a5b3d..0be7c79ce 100644 --- a/mlx/backend/cpu/masked_mm.cpp +++ b/mlx/backend/cpu/masked_mm.cpp @@ -59,7 +59,7 @@ void BlockMaskedMM::eval_cpu(const std::vector& inputs, array& out) { throw std::runtime_error( "[BlockMaskedMM::eval] Currently only supports float32."); } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; @@ -318,7 +318,7 @@ void GatherMM::eval_cpu(const std::vector& inputs, array& out) { throw std::runtime_error( "[GatherMM::eval] Currently only supports float32."); } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; diff --git a/mlx/backend/cpu/matmul.cpp b/mlx/backend/cpu/matmul.cpp index 082531894..8ae99ab2d 100644 --- a/mlx/backend/cpu/matmul.cpp +++ b/mlx/backend/cpu/matmul.cpp @@ -115,7 +115,7 @@ void matmul_general( } void Matmul::eval_cpu(const std::vector& inputs, array& out) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); if (inputs[0].shape(-1) == 0) { auto& encoder = cpu::get_command_encoder(stream()); encoder.set_output_array(out); diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index a561f1de8..1dfae8524 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -21,7 +21,7 @@ namespace mlx::core { void reshape(const array& in, array& out) { auto [copy_necessary, out_strides] = prepare_reshape(in, out); if (copy_necessary) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); copy_inplace(in, out, CopyType::General, out.primitive().stream()); } else { shared_buffer_reshape(in, out_strides, out); @@ -39,7 +39,7 @@ static std::pair compute_dynamic_offset( if (donate) { offset.copy_shared_buffer(indices); } else { - offset.set_data(allocator::malloc_or_wait(offset.itemsize())); + offset.set_data(allocator::malloc(offset.itemsize())); } auto& encoder = cpu::get_command_encoder(stream); @@ -124,7 +124,7 @@ void Transpose::eval_cpu(const std::vector& inputs, array& out) { void Arange::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 0); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); switch (out.dtype()) { case bool_: throw std::runtime_error("Bool type unsupported for arange."); @@ -186,7 +186,7 @@ void Concatenate::eval_cpu(const std::vector& inputs, array& out) { } std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto strides = out.strides(); auto flags = out.flags(); @@ -276,7 +276,7 @@ void RandomBits::eval_cpu(const std::vector& inputs, array& out) { size_t elems_per_key = out.size() / num_keys; size_t bytes_per_key = out.itemsize() * elems_per_key; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto kptr = inputs[0].data(); auto cptr = out.data(); @@ -335,7 +335,7 @@ void DynamicSlice::eval_cpu(const std::vector& inputs, array& out) { return; } auto& in = inputs[0]; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto [in_offset, donated] = compute_dynamic_offset(inputs[1], in.strides(), axes_, stream()); copy_inplace( @@ -450,7 +450,7 @@ void View::eval_cpu(const std::vector& inputs, array& out) { } else { auto tmp = array( in.shape(), in.dtype() == bool_ ? uint8 : in.dtype(), nullptr, {}); - tmp.set_data(allocator::malloc_or_wait(tmp.nbytes())); + tmp.set_data(allocator::malloc(tmp.nbytes())); if (in.dtype() == bool_) { auto in_tmp = array(in.shape(), uint8, nullptr, {}); in_tmp.copy_shared_buffer(in); diff --git a/mlx/backend/cpu/qrf.cpp b/mlx/backend/cpu/qrf.cpp index 8c6f94140..9e01d188b 100644 --- a/mlx/backend/cpu/qrf.cpp +++ b/mlx/backend/cpu/qrf.cpp @@ -25,12 +25,11 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { auto strides = in.strides(); strides[in.ndim() - 2] = 1; strides[in.ndim() - 1] = M; - in.set_data( - allocator::malloc_or_wait(in.nbytes()), in.nbytes(), strides, flags); + in.set_data(allocator::malloc(in.nbytes()), in.nbytes(), strides, flags); copy_inplace(a, in, CopyType::GeneralGeneral, stream); auto& encoder = cpu::get_command_encoder(stream); - q.set_data(allocator::malloc_or_wait(q.nbytes())); - r.set_data(allocator::malloc_or_wait(r.nbytes())); + q.set_data(allocator::malloc(q.nbytes())); + r.set_data(allocator::malloc(r.nbytes())); auto in_ptr = in.data(); auto r_ptr = r.data(); @@ -41,8 +40,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { encoder.set_output_array(r); encoder.dispatch([in_ptr, q_ptr, r_ptr, M, N, lda, num_matrices]() { int num_reflectors = std::min(M, N); - auto tau = - allocator::malloc_or_wait(sizeof(T) * num_matrices * num_reflectors); + auto tau = allocator::malloc(sizeof(T) * num_matrices * num_reflectors); T optimal_work; int lwork = -1; @@ -53,7 +51,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { // Update workspace size lwork = optimal_work; - auto work = allocator::malloc_or_wait(sizeof(T) * lwork); + auto work = allocator::malloc(sizeof(T) * lwork); // Loop over matrices for (int i = 0; i < num_matrices; ++i) { @@ -96,7 +94,7 @@ void qrf_impl(const array& a, array& q, array& r, Stream stream) { &lwork, &info); lwork = optimal_work; - work = allocator::malloc_or_wait(sizeof(T) * lwork); + work = allocator::malloc(sizeof(T) * lwork); // Loop over matrices for (int i = 0; i < num_matrices; ++i) { diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index 02bddab2f..f0ac9d57f 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -515,7 +515,7 @@ void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { auto scales = ensure_row_contiguous(scales_pre); auto biases = ensure_row_contiguous(biases_pre); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); encoder.add_temporaries(std::move(temps)); @@ -565,7 +565,7 @@ void GatherQMM::eval_cpu(const std::vector& inputs, array& out) { auto scales = ensure_row_contiguous_last_dims(scales_pre); auto biases = ensure_row_contiguous_last_dims(biases_pre); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); encoder.add_temporaries(std::move(temps)); @@ -691,12 +691,12 @@ void fast::AffineQuantize::eval_cpu( auto [w, copied] = ensure_row_contiguous(inputs[0]); auto& out = outputs[0]; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& scales = outputs[1]; auto& biases = outputs[2]; - scales.set_data(allocator::malloc_or_wait(scales.nbytes())); - biases.set_data(allocator::malloc_or_wait(biases.nbytes())); + scales.set_data(allocator::malloc(scales.nbytes())); + biases.set_data(allocator::malloc(biases.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); if (copied) { encoder.add_temporary(w); diff --git a/mlx/backend/cpu/reduce.cpp b/mlx/backend/cpu/reduce.cpp index 3f0c3b2ae..ce25feb11 100644 --- a/mlx/backend/cpu/reduce.cpp +++ b/mlx/backend/cpu/reduce.cpp @@ -433,7 +433,7 @@ void reduce_dispatch_min_max( void Reduce::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(in); encoder.set_output_array(out); diff --git a/mlx/backend/cpu/scan.cpp b/mlx/backend/cpu/scan.cpp index 205ae414d..1a44ebd39 100644 --- a/mlx/backend/cpu/scan.cpp +++ b/mlx/backend/cpu/scan.cpp @@ -244,7 +244,7 @@ void Scan::eval_cpu(const std::vector& inputs, array& out) { in = arr_copy; encoder.add_temporary(arr_copy); } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); encoder.set_input_array(in); encoder.set_output_array(out); diff --git a/mlx/backend/cpu/softmax.cpp b/mlx/backend/cpu/softmax.cpp index 591bad020..78e4a3e68 100644 --- a/mlx/backend/cpu/softmax.cpp +++ b/mlx/backend/cpu/softmax.cpp @@ -129,7 +129,7 @@ void Softmax::eval_cpu(const std::vector& inputs, array& out) { out.copy_shared_buffer(x); } else { out.set_data( - allocator::malloc_or_wait(x.data_size() * x.itemsize()), + allocator::malloc(x.data_size() * x.itemsize()), x.data_size(), x.strides(), x.flags()); diff --git a/mlx/backend/cpu/sort.cpp b/mlx/backend/cpu/sort.cpp index 4439df61b..b00e301b8 100644 --- a/mlx/backend/cpu/sort.cpp +++ b/mlx/backend/cpu/sort.cpp @@ -288,7 +288,7 @@ void ArgSort::eval_cpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; // Allocate output - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(in); @@ -379,7 +379,7 @@ void ArgPartition::eval_cpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; // Allocate output - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& encoder = cpu::get_command_encoder(stream()); encoder.set_input_array(in); diff --git a/mlx/backend/cpu/svd.cpp b/mlx/backend/cpu/svd.cpp index 86745e545..24d93f8e5 100644 --- a/mlx/backend/cpu/svd.cpp +++ b/mlx/backend/cpu/svd.cpp @@ -50,9 +50,9 @@ void svd_impl( array& s = outputs[1]; array& vt = outputs[2]; - u.set_data(allocator::malloc_or_wait(u.nbytes())); - s.set_data(allocator::malloc_or_wait(s.nbytes())); - vt.set_data(allocator::malloc_or_wait(vt.nbytes())); + u.set_data(allocator::malloc(u.nbytes())); + s.set_data(allocator::malloc(s.nbytes())); + vt.set_data(allocator::malloc(vt.nbytes())); encoder.set_output_array(u); encoder.set_output_array(s); @@ -64,7 +64,7 @@ void svd_impl( } else { array& s = outputs[0]; - s.set_data(allocator::malloc_or_wait(s.nbytes())); + s.set_data(allocator::malloc(s.nbytes())); encoder.set_output_array(s); @@ -91,7 +91,7 @@ void svd_impl( // Will contain the indices of eigenvectors that failed to converge (not // used here but required by lapack). - auto iwork = array::Data{allocator::malloc_or_wait(sizeof(int) * 12 * K)}; + auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)}; static const int lwork_query = -1; @@ -132,7 +132,7 @@ void svd_impl( } const int lwork = workspace_dimension; - auto scratch = array::Data{allocator::malloc_or_wait(sizeof(T) * lwork)}; + auto scratch = array::Data{allocator::malloc(sizeof(T) * lwork)}; // Loop over matrices. for (int i = 0; i < num_matrices; i++) { diff --git a/mlx/backend/cpu/unary.h b/mlx/backend/cpu/unary.h index a12bbbf00..fa539541c 100644 --- a/mlx/backend/cpu/unary.h +++ b/mlx/backend/cpu/unary.h @@ -18,13 +18,13 @@ void set_unary_output_data(const array& in, array& out) { } else { auto size = in.data_size(); out.set_data( - allocator::malloc_or_wait(size * out.itemsize()), + allocator::malloc(size * out.itemsize()), size, in.strides(), in.flags()); } } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); } } diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 8f5b28226..d7b84a165 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -192,16 +192,19 @@ size_t MetalAllocator::set_cache_limit(size_t limit) { return limit; }; -size_t MetalAllocator::set_memory_limit(size_t limit, bool relaxed) { +size_t MetalAllocator::set_memory_limit(size_t limit) { std::unique_lock lk(mutex_); std::swap(limit, block_limit_); - relaxed_ = relaxed; gc_limit_ = std::min( block_limit_, static_cast(0.95 * device_->recommendedMaxWorkingSetSize())); return limit; }; +size_t MetalAllocator::get_memory_limit() { + return block_limit_; +} + size_t MetalAllocator::set_wired_limit(size_t limit) { std::unique_lock lk(mutex_); std::swap(limit, wired_limit_); @@ -209,7 +212,7 @@ size_t MetalAllocator::set_wired_limit(size_t limit) { return limit; }; -Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { +Buffer MetalAllocator::malloc(size_t size) { // Metal doesn't like empty buffers if (size == 0) { return Buffer{nullptr}; @@ -236,11 +239,6 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { if (!buf) { size_t mem_required = get_active_memory() + get_cache_memory() + size; - // If there is too much memory pressure, fail (likely causes a wait). - if (!(allow_swap && relaxed_) && mem_required >= block_limit_) { - return Buffer{nullptr}; - } - auto pool = metal::new_scoped_memory_pool(); // If we have a lot of memory pressure or are over the maximum cache size, @@ -328,8 +326,11 @@ MetalAllocator& allocator() { size_t set_cache_limit(size_t limit) { return allocator().set_cache_limit(limit); } -size_t set_memory_limit(size_t limit, bool relaxed /* = true */) { - return allocator().set_memory_limit(limit, relaxed); +size_t set_memory_limit(size_t limit) { + return allocator().set_memory_limit(limit); +} +size_t get_memory_limit() { + return allocator().get_memory_limit(); } size_t set_wired_limit(size_t limit) { if (limit > diff --git a/mlx/backend/metal/allocator.h b/mlx/backend/metal/allocator.h index df301f55e..8b77ff6c1 100644 --- a/mlx/backend/metal/allocator.h +++ b/mlx/backend/metal/allocator.h @@ -56,7 +56,7 @@ class BufferCache { class MetalAllocator : public allocator::Allocator { /** Allocator for Metal GPUs. */ public: - virtual Buffer malloc(size_t size, bool allow_swap = false) override; + virtual Buffer malloc(size_t size) override; virtual void free(Buffer buffer) override; virtual size_t size(Buffer buffer) const override; size_t get_active_memory() { @@ -73,7 +73,8 @@ class MetalAllocator : public allocator::Allocator { return buffer_cache_.cache_size(); }; size_t set_cache_limit(size_t limit); - size_t set_memory_limit(size_t limit, bool relaxed); + size_t set_memory_limit(size_t limit); + size_t get_memory_limit(); size_t set_wired_limit(size_t limit); void clear_cache(); diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 3e42f7d2f..c4803a380 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -37,7 +37,7 @@ void explicit_gemm_conv_ND_gpu( Shape unfolded_shape{implicit_M, implicit_K}; array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); - in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes())); + in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); // Prepare unfolding kernel std::ostringstream kname; @@ -115,7 +115,7 @@ void explicit_gemm_conv_group_ND_gpu( // Prepare unfolding array Shape unfolded_shape{implicit_M, implicit_K * groups}; array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); - in_unfolded.set_data(allocator::malloc_or_wait(in_unfolded.nbytes())); + in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); // Prepare unfolding kernel std::ostringstream kname; @@ -613,7 +613,7 @@ void winograd_conv_2D_gpu( // Do filter transform Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O}; array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {}); - filt_wg.set_data(allocator::malloc_or_wait(filt_wg.nbytes())); + filt_wg.set_data(allocator::malloc(filt_wg.nbytes())); copies_w.push_back(filt_wg); { int bc = 32; @@ -640,7 +640,7 @@ void winograd_conv_2D_gpu( // Do input transform Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C}; array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {}); - inp_wg.set_data(allocator::malloc_or_wait(inp_wg.nbytes())); + inp_wg.set_data(allocator::malloc(inp_wg.nbytes())); copies_w.push_back(inp_wg); { int bc = 32; @@ -667,7 +667,7 @@ void winograd_conv_2D_gpu( // Do batched gemm Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O}; array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {}); - out_wg.set_data(allocator::malloc_or_wait(out_wg.nbytes())); + out_wg.set_data(allocator::malloc(out_wg.nbytes())); copies_w.push_back(out_wg); { std::vector empty_copies; @@ -855,7 +855,7 @@ void conv_3D_gpu( } // namespace void Convolution::eval_gpu(const std::vector& inputs, array& out) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 1750a6908..3399201de 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -202,7 +202,7 @@ void fill_gpu(const array& val, array& out, const Stream& s) { if (out.size() == 0) { return; } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); bool large = out.data_size() > UINT32_MAX; auto& d = metal::device(s.device); std::string kernel_name = std::string(large ? "s2" : "s") + "_copy" + diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index e775f6798..8a672289a 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -19,7 +19,7 @@ void CustomKernel::eval_gpu( copies.emplace_back(init_value_.value(), out.dtype()); fill_gpu(copies.back(), out, s); } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); } } diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index f5502231a..e784d34ae 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -20,7 +20,7 @@ struct FenceImpl { auto p = metal::new_scoped_memory_pool(); fence = static_cast(d->newSharedEvent()); } else { - auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr(); + auto buf = allocator::malloc(sizeof(uint32_t)).ptr(); fence = static_cast(buf); cpu_value()[0] = 0; } diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 52131b4a8..95678279e 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -281,7 +281,7 @@ std::tuple compute_raders_constants( } array b_q_fft({rader_n - 1}, complex64, nullptr, {}); - b_q_fft.set_data(allocator::malloc_or_wait(b_q_fft.nbytes())); + b_q_fft.set_data(allocator::malloc(b_q_fft.nbytes())); auto b_q_fft_ptr = reinterpret_cast*>(b_q_fft.data()); std::ptrdiff_t item_size = b_q_fft.itemsize(); @@ -327,11 +327,11 @@ std::pair compute_bluestein_constants(int n, int bluestein_n) { } array w_k({n}, complex64, nullptr, {}); - w_k.set_data(allocator::malloc_or_wait(w_k.nbytes())); + w_k.set_data(allocator::malloc(w_k.nbytes())); std::copy(w_k_vec.begin(), w_k_vec.end(), w_k.data()); array w_q({bluestein_n}, complex64, nullptr, {}); - w_q.set_data(allocator::malloc_or_wait(w_q.nbytes())); + w_q.set_data(allocator::malloc(w_q.nbytes())); auto w_q_ptr = reinterpret_cast*>(w_q.data()); @@ -551,8 +551,7 @@ void fft_op( flags.row_contiguous = is_row_contiguous; flags.contiguous = data_size == x_copy.size(); - x_copy.set_data( - allocator::malloc_or_wait(x.nbytes()), data_size, strides, flags); + x_copy.set_data(allocator::malloc(x.nbytes()), data_size, strides, flags); copy_gpu_inplace(x, x_copy, CopyType::GeneralGeneral, s); copies.push_back(x_copy); return x_copy; @@ -583,7 +582,7 @@ void fft_op( // TODO: allow donation here if (!inplace) { out.set_data( - allocator::malloc_or_wait(out.nbytes()), + allocator::malloc(out.nbytes()), out_data_size, out_strides, in_contiguous.flags()); diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index 7b711e28b..a7dfc5f17 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -84,7 +84,7 @@ void Hadamard::eval_gpu(const std::vector& inputs, array& out) { if (in_contiguous.is_donatable()) { out.copy_shared_buffer(in_contiguous); } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); } int n, m; @@ -161,7 +161,7 @@ void Hadamard::eval_gpu(const std::vector& inputs, array& out) { // Upload 2: // y = h12 @ tmp array temp(in.shape(), in.dtype(), nullptr, {}); - temp.set_data(allocator::malloc_or_wait(temp.nbytes())); + temp.set_data(allocator::malloc(temp.nbytes())); copies.push_back(temp); launch_hadamard(in_contiguous, temp, "n" + kernel_name, 1.0); diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 13696842c..d2a263051 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -43,7 +43,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error(msg.str()); } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); if (out.size() == 0) { return; } @@ -393,7 +393,7 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { auto& src = inputs[0]; auto& idx = inputs[1]; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); if (out.size() == 0) { return; } diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 4d3cf21ee..3f736505f 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -382,7 +382,7 @@ void steel_matmul( int split_k_partition_size = gemm_k_iterations * bk; array C_split({split_k_partitions, M, N}, float32, nullptr, {}); - C_split.set_data(allocator::malloc_or_wait(C_split.nbytes())); + C_split.set_data(allocator::malloc(C_split.nbytes())); copies.push_back(C_split); bool mn_aligned = M % bm == 0 && N % bn == 0; @@ -513,7 +513,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { return; } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); ///////////////////////////////////////////////////////////////////////////// // Init checks and prep @@ -677,7 +677,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error( "[matmul] Does not yet support non-floating point types."); } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); @@ -860,7 +860,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { int split_k_partition_size = gemm_k_iterations * bk; array C_split({split_k_partitions, M, N}, float32, nullptr, {}); - C_split.set_data(allocator::malloc_or_wait(C_split.nbytes())); + C_split.set_data(allocator::malloc(C_split.nbytes())); copies.push_back(C_split); bool mn_aligned = M % bm == 0 && N % bn == 0; @@ -1096,7 +1096,7 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { return; } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); ///////////////////////////////////////////////////////////////////////////// // Init checks and prep @@ -1484,7 +1484,7 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { return; } - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); ///////////////////////////////////////////////////////////////////////////// // Init checks and prep diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index d5c64f79d..82151c538 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -38,17 +38,20 @@ void reset_peak_memory(); size_t get_cache_memory(); /* Set the memory limit. - * Calls to malloc will wait on scheduled tasks if the limit is exceeded. If - * there are no more scheduled tasks an error will be raised if relaxed - * is false or memory will be allocated (including the potential for - * swap) if relaxed is true. + * The memory limit is a guideline for the maximum amount of memory to use + * during graph evaluation. If the memory limit is exceeded and there is no + * more RAM (including swap when available) allocations will result in an + * exception. * - * The memory limit defaults to 1.5 times the maximum recommended working set - * size reported by the device. + * When metal is available the memory limit defaults to 1.5 times the maximum + * recommended working set size reported by the device. * * Returns the previous memory limit. * */ -size_t set_memory_limit(size_t limit, bool relaxed = true); +size_t set_memory_limit(size_t limit); + +/* Get the current memory limit. */ +size_t get_memory_limit(); /* Set the free cache limit. * If using more than the given limit, free memory will be reclaimed diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index b9a82b876..c1d993d2a 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -29,7 +29,7 @@ void RMSNorm::eval_gpu( out.copy_shared_buffer(x); } else { out.set_data( - allocator::malloc_or_wait(x.data_size() * x.itemsize()), + allocator::malloc(x.data_size() * x.itemsize()), x.data_size(), x.strides(), x.flags()); @@ -129,7 +129,7 @@ void RMSNormVJP::eval_gpu( gx.copy_shared_buffer(g); g_in_gx = true; } else { - gx.set_data(allocator::malloc_or_wait(gx.nbytes())); + gx.set_data(allocator::malloc(gx.nbytes())); } if (g_copied && !g_in_gx) { d.add_temporary(g, s.index); @@ -146,11 +146,11 @@ void RMSNormVJP::eval_gpu( if (!g_in_gx && donate_g) { gw_temp.copy_shared_buffer(g); } else { - gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes())); + gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); d.add_temporary(gw_temp, s.index); } } - gw.set_data(allocator::malloc_or_wait(gw.nbytes())); + gw.set_data(allocator::malloc(gw.nbytes())); const int simd_size = 32; const int n_reads = RMS_N_READS; @@ -226,7 +226,7 @@ void LayerNorm::eval_gpu( out.copy_shared_buffer(x); } else { out.set_data( - allocator::malloc_or_wait(x.data_size() * x.itemsize()), + allocator::malloc(x.data_size() * x.itemsize()), x.data_size(), x.strides(), x.flags()); @@ -331,7 +331,7 @@ void LayerNormVJP::eval_gpu( gx.copy_shared_buffer(g); g_in_gx = true; } else { - gx.set_data(allocator::malloc_or_wait(gx.nbytes())); + gx.set_data(allocator::malloc(gx.nbytes())); } if (g_copied && !g_in_gx) { d.add_temporary(g, s.index); @@ -348,12 +348,12 @@ void LayerNormVJP::eval_gpu( if (!g_in_gx && donate_g) { gw_temp.copy_shared_buffer(g); } else { - gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes())); + gw_temp.set_data(allocator::malloc(gw_temp.nbytes())); d.add_temporary(gw_temp, s.index); } } - gw.set_data(allocator::malloc_or_wait(gw.nbytes())); - gb.set_data(allocator::malloc_or_wait(gb.nbytes())); + gw.set_data(allocator::malloc(gw.nbytes())); + gb.set_data(allocator::malloc(gb.nbytes())); // Finish with the gradient for b in case we had a b auto& compute_encoder = d.get_command_encoder(s.index); diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 76f7e3946..67576f03f 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -28,7 +28,7 @@ void arange_set_scalars(T start, T next, metal::CommandEncoder& enc) { void reshape(const array& in, array& out, Stream s) { auto [copy_necessary, out_strides] = prepare_reshape(in, out); if (copy_necessary) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); copy_gpu_inplace( in, out, @@ -58,7 +58,7 @@ static array compute_dynamic_offset( if (donate) { offset.copy_shared_buffer(indices); } else { - offset.set_data(allocator::malloc_or_wait(offset.itemsize())); + offset.set_data(allocator::malloc(offset.itemsize())); } d.add_temporary(offset, s.index); @@ -100,7 +100,7 @@ static array compute_dynamic_offset( void Arange::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 0); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); if (out.size() == 0) { return; } @@ -161,7 +161,7 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); std::string op_name; @@ -333,7 +333,7 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { size_t elems_per_key = out.size() / num_keys; size_t bytes_per_key = out.itemsize() * elems_per_key; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); if (out.size() == 0) { return; } @@ -397,7 +397,7 @@ void DynamicSlice::eval_gpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; auto& start = inputs[1]; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto s = stream(); auto in_offset = compute_dynamic_offset(start, in.strides(), axes_, s); copy_gpu_inplace( @@ -554,7 +554,7 @@ void View::eval_gpu(const std::vector& inputs, array& out) { in, strides, in.flags(), in.data_size() * ibytes / obytes); } else { auto tmp = array(in.shape(), in.dtype(), nullptr, {}); - tmp.set_data(allocator::malloc_or_wait(tmp.nbytes())); + tmp.set_data(allocator::malloc(tmp.nbytes())); copy_gpu_inplace(in, tmp, CopyType::General, stream()); auto flags = out.flags(); diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index d6fddd058..cc32797eb 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -224,7 +224,7 @@ void qvm_split_k( auto temp_shape = out.shape(); temp_shape.insert(temp_shape.end() - 2, split_k); array intermediate(temp_shape, x.dtype(), nullptr, {}); - intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); + intermediate.set_data(allocator::malloc(intermediate.nbytes())); d.add_temporary(intermediate, s.index); std::ostringstream kname; @@ -277,7 +277,7 @@ void qmm_op( int bits, bool gather, const Stream& s) { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); MTL::Size group_dims; MTL::Size grid_dims; @@ -394,7 +394,7 @@ void fast::AffineQuantize::eval_gpu( std::vector& outputs) { auto& w_pre = inputs[0]; auto& out = outputs[0]; - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); @@ -425,8 +425,8 @@ void fast::AffineQuantize::eval_gpu( } else { auto& scales = outputs[1]; auto& biases = outputs[2]; - scales.set_data(allocator::malloc_or_wait(scales.nbytes())); - biases.set_data(allocator::malloc_or_wait(biases.nbytes())); + scales.set_data(allocator::malloc(scales.nbytes())); + biases.set_data(allocator::malloc(biases.nbytes())); compute_encoder.set_output_array(out, 1); compute_encoder.set_output_array(scales, 2); compute_encoder.set_output_array(biases, 3); diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 36e1266e6..c5650bdd7 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -347,7 +347,7 @@ void all_reduce_dispatch( // Allocate an intermediate tensor to hold results if needed array intermediate({n_rows}, out_type, nullptr, {}); - intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); + intermediate.set_data(allocator::malloc(intermediate.nbytes())); d.add_temporary(intermediate, s.index); // 1st pass @@ -641,7 +641,7 @@ void strided_reduce_longcolumn( intermediate_shape.insert( intermediate_shape.end(), out.shape().begin(), out.shape().end()); array intermediate(std::move(intermediate_shape), out_type, nullptr, {}); - intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); + intermediate.set_data(allocator::malloc(intermediate.nbytes())); d.add_temporary(intermediate, s.index); // Prepare the arguments for the kernel @@ -812,7 +812,7 @@ void strided_reduce_2pass( intermediate_shape.insert( intermediate_shape.end(), out.shape().begin(), out.shape().end()); array intermediate(std::move(intermediate_shape), out_type, nullptr, {}); - intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); + intermediate.set_data(allocator::malloc(intermediate.nbytes())); d.add_temporary(intermediate, s.index); // Prepare the arguments for the kernel @@ -950,7 +950,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // Minimum of 4 bytes since we use size 4 structs for all reduce // and metal will complain o/w size_t min_bytes = std::max(out.nbytes(), 4ul); - out.set_data(allocator::malloc_or_wait(min_bytes)); + out.set_data(allocator::malloc(min_bytes)); std::string op_name; switch (reduce_type_) { case Reduce::And: diff --git a/mlx/backend/metal/rope.cpp b/mlx/backend/metal/rope.cpp index ce3384c54..060758333 100644 --- a/mlx/backend/metal/rope.cpp +++ b/mlx/backend/metal/rope.cpp @@ -43,14 +43,14 @@ void RoPE::eval_gpu( donated = true; out.copy_shared_buffer(in); } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); } strides[0] = mat_size; strides[1] = in.strides()[ndim - 2]; strides[2] = in.strides()[ndim - 1]; } else if (dispatch_ndim == 3) { // Handle non-contiguous 3D inputs - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); strides[0] = in.strides()[ndim - 3]; strides[1] = in.strides()[ndim - 2]; strides[2] = in.strides()[ndim - 1]; diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index c5b544852..f64d057ce 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -248,9 +248,9 @@ void sdpa_vector_2pass( intermediate_shape.pop_back(); array sums(intermediate_shape, float32, nullptr, {}); array maxs(std::move(intermediate_shape), float32, nullptr, {}); - intermediate.set_data(allocator::malloc_or_wait(intermediate.nbytes())); - sums.set_data(allocator::malloc_or_wait(sums.nbytes())); - maxs.set_data(allocator::malloc_or_wait(maxs.nbytes())); + intermediate.set_data(allocator::malloc(intermediate.nbytes())); + sums.set_data(allocator::malloc(sums.nbytes())); + maxs.set_data(allocator::malloc(maxs.nbytes())); d.add_temporary(intermediate, s.index); d.add_temporary(sums, s.index); d.add_temporary(maxs, s.index); @@ -383,7 +383,7 @@ void ScaledDotProductAttention::eval_gpu( o.copy_shared_buffer(q); } else { if (o.shape(2) == 1) { - o.set_data(allocator::malloc_or_wait(o.nbytes())); + o.set_data(allocator::malloc(o.nbytes())); } else { auto strides = o.strides(); strides[2] = o.shape(1) * o.shape(3); @@ -391,10 +391,7 @@ void ScaledDotProductAttention::eval_gpu( auto flags = q.flags(); flags.row_contiguous = q.shape(1) == 1; o.set_data( - allocator::malloc_or_wait(o.nbytes()), - o.size(), - std::move(strides), - flags); + allocator::malloc(o.nbytes()), o.size(), std::move(strides), flags); } } @@ -432,7 +429,7 @@ void ScaledDotProductAttention::eval_gpu( }; o.set_data( - allocator::malloc_or_wait(o.nbytes()), + allocator::malloc(o.nbytes()), data_size, {str_oB, str_oH, str_oL, str_oD}, flags); diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index dd123662f..c7e0087b7 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -24,7 +24,7 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { out.copy_shared_buffer(in); } else { out.set_data( - allocator::malloc_or_wait(in.data_size() * out.itemsize()), + allocator::malloc(in.data_size() * out.itemsize()), in.data_size(), in.strides(), in.flags()); diff --git a/mlx/backend/metal/slicing.cpp b/mlx/backend/metal/slicing.cpp index d34e1a747..6ab08a108 100644 --- a/mlx/backend/metal/slicing.cpp +++ b/mlx/backend/metal/slicing.cpp @@ -29,7 +29,7 @@ void concatenate_gpu( } std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto strides = out.strides(); auto flags = out.flags(); diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 854f31f4b..b089188b8 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -33,7 +33,7 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { out.copy_shared_buffer(x); } else { out.set_data( - allocator::malloc_or_wait(x.data_size() * x.itemsize()), + allocator::malloc(x.data_size() * x.itemsize()), x.data_size(), x.strides(), x.flags()); diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index ea16fbf89..543dfd180 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -150,12 +150,11 @@ void multi_block_sort( array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {}); // Do allocations - dev_vals_0.set_data(allocator::malloc_or_wait(dev_vals_0.nbytes())); - dev_vals_1.set_data(allocator::malloc_or_wait(dev_vals_1.nbytes())); - dev_idxs_0.set_data(allocator::malloc_or_wait(dev_idxs_0.nbytes())); - dev_idxs_1.set_data(allocator::malloc_or_wait(dev_idxs_1.nbytes())); - block_partitions.set_data( - allocator::malloc_or_wait(block_partitions.nbytes())); + dev_vals_0.set_data(allocator::malloc(dev_vals_0.nbytes())); + dev_vals_1.set_data(allocator::malloc(dev_vals_1.nbytes())); + dev_idxs_0.set_data(allocator::malloc(dev_idxs_0.nbytes())); + dev_idxs_1.set_data(allocator::malloc(dev_idxs_1.nbytes())); + block_partitions.set_data(allocator::malloc(block_partitions.nbytes())); std::vector copies = { dev_vals_0, dev_vals_1, dev_idxs_0, dev_idxs_1, block_partitions}; @@ -319,7 +318,7 @@ void gpu_merge_sort( void ArgSort::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); @@ -331,7 +330,7 @@ void ArgSort::eval_gpu(const std::vector& inputs, array& out) { void Sort::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); @@ -344,7 +343,7 @@ void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { // We direct arg partition to sort for now assert(inputs.size() == 1); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); @@ -357,7 +356,7 @@ void Partition::eval_gpu(const std::vector& inputs, array& out) { // We direct partition to sort for now assert(inputs.size() == 1); - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); auto& s = stream(); auto& d = metal::device(s.device); diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 741c7d70a..be43c41c2 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -97,13 +97,13 @@ void unary_op_gpu( out.copy_shared_buffer(in); } else { out.set_data( - allocator::malloc_or_wait(in.data_size() * out.itemsize()), + allocator::malloc(in.data_size() * out.itemsize()), in.data_size(), in.strides(), in.flags()); } } else { - out.set_data(allocator::malloc_or_wait(out.nbytes())); + out.set_data(allocator::malloc(out.nbytes())); } unary_op_gpu_inplace(inputs, out, op, s); } diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index 359e63b5d..03c68c734 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -42,7 +42,10 @@ void reset_peak_memory() {} size_t get_cache_memory() { return 0; } -size_t set_memory_limit(size_t, bool) { +size_t set_memory_limit(size_t) { + return 0; +} +size_t get_memory_limit() { return 0; } size_t set_cache_limit(size_t) { diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 958899bec..105a0fa28 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -218,7 +218,9 @@ array eval_impl(std::vector outputs, bool async) { cpu::eval(arr); } - if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS) { + if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS || + (metal::get_active_memory() > metal::get_memory_limit() && + scheduler::n_active_tasks() > 0)) { // Commit any open streams for (auto& [_, e] : events) { if (e.stream().device == Device::gpu) { @@ -226,6 +228,11 @@ array eval_impl(std::vector outputs, bool async) { } } scheduler::wait_for_one(); + // TODO memory api should be moved out of metal + while (metal::get_active_memory() > metal::get_memory_limit() && + scheduler::n_active_tasks() > 0) { + scheduler::wait_for_one(); + } } auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) { diff --git a/python/src/metal.cpp b/python/src/metal.cpp index 5ef5691fb..fef856dd9 100644 --- a/python/src/metal.cpp +++ b/python/src/metal.cpp @@ -57,23 +57,19 @@ void init_metal(nb::module_& m) { "set_memory_limit", &mx::metal::set_memory_limit, "limit"_a, - nb::kw_only(), - "relaxed"_a = true, R"pbdoc( Set the memory limit. - Memory allocations will wait on scheduled tasks to complete if the limit - is exceeded. If there are no more scheduled tasks an error will be raised - if ``relaxed`` is ``False``. Otherwise memory will be allocated - (including the potential for swap) if ``relaxed`` is ``True``. + The memory limit is a guideline for the maximum amount of memory to use + during graph evaluation. If the memory limit is exceeded and there is no + more RAM (including swap when available) allocations will result in an + exception. - The memory limit defaults to 1.5 times the maximum recommended working set - size reported by the device. + When metal is available the memory limit defaults to 1.5 times the + maximum recommended working set size reported by the device. Args: limit (int): Memory limit in bytes. - relaxed (bool, optional): If `False`` an error is raised if the limit - is exceeded. Default: ``True`` Returns: int: The previous memory limit in bytes. diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 510402b06..ebcf64c7a 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -185,6 +185,18 @@ class TestEval(mlx_tests.MLXTestCase): x = mx.abs(x, stream=s2) mx.eval(x) + s1 = mx.default_stream(mx.gpu) + s2 = mx.new_stream(mx.gpu) + old_limit = mx.metal.set_memory_limit(1000) + + x = mx.ones((512, 512), stream=s2) + for _ in range(80): + x = mx.abs(x, stream=s1) + y = mx.abs(x, stream=s2) + z = mx.abs(y, stream=s2) + mx.eval(z) + mx.metal.set_memory_limit(old_limit) + if __name__ == "__main__": unittest.main()