diff --git a/docs/src/usage/compile.rst b/docs/src/usage/compile.rst index 97d5503a3..091505fe4 100644 --- a/docs/src/usage/compile.rst +++ b/docs/src/usage/compile.rst @@ -33,12 +33,12 @@ Let's start with a simple example: # Compile the function compiled_fun = mx.compile(fun) - # Prints: array(2.36788, dtype=float32) + # Prints: array(2.36788, dtype=float32) print(compiled_fun(x, y)) The output of both the regular function and the compiled function is the same up to numerical precision. - + The first time you call a compiled function, MLX will build the compute graph, optimize it, and generate and compile code. This can be relatively slow. However, MLX will cache compiled functions, so calling a compiled @@ -96,7 +96,7 @@ element-wise operations: .. code-block:: python - def gelu(x): + def gelu(x): return x * (1 + mx.erf(x / math.sqrt(2))) / 2 If you use this function with small arrays, it will be overhead bound. If you @@ -136,13 +136,6 @@ Now make an array, and benchmark both functions: On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is five times faster. -.. note:: - - As of the latest MLX, CPU functions are not fully compiled. Compiling CPU - functions can still be helpful, but won't typically result in as large a - speedup as compiling operations that run on the GPU. - - Debugging --------- @@ -287,7 +280,7 @@ to the function. In some cases this can be pretty inconvenient. Hence, print(fun(mx.array(1.0))) -Compiling Training Graphs +Compiling Training Graphs ------------------------- This section will step through how to use :func:`compile` with a simple example @@ -297,7 +290,7 @@ full forward, backward, and update with :func:`compile`. To start, here is the simple example without any compilation: -.. code-block:: python +.. code-block:: python import mlx.core as mx import mlx.nn as nn @@ -330,7 +323,7 @@ To start, here is the simple example without any compilation: To compile the update we can put it all in a function and compile it with the appropriate input and output captures. Here's the same example but compiled: -.. code-block:: python +.. code-block:: python import mlx.core as mx import mlx.nn as nn @@ -355,7 +348,7 @@ appropriate input and output captures. Here's the same example but compiled: # The state that will be captured as input and output state = [model.state, optimizer.state] - + @partial(mx.compile, inputs=state, outputs=state) def step(x, y): loss_and_grad_fn = nn.value_and_grad(model, loss_fn) @@ -410,7 +403,7 @@ Compiling transformed functions works just as expected: In order to compile as much as possible, a transformation of a compiled function will not by default be compiled. To compile the transformed - function simply pass it through :func:`compile`. + function simply pass it through :func:`compile`. You can also compile functions which themselves call compiled functions. A good practice is to compile the outer most function to give :func:`compile` diff --git a/docs/src/usage/function_transforms.rst b/docs/src/usage/function_transforms.rst index 77e58058a..9a15bbf1c 100644 --- a/docs/src/usage/function_transforms.rst +++ b/docs/src/usage/function_transforms.rst @@ -25,7 +25,7 @@ Here is a simple example: The output of :func:`grad` on :func:`sin` is simply another function. In this case it is the gradient of the sine function which is exactly the cosine -function. To get the second derivative you can do: +function. To get the second derivative you can do: .. code-block:: shell @@ -50,7 +50,7 @@ Automatic Differentiation .. _auto diff: Automatic differentiation in MLX works on functions rather than on implicit -graphs. +graphs. .. note:: @@ -114,7 +114,7 @@ way to do that is the following: def loss_fn(params, x, y): w, b = params["weight"], params["bias"] - h = w * x + b + h = w * x + b return mx.mean(mx.square(h - y)) params = {"weight": mx.array(1.0), "bias": mx.array(0.0)} @@ -132,7 +132,7 @@ way to do that is the following: Notice the tree structure of the parameters is preserved in the gradients. -In some cases you may want to stop gradients from propagating through a +In some cases you may want to stop gradients from propagating through a part of the function. You can use the :func:`stop_gradient` for that. @@ -166,14 +166,14 @@ A naive way to add the elements from two sets of vectors is with a loop: Instead you can use :func:`vmap` to automatically vectorize the addition: .. code-block:: python - + # Vectorize over the second dimension of x and the # first dimension of y vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0)) The ``in_axes`` parameter can be used to specify which dimensions of the corresponding input to vectorize over. Similarly, use ``out_axes`` to specify -where the vectorized axes should be in the outputs. +where the vectorized axes should be in the outputs. Let's time these two different versions: diff --git a/docs/src/usage/indexing.rst b/docs/src/usage/indexing.rst index 458541923..62994a0fb 100644 --- a/docs/src/usage/indexing.rst +++ b/docs/src/usage/indexing.rst @@ -51,7 +51,7 @@ You can also use an :obj:`array` to index another :obj:`array`: .. code-block:: shell >>> arr = mx.arange(10) - >>> idx = mx.array([5, 7]) + >>> idx = mx.array([5, 7]) >>> arr[idx] array([5, 7], dtype=int32) @@ -82,7 +82,7 @@ general, MLX has limited support for operations for which outputs operations which MLX does not yet support include :func:`numpy.nonzero` and the single input version of :func:`numpy.where`. -In Place Updates +In Place Updates ---------------- In place updates to indexed arrays are possible in MLX. For example: diff --git a/docs/src/usage/lazy_evaluation.rst b/docs/src/usage/lazy_evaluation.rst index bd64f919c..466edaaed 100644 --- a/docs/src/usage/lazy_evaluation.rst +++ b/docs/src/usage/lazy_evaluation.rst @@ -13,7 +13,7 @@ compute graph is recorded. The actual computation only happens if an :func:`eval` is performed. MLX uses lazy evaluation because it has some nice features, some of which we -describe below. +describe below. Transforming Compute Graphs ^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -116,7 +116,7 @@ saving functions) will also evaluate the array. Calling :func:`array.item` on a scalar array will also evaluate it. In the example above, printing the loss (``print(loss)``) or adding the loss scalar to -a list (``losses.append(loss.item())``) would cause a graph evaluation. If +a list (``losses.append(loss.item())``) would cause a graph evaluation. If these lines are before ``mx.eval(loss, model.parameters())`` then this will be a partial evaluation, computing only the forward pass. diff --git a/docs/src/usage/numpy.rst b/docs/src/usage/numpy.rst index 6edb94b8b..c589f1887 100644 --- a/docs/src/usage/numpy.rst +++ b/docs/src/usage/numpy.rst @@ -3,10 +3,10 @@ Conversion to NumPy and Other Frameworks ======================================== -MLX array supports conversion between other frameworks with either: +MLX array supports conversion between other frameworks with either: -* The `Python Buffer Protocol `_. -* `DLPack `_. +* The `Python Buffer Protocol `_. +* `DLPack `_. Let's convert an array to NumPy and back. @@ -66,7 +66,7 @@ even though no in-place operations on MLX memory are executed. PyTorch ------- -.. warning:: +.. warning:: PyTorch Support for :obj:`memoryview` is experimental and can break for multi-dimensional arrays. Casting to NumPy first is advised for now. diff --git a/docs/src/usage/quick_start.rst b/docs/src/usage/quick_start.rst index 251f5344c..bc1d92ad6 100644 --- a/docs/src/usage/quick_start.rst +++ b/docs/src/usage/quick_start.rst @@ -64,4 +64,4 @@ Other gradient transformations include :func:`vjp` for vector-Jacobian products and :func:`jvp` for Jacobian-vector products. Use :func:`value_and_grad` to efficiently compute both a function's output and -gradient with respect to the function's input. +gradient with respect to the function's input. diff --git a/docs/src/usage/saving_and_loading.rst b/docs/src/usage/saving_and_loading.rst index c142bc776..43f2a7999 100644 --- a/docs/src/usage/saving_and_loading.rst +++ b/docs/src/usage/saving_and_loading.rst @@ -8,33 +8,33 @@ Saving and Loading Arrays MLX supports multiple array serialization formats. .. list-table:: Serialization Formats - :widths: 20 8 25 25 + :widths: 20 8 25 25 :header-rows: 1 - * - Format - - Extension + * - Format + - Extension - Function - - Notes - * - NumPy - - ``.npy`` + - Notes + * - NumPy + - ``.npy`` - :func:`save` - Single arrays only - * - NumPy archive - - ``.npz`` + * - NumPy archive + - ``.npz`` - :func:`savez` and :func:`savez_compressed` - - Multiple arrays + - Multiple arrays * - Safetensors - - ``.safetensors`` + - ``.safetensors`` - :func:`save_safetensors` - - Multiple arrays - * - GGUF - - ``.gguf`` + - Multiple arrays + * - GGUF + - ``.gguf`` - :func:`save_gguf` - Multiple arrays The :func:`load` function will load any of the supported serialization formats. It determines the format from the extensions. The output of -:func:`load` depends on the format. +:func:`load` depends on the format. Here's an example of saving a single array to a file: diff --git a/docs/src/usage/unified_memory.rst b/docs/src/usage/unified_memory.rst index a53477d65..57f54bc23 100644 --- a/docs/src/usage/unified_memory.rst +++ b/docs/src/usage/unified_memory.rst @@ -20,7 +20,7 @@ Both ``a`` and ``b`` live in unified memory. In MLX, rather than moving arrays to devices, you specify the device when you run the operation. Any device can perform any operation on ``a`` and ``b`` -without needing to move them from one memory location to another. For example: +without needing to move them from one memory location to another. For example: .. code-block:: python diff --git a/mlx/array.cpp b/mlx/array.cpp index e7040eb5f..374c2d36f 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -180,6 +180,7 @@ void array::move_shared_buffer( auto char_offset = sizeof(char) * itemsize() * offset; array_desc_->data_ptr = static_cast( static_cast(other.array_desc_->data_ptr) + char_offset); + other.array_desc_->data_ptr = nullptr; } void array::move_shared_buffer(array other) { diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 3c9f04dfe..8c1f80291 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -205,7 +205,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) { // Allocate new buffer if needed size_t res_opt = MTL::ResourceStorageModeShared; - res_opt |= MTL::ResourceHazardTrackingModeTracked; + res_opt |= MTL::ResourceHazardTrackingModeUntracked; lk.unlock(); buf = device_->newBuffer(size, res_opt); lk.lock(); diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index a7a62644d..1cc8a2a76 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -918,14 +918,8 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { "[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions."); } - // Clear copies - if (!copies.empty()) { - auto command_buffer = d.get_command_buffer(s.index); - command_buffer->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + // Record copies + d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 68898e914..fd002af31 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -77,12 +77,7 @@ void CustomKernel::eval_gpu( MTL::Size grid_dims = MTL::Size(gx, gy, gz); compute_encoder->dispatchThreads(grid_dims, group_dims); - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core::fast diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 3a565e902..cc0694ca8 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -20,7 +20,6 @@ namespace { // TODO nicer way to set this or possibly expose as an environment variable constexpr int MAX_BUFFERS_PER_QUEUE = 12; -constexpr int MAX_DISPATCHES_PER_ENCODER = 2; constexpr const char* default_mtllib_path = METAL_PATH; @@ -121,33 +120,41 @@ MTL::Library* load_library( } // namespace -CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) { - enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent); - enc->retain(); +CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) { + enc_ = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent); + enc_->retain(); } CommandEncoder::~CommandEncoder() { - enc->endEncoding(); - enc->release(); + enc_->endEncoding(); + enc_->release(); +} + +void CommandEncoder::set_array( + const array& a, + int idx, + int64_t offset /* = 0 */) { + auto r_buf = static_cast(const_cast(a.buffer().ptr())); + if (auto it = outputs_.find(r_buf); it != outputs_.end()) { + // Insert a barrier + enc_->memoryBarrier(&r_buf, 1); + + // Remove the output + outputs_.erase(it); + } + auto a_buf = static_cast(a.buffer().ptr()); + auto base_offset = a.data() - + static_cast(const_cast(a_buf)->contents()); + base_offset += offset; + enc_->setBuffer(a_buf, base_offset, idx); } void CommandEncoder::set_input_array( const array& a, int idx, int64_t offset /* = 0 */) { - auto r_buf = static_cast(const_cast(a.buffer().ptr())); - if (auto it = outputs.find(r_buf); it != outputs.end()) { - // Insert a barrier - enc->memoryBarrier(&r_buf, 1); - - // Remove the output - outputs.erase(it); - } - auto a_buf = static_cast(a.buffer().ptr()); - auto base_offset = a.data() - - static_cast(const_cast(a_buf)->contents()); - base_offset += offset; - enc->setBuffer(a_buf, base_offset, idx); + all_inputs_.insert(a.buffer().ptr()); + set_array(a, idx, offset); } void CommandEncoder::set_output_array( @@ -155,40 +162,26 @@ void CommandEncoder::set_output_array( int idx, int64_t offset /* = 0 */) { // Add barriers before adding the output to the output set - set_input_array(a, idx, offset); + set_array(a, idx, offset); + all_outputs_.insert(a.buffer().ptr()); auto buf = static_cast(a.buffer().ptr()); - if (concurrent) { - concurrent_outputs.insert(buf); + if (concurrent_) { + concurrent_outputs_.insert(buf); } else { - outputs.insert(buf); + outputs_.insert(buf); } } void CommandEncoder::dispatchThreadgroups( MTL::Size grid_dims, MTL::Size group_dims) { - num_dispatches++; - enc->dispatchThreadgroups(grid_dims, group_dims); - maybe_split(); + enc_->dispatchThreadgroups(grid_dims, group_dims); } void CommandEncoder::dispatchThreads( MTL::Size grid_dims, MTL::Size group_dims) { - num_dispatches++; - enc->dispatchThreads(grid_dims, group_dims); - maybe_split(); -} - -void CommandEncoder::maybe_split() { - if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) { - enc->endEncoding(); - enc->release(); - num_dispatches = 0; - outputs.clear(); - enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent); - enc->retain(); - } + enc_->dispatchThreads(grid_dims, group_dims); } Device::Device() { @@ -199,12 +192,6 @@ Device::Device() { Device::~Device() { auto pool = new_scoped_memory_pool(); - for (auto& q : queue_map_) { - q.second->release(); - } - for (auto& b : buffer_map_) { - b.second.second->release(); - } for (auto& k : kernel_map_) { k.second->release(); } @@ -225,61 +212,125 @@ void Device::new_queue(int index) { throw std::runtime_error( "[metal::Device] Failed to make new command queue."); } - queue_map_.insert({index, q}); - buffer_map_.insert({index, {0, nullptr}}); - encoder_map_.insert({index, nullptr}); + stream_map_.emplace(index, q); } int Device::get_command_buffer_ops(int index) { - return buffer_map_[index].first; + return get_stream_(index).buffer_ops; } void Device::increment_command_buffer_ops(int index) { - buffer_map_[index].first++; + get_stream_(index).buffer_ops++; } MTL::CommandBuffer* Device::get_command_buffer(int index) { - auto bit = buffer_map_.find(index); - if (bit->second.second == nullptr) { - auto qit = queue_map_.find(index); - if (qit == queue_map_.end()) { - throw std::runtime_error( - "[metal::Device] Attempting to get command buffer for invalid queue."); - } - - auto cb = qit->second->commandBufferWithUnretainedReferences(); - - if (!cb) { + auto& stream = get_stream_(index); + if (stream.buffer == nullptr) { + stream.buffer = stream.queue->commandBufferWithUnretainedReferences(); + if (!stream.buffer) { throw std::runtime_error( "[metal::Device] Unable to create new command buffer"); } - // Increment ref count so the buffer is not garbage collected - cb->retain(); - - bit->second = {0, cb}; + stream.buffer->retain(); } - return bit->second.second; + return stream.buffer; } void Device::commit_command_buffer(int index) { - auto bit = buffer_map_.find(index); - bit->second.second->commit(); - bit->second.second->release(); - bit->second = {0, nullptr}; + auto& stream = get_stream_(index); + stream.buffer->commit(); + stream.buffer->release(); + stream.buffer = nullptr; + stream.buffer_ops = 0; +} + +void Device::add_temporary(array arr, int index) { + get_stream_(index).temporaries.push_back(std::move(arr)); +} + +void Device::add_temporaries(std::vector arrays, int index) { + if (arrays.empty()) { + return; + } + auto& stream = get_stream_(index); + stream.temporaries.insert( + stream.temporaries.end(), + std::make_move_iterator(arrays.begin()), + std::make_move_iterator(arrays.end())); } void Device::end_encoding(int index) { - encoder_map_[index] = nullptr; + auto& stream = get_stream_(index); + if (stream.encoder != nullptr) { + // Each command encoder has a unique fence. We also store a map of + // all previous outputs of command encoders to their corresponding fence. + // - The command encoder records its inputs and outputs. + // - Wait on a fence if any inputs in the encoder are outputs of a previous + // encoder. + // - Update the map of outputs to include this command encoder's outputs. + // - Always signal this command encoders fence. + // - Add a completion handler for this command encoder that removes outputs + // from the map to limit the growth of the map and avoid unecessary waits + // - Temporaries are a special case as they do not cross command encoder + // boundaries. These can be removed early from the encoders inputs and + // outputs since they don't need synchronization. + auto& enc = *stream.encoder; + // Remove temporaries from inputs and outputs + for (auto& t : stream.temporaries) { + if (t.data() != nullptr) { + enc.outputs().erase(t.buffer().ptr()); + enc.inputs().erase(t.buffer().ptr()); + } + } + + // Keep references to the fences we waited on and put them + // in the completion handler so they are not prematurely released + std::unordered_set> waiting_on; + { + std::lock_guard lk(stream.fence_mtx); + for (auto in : enc.inputs()) { + if (auto it = stream.outputs.find(in); it != stream.outputs.end()) { + // If we've already waited on a fence, don't wait on it again. + if (waiting_on.find(it->second) == waiting_on.end()) { + enc->waitForFence(it->second->fence); + waiting_on.insert(it->second); + } + } + } + for (auto out : enc.outputs()) { + stream.outputs[out] = stream.fence; + } + } + enc->updateFence(stream.fence->fence); + stream.buffer->addCompletedHandler( + [&stream, + waiting_on = std::move(waiting_on), + fence = std::move(stream.fence), + outputs = std::move(enc.outputs()), + temporaries = + std::move(stream.temporaries)](MTL::CommandBuffer*) mutable { + temporaries.clear(); + std::lock_guard lk(stream.fence_mtx); + for (auto o : outputs) { + if (auto it = stream.outputs.find(o); it != stream.outputs.end()) { + if (it->second == fence) { + stream.outputs.erase(it); + } + } + } + }); + } + stream.encoder = nullptr; } CommandEncoder& Device::get_command_encoder(int index) { - auto eit = encoder_map_.find(index); - if (eit->second == nullptr) { - auto cb = get_command_buffer(index); - eit->second = std::make_unique(cb); + auto& stream = get_stream_(index); + if (stream.encoder == nullptr) { + stream.encoder = std::make_unique(stream.buffer); + stream.fence = std::make_shared(device_->newFence()); } - return *(eit->second); + return *stream.encoder; } void Device::register_library( diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 7f851f929..d15a4aaf8 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -45,13 +45,13 @@ struct CommandEncoder { struct ConcurrentContext { ConcurrentContext(CommandEncoder& enc) : enc(enc) { - enc.concurrent = true; + enc.concurrent_ = true; } ~ConcurrentContext() { - enc.concurrent = false; - enc.outputs.insert( - enc.concurrent_outputs.begin(), enc.concurrent_outputs.end()); - enc.concurrent_outputs.clear(); + enc.concurrent_ = false; + enc.outputs_.insert( + enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end()); + enc.concurrent_outputs_.clear(); } private: @@ -59,7 +59,7 @@ struct CommandEncoder { }; MTL::ComputeCommandEncoder* operator->() { - return enc; + return enc_; } void set_input_array(const array& a, int idx, int64_t offset = 0); @@ -70,18 +70,60 @@ struct CommandEncoder { ConcurrentContext start_concurrent() { return ConcurrentContext(*this); } - ~CommandEncoder(); - private: - void maybe_split(); + // Inputs to all kernels in the encoder including temporaries + std::unordered_set& inputs() { + return all_inputs_; + }; - int num_dispatches{0}; - MTL::CommandBuffer* cbuf; - MTL::ComputeCommandEncoder* enc; - bool concurrent{false}; - std::unordered_set outputs; - std::unordered_set concurrent_outputs; + // Outputs of all kernels in the encoder including temporaries + std::unordered_set outputs() { + return all_outputs_; + }; + + private: + void set_array(const array& a, int idx, int64_t offset); + MTL::ComputeCommandEncoder* enc_; + bool concurrent_{false}; + std::unordered_set outputs_; + std::unordered_set concurrent_outputs_; + std::unordered_set all_inputs_; + std::unordered_set all_outputs_; +}; + +struct Fence { + Fence(MTL::Fence* fence) : fence(fence) {} + ~Fence() { + fence->release(); + } + MTL::Fence* fence; +}; + +struct DeviceStream { + DeviceStream(MTL::CommandQueue* queue) : queue(queue) {}; + ~DeviceStream() { + queue->release(); + if (buffer != nullptr) { + buffer->release(); + } + }; + MTL::CommandQueue* queue; + // A map of prior command encoder outputs to their corresponding fence + std::unordered_map> outputs; + // Used to allow thread-safe access to the outputs map + std::mutex fence_mtx; + + // The buffer and buffer op count are updated + // between command buffers + MTL::CommandBuffer* buffer{nullptr}; + int buffer_ops{0}; + + // The command encoder, fence, and temporaries are updated between command + // encoders + std::unique_ptr encoder{nullptr}; + std::shared_ptr fence; + std::vector temporaries; }; class Device { @@ -136,7 +178,14 @@ class Device { MTL::ArgumentEncoder* argument_encoder( const std::vector& arg_descs) const; + // Record temporary arrays for the given stream index + void add_temporary(array arr, int index); + void add_temporaries(std::vector arrays, int index); + private: + DeviceStream& get_stream_(int index) { + return stream_map_.find(index)->second; + } MTL::Library* get_library_cache_(const std::string& name); MTL::Library* get_library_(const std::string& name); @@ -170,9 +219,7 @@ class Device { const std::vector& linked_functions = {}); MTL::Device* device_; - std::unordered_map queue_map_; - std::unordered_map> buffer_map_; - std::unordered_map> encoder_map_; + std::unordered_map stream_map_; std::shared_mutex kernel_mtx_; std::unordered_map kernel_map_; diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 791e1fa00..43ded5378 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -575,10 +575,7 @@ void fft_op( auto plan = plan_fft(n); if (plan.four_step) { four_step_fft(in, out, axis, inverse, real, plan, copies, s); - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); + d.add_temporaries(std::move(copies), s.index); return; } @@ -744,12 +741,7 @@ void fft_op( compute_encoder->dispatchThreads(grid_dims, group_dims); } - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + d.add_temporaries(std::move(copies), s.index); } void fft_op( @@ -792,8 +784,7 @@ void nd_fft_op( } auto& d = metal::device(s.device); - d.get_command_buffer(s.index)->addCompletedHandler( - [temp_arrs](MTL::CommandBuffer*) mutable { temp_arrs.clear(); }); + d.add_temporaries(std::move(temp_arrs), s.index); } void FFT::eval_gpu(const std::vector& inputs, array& out) { diff --git a/mlx/backend/metal/hadamard.cpp b/mlx/backend/metal/hadamard.cpp index dd89b415e..dc2268f7d 100644 --- a/mlx/backend/metal/hadamard.cpp +++ b/mlx/backend/metal/hadamard.cpp @@ -174,12 +174,7 @@ void Hadamard::eval_gpu(const std::vector& inputs, array& out) { launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_); } - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 3620a5676..3d8973fae 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -226,13 +226,8 @@ void steel_matmul_regular( compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - // Clear copies - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + // Record copies + d.add_temporaries(std::move(copies), s.index); } void steel_matmul( @@ -382,12 +377,7 @@ void steel_matmul( compute_encoder.dispatchThreads(grid_dims, group_dims); } - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + d.add_temporaries(std::move(copies), s.index); return; } @@ -435,8 +425,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { if (a_pre.size() == 0 || b_pre.size() == 0) { array zero = array(0, a_pre.dtype()); fill_gpu(zero, out, s); - auto command_buffer = d.get_command_buffer(s.index); - command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {}); + d.add_temporary(std::move(zero), s.index); return; } @@ -588,12 +577,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + d.add_temporaries(std::move(copies), s.index); return; } ///////////////////////////////////////////////////////////////////////////// @@ -798,12 +782,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + d.add_temporaries(std::move(copies), s.index); return; } @@ -916,12 +895,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreads(grid_dims, group_dims); } - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + d.add_temporaries(std::move(copies), s.index); return; } @@ -1056,12 +1030,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + d.add_temporaries(std::move(copies), s.index); } void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { @@ -1080,8 +1049,7 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { if (a_pre.size() == 0 || b_pre.size() == 0) { array zero = array(0, a_pre.dtype()); fill_gpu(zero, out, s); - auto command_buffer = d.get_command_buffer(s.index); - command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {}); + d.add_temporary(std::move(zero), s.index); return; } @@ -1356,12 +1324,7 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + d.add_temporaries(std::move(copies), s.index); return; } @@ -1471,13 +1434,7 @@ void BlockMaskedMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - // Clear copies - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + d.add_temporaries(std::move(copies), s.index); } void GatherMM::eval_gpu(const std::vector& inputs, array& out) { @@ -1496,8 +1453,7 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { if (a_pre.size() == 0 || b_pre.size() == 0) { array zero = array(0, a_pre.dtype()); fill_gpu(zero, out, s); - auto command_buffer = d.get_command_buffer(s.index); - command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {}); + d.add_temporary(std::move(zero), s.index); return; } @@ -1703,12 +1659,7 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + d.add_temporaries(std::move(copies), s.index); return; } @@ -1847,13 +1798,7 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreadgroups(grid_dims, group_dims); - // Clear copies - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core diff --git a/mlx/backend/metal/normalization.cpp b/mlx/backend/metal/normalization.cpp index b9053fa4f..cdab18368 100644 --- a/mlx/backend/metal/normalization.cpp +++ b/mlx/backend/metal/normalization.cpp @@ -91,12 +91,8 @@ void RMSNorm::eval_gpu( compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1); compute_encoder.dispatchThreads(grid_dims, group_dims); } - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + + d.add_temporaries(std::move(copies), s.index); } void RMSNormVJP::eval_gpu( @@ -204,10 +200,7 @@ void RMSNormVJP::eval_gpu( strided_reduce_general_dispatch( gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s); - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); + d.add_temporaries(std::move(copies), s.index); } void LayerNorm::eval_gpu( @@ -292,12 +285,8 @@ void LayerNorm::eval_gpu( compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7); compute_encoder.dispatchThreads(grid_dims, group_dims); } - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + + d.add_temporaries(std::move(copies), s.index); } void LayerNormVJP::eval_gpu( @@ -425,10 +414,7 @@ void LayerNormVJP::eval_gpu( gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s); } - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); + d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core::fast diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index d1c594cb4..30828da70 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -145,6 +145,7 @@ void launch_qmm( } compute_encoder.dispatchThreadgroups(grid_dims, group_dims); + d.add_temporaries(std::move(copies), s.index); } void qmm_op( @@ -350,12 +351,7 @@ void fast::AffineQuantize::eval_gpu( : MTL::Size(nthreads, 1, 1); compute_encoder.dispatchThreads(grid_dims, group_dims); - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 5881effa4..d8906a819 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -660,12 +660,7 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { in, out, op_name, plan, axes_, compute_encoder, d, s); } - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + d.add_temporaries(std::move(copies), s.index); } // Nothing to reduce just initialize the output diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 31811bb7c..7a3fc03ba 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -262,12 +262,7 @@ void ScaledDotProductAttention::eval_gpu( sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o); } - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core::fast diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp index a113353a7..8737b25e4 100644 --- a/mlx/backend/metal/scan.cpp +++ b/mlx/backend/metal/scan.cpp @@ -107,13 +107,7 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.dispatchThreads(grid_dims, group_dims); } - if (!copies.empty()) { - auto command_buffer = d.get_command_buffer(s.index); - command_buffer->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core diff --git a/mlx/backend/metal/softmax.cpp b/mlx/backend/metal/softmax.cpp index 7af5d6f53..5bb7e66a4 100644 --- a/mlx/backend/metal/softmax.cpp +++ b/mlx/backend/metal/softmax.cpp @@ -88,12 +88,8 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { compute_encoder->setBytes(&axis_size, sizeof(int), 2); compute_encoder.dispatchThreads(grid_dims, group_dims); } - if (!copies.empty()) { - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); - } + + d.add_temporaries(std::move(copies), s.index); } } // namespace mlx::core diff --git a/mlx/backend/metal/sort.cpp b/mlx/backend/metal/sort.cpp index 0de69f9c5..925a5ccd9 100644 --- a/mlx/backend/metal/sort.cpp +++ b/mlx/backend/metal/sort.cpp @@ -252,11 +252,7 @@ void multi_block_sort( (axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General, s); - // Clear copies - d.get_command_buffer(s.index)->addCompletedHandler( - [copies = std::move(copies)](MTL::CommandBuffer*) mutable { - copies.clear(); - }); + d.add_temporaries(std::move(copies), s.index); } void gpu_merge_sort( diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 81e6ccedf..8e496ab06 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -769,7 +769,6 @@ class TestCompile(mlx_tests.MLXTestCase): out = mx.compile(fn)(a, b) expected = fn(a, b) - print((out - expected).abs().max()) self.assertTrue(mx.allclose(out, expected)) def test_compile_many_inputs(self):