Remove Hazard tracking with Fences (#1509)

* remove hazard tracking

* with fence map

* no hazard tracking with fences

* nits

* fix fence retain

* cleanup

* fix quantized rebase
This commit is contained in:
Awni Hannun 2024-10-21 19:33:32 -07:00 committed by GitHub
parent d15fa13daf
commit c26208f67d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 268 additions and 299 deletions

View File

@ -33,12 +33,12 @@ Let's start with a simple example:
# Compile the function # Compile the function
compiled_fun = mx.compile(fun) compiled_fun = mx.compile(fun)
# Prints: array(2.36788, dtype=float32) # Prints: array(2.36788, dtype=float32)
print(compiled_fun(x, y)) print(compiled_fun(x, y))
The output of both the regular function and the compiled function is the same The output of both the regular function and the compiled function is the same
up to numerical precision. up to numerical precision.
The first time you call a compiled function, MLX will build the compute The first time you call a compiled function, MLX will build the compute
graph, optimize it, and generate and compile code. This can be relatively graph, optimize it, and generate and compile code. This can be relatively
slow. However, MLX will cache compiled functions, so calling a compiled slow. However, MLX will cache compiled functions, so calling a compiled
@ -96,7 +96,7 @@ element-wise operations:
.. code-block:: python .. code-block:: python
def gelu(x): def gelu(x):
return x * (1 + mx.erf(x / math.sqrt(2))) / 2 return x * (1 + mx.erf(x / math.sqrt(2))) / 2
If you use this function with small arrays, it will be overhead bound. If you If you use this function with small arrays, it will be overhead bound. If you
@ -136,13 +136,6 @@ Now make an array, and benchmark both functions:
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster. five times faster.
.. note::
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
functions can still be helpful, but won't typically result in as large a
speedup as compiling operations that run on the GPU.
Debugging Debugging
--------- ---------
@ -287,7 +280,7 @@ to the function. In some cases this can be pretty inconvenient. Hence,
print(fun(mx.array(1.0))) print(fun(mx.array(1.0)))
Compiling Training Graphs Compiling Training Graphs
------------------------- -------------------------
This section will step through how to use :func:`compile` with a simple example This section will step through how to use :func:`compile` with a simple example
@ -297,7 +290,7 @@ full forward, backward, and update with :func:`compile`.
To start, here is the simple example without any compilation: To start, here is the simple example without any compilation:
.. code-block:: python .. code-block:: python
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -330,7 +323,7 @@ To start, here is the simple example without any compilation:
To compile the update we can put it all in a function and compile it with the To compile the update we can put it all in a function and compile it with the
appropriate input and output captures. Here's the same example but compiled: appropriate input and output captures. Here's the same example but compiled:
.. code-block:: python .. code-block:: python
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -355,7 +348,7 @@ appropriate input and output captures. Here's the same example but compiled:
# The state that will be captured as input and output # The state that will be captured as input and output
state = [model.state, optimizer.state] state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state) @partial(mx.compile, inputs=state, outputs=state)
def step(x, y): def step(x, y):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn) loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
@ -410,7 +403,7 @@ Compiling transformed functions works just as expected:
In order to compile as much as possible, a transformation of a compiled In order to compile as much as possible, a transformation of a compiled
function will not by default be compiled. To compile the transformed function will not by default be compiled. To compile the transformed
function simply pass it through :func:`compile`. function simply pass it through :func:`compile`.
You can also compile functions which themselves call compiled functions. A You can also compile functions which themselves call compiled functions. A
good practice is to compile the outer most function to give :func:`compile` good practice is to compile the outer most function to give :func:`compile`

View File

@ -25,7 +25,7 @@ Here is a simple example:
The output of :func:`grad` on :func:`sin` is simply another function. In this The output of :func:`grad` on :func:`sin` is simply another function. In this
case it is the gradient of the sine function which is exactly the cosine case it is the gradient of the sine function which is exactly the cosine
function. To get the second derivative you can do: function. To get the second derivative you can do:
.. code-block:: shell .. code-block:: shell
@ -50,7 +50,7 @@ Automatic Differentiation
.. _auto diff: .. _auto diff:
Automatic differentiation in MLX works on functions rather than on implicit Automatic differentiation in MLX works on functions rather than on implicit
graphs. graphs.
.. note:: .. note::
@ -114,7 +114,7 @@ way to do that is the following:
def loss_fn(params, x, y): def loss_fn(params, x, y):
w, b = params["weight"], params["bias"] w, b = params["weight"], params["bias"]
h = w * x + b h = w * x + b
return mx.mean(mx.square(h - y)) return mx.mean(mx.square(h - y))
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)} params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
@ -132,7 +132,7 @@ way to do that is the following:
Notice the tree structure of the parameters is preserved in the gradients. Notice the tree structure of the parameters is preserved in the gradients.
In some cases you may want to stop gradients from propagating through a In some cases you may want to stop gradients from propagating through a
part of the function. You can use the :func:`stop_gradient` for that. part of the function. You can use the :func:`stop_gradient` for that.
@ -166,14 +166,14 @@ A naive way to add the elements from two sets of vectors is with a loop:
Instead you can use :func:`vmap` to automatically vectorize the addition: Instead you can use :func:`vmap` to automatically vectorize the addition:
.. code-block:: python .. code-block:: python
# Vectorize over the second dimension of x and the # Vectorize over the second dimension of x and the
# first dimension of y # first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0)) vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
The ``in_axes`` parameter can be used to specify which dimensions of the The ``in_axes`` parameter can be used to specify which dimensions of the
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
where the vectorized axes should be in the outputs. where the vectorized axes should be in the outputs.
Let's time these two different versions: Let's time these two different versions:

View File

@ -51,7 +51,7 @@ You can also use an :obj:`array` to index another :obj:`array`:
.. code-block:: shell .. code-block:: shell
>>> arr = mx.arange(10) >>> arr = mx.arange(10)
>>> idx = mx.array([5, 7]) >>> idx = mx.array([5, 7])
>>> arr[idx] >>> arr[idx]
array([5, 7], dtype=int32) array([5, 7], dtype=int32)
@ -82,7 +82,7 @@ general, MLX has limited support for operations for which outputs
operations which MLX does not yet support include :func:`numpy.nonzero` and the operations which MLX does not yet support include :func:`numpy.nonzero` and the
single input version of :func:`numpy.where`. single input version of :func:`numpy.where`.
In Place Updates In Place Updates
---------------- ----------------
In place updates to indexed arrays are possible in MLX. For example: In place updates to indexed arrays are possible in MLX. For example:

View File

@ -13,7 +13,7 @@ compute graph is recorded. The actual computation only happens if an
:func:`eval` is performed. :func:`eval` is performed.
MLX uses lazy evaluation because it has some nice features, some of which we MLX uses lazy evaluation because it has some nice features, some of which we
describe below. describe below.
Transforming Compute Graphs Transforming Compute Graphs
^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -116,7 +116,7 @@ saving functions) will also evaluate the array.
Calling :func:`array.item` on a scalar array will also evaluate it. In the Calling :func:`array.item` on a scalar array will also evaluate it. In the
example above, printing the loss (``print(loss)``) or adding the loss scalar to example above, printing the loss (``print(loss)``) or adding the loss scalar to
a list (``losses.append(loss.item())``) would cause a graph evaluation. If a list (``losses.append(loss.item())``) would cause a graph evaluation. If
these lines are before ``mx.eval(loss, model.parameters())`` then this these lines are before ``mx.eval(loss, model.parameters())`` then this
will be a partial evaluation, computing only the forward pass. will be a partial evaluation, computing only the forward pass.

View File

@ -3,10 +3,10 @@
Conversion to NumPy and Other Frameworks Conversion to NumPy and Other Frameworks
======================================== ========================================
MLX array supports conversion between other frameworks with either: MLX array supports conversion between other frameworks with either:
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_. * The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_. * `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
Let's convert an array to NumPy and back. Let's convert an array to NumPy and back.
@ -66,7 +66,7 @@ even though no in-place operations on MLX memory are executed.
PyTorch PyTorch
------- -------
.. warning:: .. warning::
PyTorch Support for :obj:`memoryview` is experimental and can break for PyTorch Support for :obj:`memoryview` is experimental and can break for
multi-dimensional arrays. Casting to NumPy first is advised for now. multi-dimensional arrays. Casting to NumPy first is advised for now.

View File

@ -64,4 +64,4 @@ Other gradient transformations include :func:`vjp` for vector-Jacobian products
and :func:`jvp` for Jacobian-vector products. and :func:`jvp` for Jacobian-vector products.
Use :func:`value_and_grad` to efficiently compute both a function's output and Use :func:`value_and_grad` to efficiently compute both a function's output and
gradient with respect to the function's input. gradient with respect to the function's input.

View File

@ -8,33 +8,33 @@ Saving and Loading Arrays
MLX supports multiple array serialization formats. MLX supports multiple array serialization formats.
.. list-table:: Serialization Formats .. list-table:: Serialization Formats
:widths: 20 8 25 25 :widths: 20 8 25 25
:header-rows: 1 :header-rows: 1
* - Format * - Format
- Extension - Extension
- Function - Function
- Notes - Notes
* - NumPy * - NumPy
- ``.npy`` - ``.npy``
- :func:`save` - :func:`save`
- Single arrays only - Single arrays only
* - NumPy archive * - NumPy archive
- ``.npz`` - ``.npz``
- :func:`savez` and :func:`savez_compressed` - :func:`savez` and :func:`savez_compressed`
- Multiple arrays - Multiple arrays
* - Safetensors * - Safetensors
- ``.safetensors`` - ``.safetensors``
- :func:`save_safetensors` - :func:`save_safetensors`
- Multiple arrays - Multiple arrays
* - GGUF * - GGUF
- ``.gguf`` - ``.gguf``
- :func:`save_gguf` - :func:`save_gguf`
- Multiple arrays - Multiple arrays
The :func:`load` function will load any of the supported serialization The :func:`load` function will load any of the supported serialization
formats. It determines the format from the extensions. The output of formats. It determines the format from the extensions. The output of
:func:`load` depends on the format. :func:`load` depends on the format.
Here's an example of saving a single array to a file: Here's an example of saving a single array to a file:

View File

@ -20,7 +20,7 @@ Both ``a`` and ``b`` live in unified memory.
In MLX, rather than moving arrays to devices, you specify the device when you In MLX, rather than moving arrays to devices, you specify the device when you
run the operation. Any device can perform any operation on ``a`` and ``b`` run the operation. Any device can perform any operation on ``a`` and ``b``
without needing to move them from one memory location to another. For example: without needing to move them from one memory location to another. For example:
.. code-block:: python .. code-block:: python

View File

@ -180,6 +180,7 @@ void array::move_shared_buffer(
auto char_offset = sizeof(char) * itemsize() * offset; auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>( array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset); static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
other.array_desc_->data_ptr = nullptr;
} }
void array::move_shared_buffer(array other) { void array::move_shared_buffer(array other) {

View File

@ -205,7 +205,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Allocate new buffer if needed // Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared; size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeTracked; res_opt |= MTL::ResourceHazardTrackingModeUntracked;
lk.unlock(); lk.unlock();
buf = device_->newBuffer(size, res_opt); buf = device_->newBuffer(size, res_opt);
lk.lock(); lk.lock();

View File

@ -918,14 +918,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
"[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions."); "[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions.");
} }
// Clear copies // Record copies
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -77,12 +77,7 @@ void CustomKernel::eval_gpu(
MTL::Size grid_dims = MTL::Size(gx, gy, gz); MTL::Size grid_dims = MTL::Size(gx, gy, gz);
compute_encoder->dispatchThreads(grid_dims, group_dims); compute_encoder->dispatchThreads(grid_dims, group_dims);
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@ -20,7 +20,6 @@ namespace {
// TODO nicer way to set this or possibly expose as an environment variable // TODO nicer way to set this or possibly expose as an environment variable
constexpr int MAX_BUFFERS_PER_QUEUE = 12; constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
constexpr const char* default_mtllib_path = METAL_PATH; constexpr const char* default_mtllib_path = METAL_PATH;
@ -121,33 +120,41 @@ MTL::Library* load_library(
} // namespace } // namespace
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) { CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) {
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent); enc_ = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain(); enc_->retain();
} }
CommandEncoder::~CommandEncoder() { CommandEncoder::~CommandEncoder() {
enc->endEncoding(); enc_->endEncoding();
enc->release(); enc_->release();
}
void CommandEncoder::set_array(
const array& a,
int idx,
int64_t offset /* = 0 */) {
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
// Insert a barrier
enc_->memoryBarrier(&r_buf, 1);
// Remove the output
outputs_.erase(it);
}
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc_->setBuffer(a_buf, base_offset, idx);
} }
void CommandEncoder::set_input_array( void CommandEncoder::set_input_array(
const array& a, const array& a,
int idx, int idx,
int64_t offset /* = 0 */) { int64_t offset /* = 0 */) {
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr())); all_inputs_.insert(a.buffer().ptr());
if (auto it = outputs.find(r_buf); it != outputs.end()) { set_array(a, idx, offset);
// Insert a barrier
enc->memoryBarrier(&r_buf, 1);
// Remove the output
outputs.erase(it);
}
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
} }
void CommandEncoder::set_output_array( void CommandEncoder::set_output_array(
@ -155,40 +162,26 @@ void CommandEncoder::set_output_array(
int idx, int idx,
int64_t offset /* = 0 */) { int64_t offset /* = 0 */) {
// Add barriers before adding the output to the output set // Add barriers before adding the output to the output set
set_input_array(a, idx, offset); set_array(a, idx, offset);
all_outputs_.insert(a.buffer().ptr());
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr()); auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent) { if (concurrent_) {
concurrent_outputs.insert(buf); concurrent_outputs_.insert(buf);
} else { } else {
outputs.insert(buf); outputs_.insert(buf);
} }
} }
void CommandEncoder::dispatchThreadgroups( void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims, MTL::Size grid_dims,
MTL::Size group_dims) { MTL::Size group_dims) {
num_dispatches++; enc_->dispatchThreadgroups(grid_dims, group_dims);
enc->dispatchThreadgroups(grid_dims, group_dims);
maybe_split();
} }
void CommandEncoder::dispatchThreads( void CommandEncoder::dispatchThreads(
MTL::Size grid_dims, MTL::Size grid_dims,
MTL::Size group_dims) { MTL::Size group_dims) {
num_dispatches++; enc_->dispatchThreads(grid_dims, group_dims);
enc->dispatchThreads(grid_dims, group_dims);
maybe_split();
}
void CommandEncoder::maybe_split() {
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
enc->endEncoding();
enc->release();
num_dispatches = 0;
outputs.clear();
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
}
} }
Device::Device() { Device::Device() {
@ -199,12 +192,6 @@ Device::Device() {
Device::~Device() { Device::~Device() {
auto pool = new_scoped_memory_pool(); auto pool = new_scoped_memory_pool();
for (auto& q : queue_map_) {
q.second->release();
}
for (auto& b : buffer_map_) {
b.second.second->release();
}
for (auto& k : kernel_map_) { for (auto& k : kernel_map_) {
k.second->release(); k.second->release();
} }
@ -225,61 +212,125 @@ void Device::new_queue(int index) {
throw std::runtime_error( throw std::runtime_error(
"[metal::Device] Failed to make new command queue."); "[metal::Device] Failed to make new command queue.");
} }
queue_map_.insert({index, q}); stream_map_.emplace(index, q);
buffer_map_.insert({index, {0, nullptr}});
encoder_map_.insert({index, nullptr});
} }
int Device::get_command_buffer_ops(int index) { int Device::get_command_buffer_ops(int index) {
return buffer_map_[index].first; return get_stream_(index).buffer_ops;
} }
void Device::increment_command_buffer_ops(int index) { void Device::increment_command_buffer_ops(int index) {
buffer_map_[index].first++; get_stream_(index).buffer_ops++;
} }
MTL::CommandBuffer* Device::get_command_buffer(int index) { MTL::CommandBuffer* Device::get_command_buffer(int index) {
auto bit = buffer_map_.find(index); auto& stream = get_stream_(index);
if (bit->second.second == nullptr) { if (stream.buffer == nullptr) {
auto qit = queue_map_.find(index); stream.buffer = stream.queue->commandBufferWithUnretainedReferences();
if (qit == queue_map_.end()) { if (!stream.buffer) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
}
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
throw std::runtime_error( throw std::runtime_error(
"[metal::Device] Unable to create new command buffer"); "[metal::Device] Unable to create new command buffer");
} }
// Increment ref count so the buffer is not garbage collected // Increment ref count so the buffer is not garbage collected
cb->retain(); stream.buffer->retain();
bit->second = {0, cb};
} }
return bit->second.second; return stream.buffer;
} }
void Device::commit_command_buffer(int index) { void Device::commit_command_buffer(int index) {
auto bit = buffer_map_.find(index); auto& stream = get_stream_(index);
bit->second.second->commit(); stream.buffer->commit();
bit->second.second->release(); stream.buffer->release();
bit->second = {0, nullptr}; stream.buffer = nullptr;
stream.buffer_ops = 0;
}
void Device::add_temporary(array arr, int index) {
get_stream_(index).temporaries.push_back(std::move(arr));
}
void Device::add_temporaries(std::vector<array> arrays, int index) {
if (arrays.empty()) {
return;
}
auto& stream = get_stream_(index);
stream.temporaries.insert(
stream.temporaries.end(),
std::make_move_iterator(arrays.begin()),
std::make_move_iterator(arrays.end()));
} }
void Device::end_encoding(int index) { void Device::end_encoding(int index) {
encoder_map_[index] = nullptr; auto& stream = get_stream_(index);
if (stream.encoder != nullptr) {
// Each command encoder has a unique fence. We also store a map of
// all previous outputs of command encoders to their corresponding fence.
// - The command encoder records its inputs and outputs.
// - Wait on a fence if any inputs in the encoder are outputs of a previous
// encoder.
// - Update the map of outputs to include this command encoder's outputs.
// - Always signal this command encoders fence.
// - Add a completion handler for this command encoder that removes outputs
// from the map to limit the growth of the map and avoid unecessary waits
// - Temporaries are a special case as they do not cross command encoder
// boundaries. These can be removed early from the encoders inputs and
// outputs since they don't need synchronization.
auto& enc = *stream.encoder;
// Remove temporaries from inputs and outputs
for (auto& t : stream.temporaries) {
if (t.data<void>() != nullptr) {
enc.outputs().erase(t.buffer().ptr());
enc.inputs().erase(t.buffer().ptr());
}
}
// Keep references to the fences we waited on and put them
// in the completion handler so they are not prematurely released
std::unordered_set<std::shared_ptr<Fence>> waiting_on;
{
std::lock_guard<std::mutex> lk(stream.fence_mtx);
for (auto in : enc.inputs()) {
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
// If we've already waited on a fence, don't wait on it again.
if (waiting_on.find(it->second) == waiting_on.end()) {
enc->waitForFence(it->second->fence);
waiting_on.insert(it->second);
}
}
}
for (auto out : enc.outputs()) {
stream.outputs[out] = stream.fence;
}
}
enc->updateFence(stream.fence->fence);
stream.buffer->addCompletedHandler(
[&stream,
waiting_on = std::move(waiting_on),
fence = std::move(stream.fence),
outputs = std::move(enc.outputs()),
temporaries =
std::move(stream.temporaries)](MTL::CommandBuffer*) mutable {
temporaries.clear();
std::lock_guard<std::mutex> lk(stream.fence_mtx);
for (auto o : outputs) {
if (auto it = stream.outputs.find(o); it != stream.outputs.end()) {
if (it->second == fence) {
stream.outputs.erase(it);
}
}
}
});
}
stream.encoder = nullptr;
} }
CommandEncoder& Device::get_command_encoder(int index) { CommandEncoder& Device::get_command_encoder(int index) {
auto eit = encoder_map_.find(index); auto& stream = get_stream_(index);
if (eit->second == nullptr) { if (stream.encoder == nullptr) {
auto cb = get_command_buffer(index); stream.encoder = std::make_unique<CommandEncoder>(stream.buffer);
eit->second = std::make_unique<CommandEncoder>(cb); stream.fence = std::make_shared<Fence>(device_->newFence());
} }
return *(eit->second); return *stream.encoder;
} }
void Device::register_library( void Device::register_library(

View File

@ -45,13 +45,13 @@ struct CommandEncoder {
struct ConcurrentContext { struct ConcurrentContext {
ConcurrentContext(CommandEncoder& enc) : enc(enc) { ConcurrentContext(CommandEncoder& enc) : enc(enc) {
enc.concurrent = true; enc.concurrent_ = true;
} }
~ConcurrentContext() { ~ConcurrentContext() {
enc.concurrent = false; enc.concurrent_ = false;
enc.outputs.insert( enc.outputs_.insert(
enc.concurrent_outputs.begin(), enc.concurrent_outputs.end()); enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
enc.concurrent_outputs.clear(); enc.concurrent_outputs_.clear();
} }
private: private:
@ -59,7 +59,7 @@ struct CommandEncoder {
}; };
MTL::ComputeCommandEncoder* operator->() { MTL::ComputeCommandEncoder* operator->() {
return enc; return enc_;
} }
void set_input_array(const array& a, int idx, int64_t offset = 0); void set_input_array(const array& a, int idx, int64_t offset = 0);
@ -70,18 +70,60 @@ struct CommandEncoder {
ConcurrentContext start_concurrent() { ConcurrentContext start_concurrent() {
return ConcurrentContext(*this); return ConcurrentContext(*this);
} }
~CommandEncoder(); ~CommandEncoder();
private: // Inputs to all kernels in the encoder including temporaries
void maybe_split(); std::unordered_set<const void*>& inputs() {
return all_inputs_;
};
int num_dispatches{0}; // Outputs of all kernels in the encoder including temporaries
MTL::CommandBuffer* cbuf; std::unordered_set<const void*> outputs() {
MTL::ComputeCommandEncoder* enc; return all_outputs_;
bool concurrent{false}; };
std::unordered_set<MTL::Resource*> outputs;
std::unordered_set<MTL::Resource*> concurrent_outputs; private:
void set_array(const array& a, int idx, int64_t offset);
MTL::ComputeCommandEncoder* enc_;
bool concurrent_{false};
std::unordered_set<MTL::Resource*> outputs_;
std::unordered_set<MTL::Resource*> concurrent_outputs_;
std::unordered_set<const void*> all_inputs_;
std::unordered_set<const void*> all_outputs_;
};
struct Fence {
Fence(MTL::Fence* fence) : fence(fence) {}
~Fence() {
fence->release();
}
MTL::Fence* fence;
};
struct DeviceStream {
DeviceStream(MTL::CommandQueue* queue) : queue(queue) {};
~DeviceStream() {
queue->release();
if (buffer != nullptr) {
buffer->release();
}
};
MTL::CommandQueue* queue;
// A map of prior command encoder outputs to their corresponding fence
std::unordered_map<const void*, std::shared_ptr<Fence>> outputs;
// Used to allow thread-safe access to the outputs map
std::mutex fence_mtx;
// The buffer and buffer op count are updated
// between command buffers
MTL::CommandBuffer* buffer{nullptr};
int buffer_ops{0};
// The command encoder, fence, and temporaries are updated between command
// encoders
std::unique_ptr<CommandEncoder> encoder{nullptr};
std::shared_ptr<Fence> fence;
std::vector<array> temporaries;
}; };
class Device { class Device {
@ -136,7 +178,14 @@ class Device {
MTL::ArgumentEncoder* argument_encoder( MTL::ArgumentEncoder* argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const; const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
// Record temporary arrays for the given stream index
void add_temporary(array arr, int index);
void add_temporaries(std::vector<array> arrays, int index);
private: private:
DeviceStream& get_stream_(int index) {
return stream_map_.find(index)->second;
}
MTL::Library* get_library_cache_(const std::string& name); MTL::Library* get_library_cache_(const std::string& name);
MTL::Library* get_library_(const std::string& name); MTL::Library* get_library_(const std::string& name);
@ -170,9 +219,7 @@ class Device {
const std::vector<MTL::Function*>& linked_functions = {}); const std::vector<MTL::Function*>& linked_functions = {});
MTL::Device* device_; MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_; std::unordered_map<int32_t, DeviceStream> stream_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
std::shared_mutex kernel_mtx_; std::shared_mutex kernel_mtx_;
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_; std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;

View File

@ -575,10 +575,7 @@ void fft_op(
auto plan = plan_fft(n); auto plan = plan_fft(n);
if (plan.four_step) { if (plan.four_step) {
four_step_fft(in, out, axis, inverse, real, plan, copies, s); four_step_fft(in, out, axis, inverse, real, plan, copies, s);
d.get_command_buffer(s.index)->addCompletedHandler( d.add_temporaries(std::move(copies), s.index);
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
return; return;
} }
@ -744,12 +741,7 @@ void fft_op(
compute_encoder->dispatchThreads(grid_dims, group_dims); compute_encoder->dispatchThreads(grid_dims, group_dims);
} }
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
void fft_op( void fft_op(
@ -792,8 +784,7 @@ void nd_fft_op(
} }
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
d.get_command_buffer(s.index)->addCompletedHandler( d.add_temporaries(std::move(temp_arrs), s.index);
[temp_arrs](MTL::CommandBuffer*) mutable { temp_arrs.clear(); });
} }
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) { void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@ -174,12 +174,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_); launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
} }
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -226,13 +226,8 @@ void steel_matmul_regular(
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies // Record copies
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
void steel_matmul( void steel_matmul(
@ -382,12 +377,7 @@ void steel_matmul(
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
return; return;
} }
@ -435,8 +425,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
if (a_pre.size() == 0 || b_pre.size() == 0) { if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype()); array zero = array(0, a_pre.dtype());
fill_gpu(zero, out, s); fill_gpu(zero, out, s);
auto command_buffer = d.get_command_buffer(s.index); d.add_temporary(std::move(zero), s.index);
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
return; return;
} }
@ -588,12 +577,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
return; return;
} }
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
@ -798,12 +782,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
return; return;
} }
@ -916,12 +895,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
return; return;
} }
@ -1056,12 +1030,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) { void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -1080,8 +1049,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (a_pre.size() == 0 || b_pre.size() == 0) { if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype()); array zero = array(0, a_pre.dtype());
fill_gpu(zero, out, s); fill_gpu(zero, out, s);
auto command_buffer = d.get_command_buffer(s.index); d.add_temporary(std::move(zero), s.index);
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
return; return;
} }
@ -1356,12 +1324,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
return; return;
} }
@ -1471,13 +1434,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies d.add_temporaries(std::move(copies), s.index);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) { void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -1496,8 +1453,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (a_pre.size() == 0 || b_pre.size() == 0) { if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype()); array zero = array(0, a_pre.dtype());
fill_gpu(zero, out, s); fill_gpu(zero, out, s);
auto command_buffer = d.get_command_buffer(s.index); d.add_temporary(std::move(zero), s.index);
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
return; return;
} }
@ -1703,12 +1659,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
return; return;
} }
@ -1847,13 +1798,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies d.add_temporaries(std::move(copies), s.index);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -91,12 +91,8 @@ void RMSNorm::eval_gpu(
compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1); compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler( d.add_temporaries(std::move(copies), s.index);
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
void RMSNormVJP::eval_gpu( void RMSNormVJP::eval_gpu(
@ -204,10 +200,7 @@ void RMSNormVJP::eval_gpu(
strided_reduce_general_dispatch( strided_reduce_general_dispatch(
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s); gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
d.get_command_buffer(s.index)->addCompletedHandler( d.add_temporaries(std::move(copies), s.index);
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
} }
void LayerNorm::eval_gpu( void LayerNorm::eval_gpu(
@ -292,12 +285,8 @@ void LayerNorm::eval_gpu(
compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7); compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler( d.add_temporaries(std::move(copies), s.index);
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
void LayerNormVJP::eval_gpu( void LayerNormVJP::eval_gpu(
@ -425,10 +414,7 @@ void LayerNormVJP::eval_gpu(
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s); gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
} }
d.get_command_buffer(s.index)->addCompletedHandler( d.add_temporaries(std::move(copies), s.index);
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
} }
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@ -145,6 +145,7 @@ void launch_qmm(
} }
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
} }
void qmm_op( void qmm_op(
@ -350,12 +351,7 @@ void fast::AffineQuantize::eval_gpu(
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -660,12 +660,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
in, out, op_name, plan, axes_, compute_encoder, d, s); in, out, op_name, plan, axes_, compute_encoder, d, s);
} }
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
// Nothing to reduce just initialize the output // Nothing to reduce just initialize the output

View File

@ -262,12 +262,7 @@ void ScaledDotProductAttention::eval_gpu(
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o); sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
} }
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@ -107,13 +107,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -88,12 +88,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&axis_size, sizeof(int), 2); compute_encoder->setBytes(&axis_size, sizeof(int), 2);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler( d.add_temporaries(std::move(copies), s.index);
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -252,11 +252,7 @@ void multi_block_sort(
(axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General, (axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General,
s); s);
// Clear copies d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
} }
void gpu_merge_sort( void gpu_merge_sort(

View File

@ -769,7 +769,6 @@ class TestCompile(mlx_tests.MLXTestCase):
out = mx.compile(fn)(a, b) out = mx.compile(fn)(a, b)
expected = fn(a, b) expected = fn(a, b)
print((out - expected).abs().max())
self.assertTrue(mx.allclose(out, expected)) self.assertTrue(mx.allclose(out, expected))
def test_compile_many_inputs(self): def test_compile_many_inputs(self):