mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Remove Hazard tracking with Fences (#1509)
* remove hazard tracking * with fence map * no hazard tracking with fences * nits * fix fence retain * cleanup * fix quantized rebase
This commit is contained in:
parent
d15fa13daf
commit
c26208f67d
@ -33,12 +33,12 @@ Let's start with a simple example:
|
|||||||
# Compile the function
|
# Compile the function
|
||||||
compiled_fun = mx.compile(fun)
|
compiled_fun = mx.compile(fun)
|
||||||
|
|
||||||
# Prints: array(2.36788, dtype=float32)
|
# Prints: array(2.36788, dtype=float32)
|
||||||
print(compiled_fun(x, y))
|
print(compiled_fun(x, y))
|
||||||
|
|
||||||
The output of both the regular function and the compiled function is the same
|
The output of both the regular function and the compiled function is the same
|
||||||
up to numerical precision.
|
up to numerical precision.
|
||||||
|
|
||||||
The first time you call a compiled function, MLX will build the compute
|
The first time you call a compiled function, MLX will build the compute
|
||||||
graph, optimize it, and generate and compile code. This can be relatively
|
graph, optimize it, and generate and compile code. This can be relatively
|
||||||
slow. However, MLX will cache compiled functions, so calling a compiled
|
slow. However, MLX will cache compiled functions, so calling a compiled
|
||||||
@ -96,7 +96,7 @@ element-wise operations:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
||||||
|
|
||||||
If you use this function with small arrays, it will be overhead bound. If you
|
If you use this function with small arrays, it will be overhead bound. If you
|
||||||
@ -136,13 +136,6 @@ Now make an array, and benchmark both functions:
|
|||||||
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
|
||||||
five times faster.
|
five times faster.
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
|
|
||||||
functions can still be helpful, but won't typically result in as large a
|
|
||||||
speedup as compiling operations that run on the GPU.
|
|
||||||
|
|
||||||
|
|
||||||
Debugging
|
Debugging
|
||||||
---------
|
---------
|
||||||
|
|
||||||
@ -287,7 +280,7 @@ to the function. In some cases this can be pretty inconvenient. Hence,
|
|||||||
print(fun(mx.array(1.0)))
|
print(fun(mx.array(1.0)))
|
||||||
|
|
||||||
|
|
||||||
Compiling Training Graphs
|
Compiling Training Graphs
|
||||||
-------------------------
|
-------------------------
|
||||||
|
|
||||||
This section will step through how to use :func:`compile` with a simple example
|
This section will step through how to use :func:`compile` with a simple example
|
||||||
@ -297,7 +290,7 @@ full forward, backward, and update with :func:`compile`.
|
|||||||
|
|
||||||
To start, here is the simple example without any compilation:
|
To start, here is the simple example without any compilation:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -330,7 +323,7 @@ To start, here is the simple example without any compilation:
|
|||||||
To compile the update we can put it all in a function and compile it with the
|
To compile the update we can put it all in a function and compile it with the
|
||||||
appropriate input and output captures. Here's the same example but compiled:
|
appropriate input and output captures. Here's the same example but compiled:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -355,7 +348,7 @@ appropriate input and output captures. Here's the same example but compiled:
|
|||||||
|
|
||||||
# The state that will be captured as input and output
|
# The state that will be captured as input and output
|
||||||
state = [model.state, optimizer.state]
|
state = [model.state, optimizer.state]
|
||||||
|
|
||||||
@partial(mx.compile, inputs=state, outputs=state)
|
@partial(mx.compile, inputs=state, outputs=state)
|
||||||
def step(x, y):
|
def step(x, y):
|
||||||
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
|
||||||
@ -410,7 +403,7 @@ Compiling transformed functions works just as expected:
|
|||||||
|
|
||||||
In order to compile as much as possible, a transformation of a compiled
|
In order to compile as much as possible, a transformation of a compiled
|
||||||
function will not by default be compiled. To compile the transformed
|
function will not by default be compiled. To compile the transformed
|
||||||
function simply pass it through :func:`compile`.
|
function simply pass it through :func:`compile`.
|
||||||
|
|
||||||
You can also compile functions which themselves call compiled functions. A
|
You can also compile functions which themselves call compiled functions. A
|
||||||
good practice is to compile the outer most function to give :func:`compile`
|
good practice is to compile the outer most function to give :func:`compile`
|
||||||
|
@ -25,7 +25,7 @@ Here is a simple example:
|
|||||||
|
|
||||||
The output of :func:`grad` on :func:`sin` is simply another function. In this
|
The output of :func:`grad` on :func:`sin` is simply another function. In this
|
||||||
case it is the gradient of the sine function which is exactly the cosine
|
case it is the gradient of the sine function which is exactly the cosine
|
||||||
function. To get the second derivative you can do:
|
function. To get the second derivative you can do:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ Automatic Differentiation
|
|||||||
.. _auto diff:
|
.. _auto diff:
|
||||||
|
|
||||||
Automatic differentiation in MLX works on functions rather than on implicit
|
Automatic differentiation in MLX works on functions rather than on implicit
|
||||||
graphs.
|
graphs.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ way to do that is the following:
|
|||||||
|
|
||||||
def loss_fn(params, x, y):
|
def loss_fn(params, x, y):
|
||||||
w, b = params["weight"], params["bias"]
|
w, b = params["weight"], params["bias"]
|
||||||
h = w * x + b
|
h = w * x + b
|
||||||
return mx.mean(mx.square(h - y))
|
return mx.mean(mx.square(h - y))
|
||||||
|
|
||||||
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
|
params = {"weight": mx.array(1.0), "bias": mx.array(0.0)}
|
||||||
@ -132,7 +132,7 @@ way to do that is the following:
|
|||||||
|
|
||||||
Notice the tree structure of the parameters is preserved in the gradients.
|
Notice the tree structure of the parameters is preserved in the gradients.
|
||||||
|
|
||||||
In some cases you may want to stop gradients from propagating through a
|
In some cases you may want to stop gradients from propagating through a
|
||||||
part of the function. You can use the :func:`stop_gradient` for that.
|
part of the function. You can use the :func:`stop_gradient` for that.
|
||||||
|
|
||||||
|
|
||||||
@ -166,14 +166,14 @@ A naive way to add the elements from two sets of vectors is with a loop:
|
|||||||
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
# Vectorize over the second dimension of x and the
|
# Vectorize over the second dimension of x and the
|
||||||
# first dimension of y
|
# first dimension of y
|
||||||
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
|
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
|
||||||
|
|
||||||
The ``in_axes`` parameter can be used to specify which dimensions of the
|
The ``in_axes`` parameter can be used to specify which dimensions of the
|
||||||
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
||||||
where the vectorized axes should be in the outputs.
|
where the vectorized axes should be in the outputs.
|
||||||
|
|
||||||
Let's time these two different versions:
|
Let's time these two different versions:
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ You can also use an :obj:`array` to index another :obj:`array`:
|
|||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
|
||||||
>>> arr = mx.arange(10)
|
>>> arr = mx.arange(10)
|
||||||
>>> idx = mx.array([5, 7])
|
>>> idx = mx.array([5, 7])
|
||||||
>>> arr[idx]
|
>>> arr[idx]
|
||||||
array([5, 7], dtype=int32)
|
array([5, 7], dtype=int32)
|
||||||
|
|
||||||
@ -82,7 +82,7 @@ general, MLX has limited support for operations for which outputs
|
|||||||
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
||||||
single input version of :func:`numpy.where`.
|
single input version of :func:`numpy.where`.
|
||||||
|
|
||||||
In Place Updates
|
In Place Updates
|
||||||
----------------
|
----------------
|
||||||
|
|
||||||
In place updates to indexed arrays are possible in MLX. For example:
|
In place updates to indexed arrays are possible in MLX. For example:
|
||||||
|
@ -13,7 +13,7 @@ compute graph is recorded. The actual computation only happens if an
|
|||||||
:func:`eval` is performed.
|
:func:`eval` is performed.
|
||||||
|
|
||||||
MLX uses lazy evaluation because it has some nice features, some of which we
|
MLX uses lazy evaluation because it has some nice features, some of which we
|
||||||
describe below.
|
describe below.
|
||||||
|
|
||||||
Transforming Compute Graphs
|
Transforming Compute Graphs
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
@ -116,7 +116,7 @@ saving functions) will also evaluate the array.
|
|||||||
|
|
||||||
Calling :func:`array.item` on a scalar array will also evaluate it. In the
|
Calling :func:`array.item` on a scalar array will also evaluate it. In the
|
||||||
example above, printing the loss (``print(loss)``) or adding the loss scalar to
|
example above, printing the loss (``print(loss)``) or adding the loss scalar to
|
||||||
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
|
a list (``losses.append(loss.item())``) would cause a graph evaluation. If
|
||||||
these lines are before ``mx.eval(loss, model.parameters())`` then this
|
these lines are before ``mx.eval(loss, model.parameters())`` then this
|
||||||
will be a partial evaluation, computing only the forward pass.
|
will be a partial evaluation, computing only the forward pass.
|
||||||
|
|
||||||
|
@ -3,10 +3,10 @@
|
|||||||
Conversion to NumPy and Other Frameworks
|
Conversion to NumPy and Other Frameworks
|
||||||
========================================
|
========================================
|
||||||
|
|
||||||
MLX array supports conversion between other frameworks with either:
|
MLX array supports conversion between other frameworks with either:
|
||||||
|
|
||||||
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
* The `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
|
||||||
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
|
* `DLPack <https://dmlc.github.io/dlpack/latest/>`_.
|
||||||
|
|
||||||
Let's convert an array to NumPy and back.
|
Let's convert an array to NumPy and back.
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ even though no in-place operations on MLX memory are executed.
|
|||||||
PyTorch
|
PyTorch
|
||||||
-------
|
-------
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
|
|
||||||
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
PyTorch Support for :obj:`memoryview` is experimental and can break for
|
||||||
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
multi-dimensional arrays. Casting to NumPy first is advised for now.
|
||||||
|
@ -64,4 +64,4 @@ Other gradient transformations include :func:`vjp` for vector-Jacobian products
|
|||||||
and :func:`jvp` for Jacobian-vector products.
|
and :func:`jvp` for Jacobian-vector products.
|
||||||
|
|
||||||
Use :func:`value_and_grad` to efficiently compute both a function's output and
|
Use :func:`value_and_grad` to efficiently compute both a function's output and
|
||||||
gradient with respect to the function's input.
|
gradient with respect to the function's input.
|
||||||
|
@ -8,33 +8,33 @@ Saving and Loading Arrays
|
|||||||
MLX supports multiple array serialization formats.
|
MLX supports multiple array serialization formats.
|
||||||
|
|
||||||
.. list-table:: Serialization Formats
|
.. list-table:: Serialization Formats
|
||||||
:widths: 20 8 25 25
|
:widths: 20 8 25 25
|
||||||
:header-rows: 1
|
:header-rows: 1
|
||||||
|
|
||||||
* - Format
|
* - Format
|
||||||
- Extension
|
- Extension
|
||||||
- Function
|
- Function
|
||||||
- Notes
|
- Notes
|
||||||
* - NumPy
|
* - NumPy
|
||||||
- ``.npy``
|
- ``.npy``
|
||||||
- :func:`save`
|
- :func:`save`
|
||||||
- Single arrays only
|
- Single arrays only
|
||||||
* - NumPy archive
|
* - NumPy archive
|
||||||
- ``.npz``
|
- ``.npz``
|
||||||
- :func:`savez` and :func:`savez_compressed`
|
- :func:`savez` and :func:`savez_compressed`
|
||||||
- Multiple arrays
|
- Multiple arrays
|
||||||
* - Safetensors
|
* - Safetensors
|
||||||
- ``.safetensors``
|
- ``.safetensors``
|
||||||
- :func:`save_safetensors`
|
- :func:`save_safetensors`
|
||||||
- Multiple arrays
|
- Multiple arrays
|
||||||
* - GGUF
|
* - GGUF
|
||||||
- ``.gguf``
|
- ``.gguf``
|
||||||
- :func:`save_gguf`
|
- :func:`save_gguf`
|
||||||
- Multiple arrays
|
- Multiple arrays
|
||||||
|
|
||||||
The :func:`load` function will load any of the supported serialization
|
The :func:`load` function will load any of the supported serialization
|
||||||
formats. It determines the format from the extensions. The output of
|
formats. It determines the format from the extensions. The output of
|
||||||
:func:`load` depends on the format.
|
:func:`load` depends on the format.
|
||||||
|
|
||||||
Here's an example of saving a single array to a file:
|
Here's an example of saving a single array to a file:
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ Both ``a`` and ``b`` live in unified memory.
|
|||||||
|
|
||||||
In MLX, rather than moving arrays to devices, you specify the device when you
|
In MLX, rather than moving arrays to devices, you specify the device when you
|
||||||
run the operation. Any device can perform any operation on ``a`` and ``b``
|
run the operation. Any device can perform any operation on ``a`` and ``b``
|
||||||
without needing to move them from one memory location to another. For example:
|
without needing to move them from one memory location to another. For example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -180,6 +180,7 @@ void array::move_shared_buffer(
|
|||||||
auto char_offset = sizeof(char) * itemsize() * offset;
|
auto char_offset = sizeof(char) * itemsize() * offset;
|
||||||
array_desc_->data_ptr = static_cast<void*>(
|
array_desc_->data_ptr = static_cast<void*>(
|
||||||
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
|
||||||
|
other.array_desc_->data_ptr = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void array::move_shared_buffer(array other) {
|
void array::move_shared_buffer(array other) {
|
||||||
|
@ -205,7 +205,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
|||||||
|
|
||||||
// Allocate new buffer if needed
|
// Allocate new buffer if needed
|
||||||
size_t res_opt = MTL::ResourceStorageModeShared;
|
size_t res_opt = MTL::ResourceStorageModeShared;
|
||||||
res_opt |= MTL::ResourceHazardTrackingModeTracked;
|
res_opt |= MTL::ResourceHazardTrackingModeUntracked;
|
||||||
lk.unlock();
|
lk.unlock();
|
||||||
buf = device_->newBuffer(size, res_opt);
|
buf = device_->newBuffer(size, res_opt);
|
||||||
lk.lock();
|
lk.lock();
|
||||||
|
@ -918,14 +918,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
"[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions.");
|
"[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear copies
|
// Record copies
|
||||||
if (!copies.empty()) {
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
auto command_buffer = d.get_command_buffer(s.index);
|
|
||||||
command_buffer->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -77,12 +77,7 @@ void CustomKernel::eval_gpu(
|
|||||||
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
|
MTL::Size grid_dims = MTL::Size(gx, gy, gz);
|
||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
|
|
||||||
if (!copies.empty()) {
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::fast
|
} // namespace mlx::core::fast
|
||||||
|
@ -20,7 +20,6 @@ namespace {
|
|||||||
|
|
||||||
// TODO nicer way to set this or possibly expose as an environment variable
|
// TODO nicer way to set this or possibly expose as an environment variable
|
||||||
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||||
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
|
|
||||||
|
|
||||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||||
|
|
||||||
@ -121,33 +120,41 @@ MTL::Library* load_library(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
|
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) {
|
||||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
enc_ = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||||
enc->retain();
|
enc_->retain();
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder::~CommandEncoder() {
|
CommandEncoder::~CommandEncoder() {
|
||||||
enc->endEncoding();
|
enc_->endEncoding();
|
||||||
enc->release();
|
enc_->release();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::set_array(
|
||||||
|
const array& a,
|
||||||
|
int idx,
|
||||||
|
int64_t offset /* = 0 */) {
|
||||||
|
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||||
|
if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
|
||||||
|
// Insert a barrier
|
||||||
|
enc_->memoryBarrier(&r_buf, 1);
|
||||||
|
|
||||||
|
// Remove the output
|
||||||
|
outputs_.erase(it);
|
||||||
|
}
|
||||||
|
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||||
|
auto base_offset = a.data<char>() -
|
||||||
|
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||||
|
base_offset += offset;
|
||||||
|
enc_->setBuffer(a_buf, base_offset, idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::set_input_array(
|
void CommandEncoder::set_input_array(
|
||||||
const array& a,
|
const array& a,
|
||||||
int idx,
|
int idx,
|
||||||
int64_t offset /* = 0 */) {
|
int64_t offset /* = 0 */) {
|
||||||
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
all_inputs_.insert(a.buffer().ptr());
|
||||||
if (auto it = outputs.find(r_buf); it != outputs.end()) {
|
set_array(a, idx, offset);
|
||||||
// Insert a barrier
|
|
||||||
enc->memoryBarrier(&r_buf, 1);
|
|
||||||
|
|
||||||
// Remove the output
|
|
||||||
outputs.erase(it);
|
|
||||||
}
|
|
||||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
|
||||||
auto base_offset = a.data<char>() -
|
|
||||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
|
||||||
base_offset += offset;
|
|
||||||
enc->setBuffer(a_buf, base_offset, idx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::set_output_array(
|
void CommandEncoder::set_output_array(
|
||||||
@ -155,40 +162,26 @@ void CommandEncoder::set_output_array(
|
|||||||
int idx,
|
int idx,
|
||||||
int64_t offset /* = 0 */) {
|
int64_t offset /* = 0 */) {
|
||||||
// Add barriers before adding the output to the output set
|
// Add barriers before adding the output to the output set
|
||||||
set_input_array(a, idx, offset);
|
set_array(a, idx, offset);
|
||||||
|
all_outputs_.insert(a.buffer().ptr());
|
||||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||||
if (concurrent) {
|
if (concurrent_) {
|
||||||
concurrent_outputs.insert(buf);
|
concurrent_outputs_.insert(buf);
|
||||||
} else {
|
} else {
|
||||||
outputs.insert(buf);
|
outputs_.insert(buf);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::dispatchThreadgroups(
|
void CommandEncoder::dispatchThreadgroups(
|
||||||
MTL::Size grid_dims,
|
MTL::Size grid_dims,
|
||||||
MTL::Size group_dims) {
|
MTL::Size group_dims) {
|
||||||
num_dispatches++;
|
enc_->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
enc->dispatchThreadgroups(grid_dims, group_dims);
|
|
||||||
maybe_split();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::dispatchThreads(
|
void CommandEncoder::dispatchThreads(
|
||||||
MTL::Size grid_dims,
|
MTL::Size grid_dims,
|
||||||
MTL::Size group_dims) {
|
MTL::Size group_dims) {
|
||||||
num_dispatches++;
|
enc_->dispatchThreads(grid_dims, group_dims);
|
||||||
enc->dispatchThreads(grid_dims, group_dims);
|
|
||||||
maybe_split();
|
|
||||||
}
|
|
||||||
|
|
||||||
void CommandEncoder::maybe_split() {
|
|
||||||
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
|
|
||||||
enc->endEncoding();
|
|
||||||
enc->release();
|
|
||||||
num_dispatches = 0;
|
|
||||||
outputs.clear();
|
|
||||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
|
||||||
enc->retain();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Device::Device() {
|
Device::Device() {
|
||||||
@ -199,12 +192,6 @@ Device::Device() {
|
|||||||
|
|
||||||
Device::~Device() {
|
Device::~Device() {
|
||||||
auto pool = new_scoped_memory_pool();
|
auto pool = new_scoped_memory_pool();
|
||||||
for (auto& q : queue_map_) {
|
|
||||||
q.second->release();
|
|
||||||
}
|
|
||||||
for (auto& b : buffer_map_) {
|
|
||||||
b.second.second->release();
|
|
||||||
}
|
|
||||||
for (auto& k : kernel_map_) {
|
for (auto& k : kernel_map_) {
|
||||||
k.second->release();
|
k.second->release();
|
||||||
}
|
}
|
||||||
@ -225,61 +212,125 @@ void Device::new_queue(int index) {
|
|||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[metal::Device] Failed to make new command queue.");
|
"[metal::Device] Failed to make new command queue.");
|
||||||
}
|
}
|
||||||
queue_map_.insert({index, q});
|
stream_map_.emplace(index, q);
|
||||||
buffer_map_.insert({index, {0, nullptr}});
|
|
||||||
encoder_map_.insert({index, nullptr});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int Device::get_command_buffer_ops(int index) {
|
int Device::get_command_buffer_ops(int index) {
|
||||||
return buffer_map_[index].first;
|
return get_stream_(index).buffer_ops;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::increment_command_buffer_ops(int index) {
|
void Device::increment_command_buffer_ops(int index) {
|
||||||
buffer_map_[index].first++;
|
get_stream_(index).buffer_ops++;
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
||||||
auto bit = buffer_map_.find(index);
|
auto& stream = get_stream_(index);
|
||||||
if (bit->second.second == nullptr) {
|
if (stream.buffer == nullptr) {
|
||||||
auto qit = queue_map_.find(index);
|
stream.buffer = stream.queue->commandBufferWithUnretainedReferences();
|
||||||
if (qit == queue_map_.end()) {
|
if (!stream.buffer) {
|
||||||
throw std::runtime_error(
|
|
||||||
"[metal::Device] Attempting to get command buffer for invalid queue.");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto cb = qit->second->commandBufferWithUnretainedReferences();
|
|
||||||
|
|
||||||
if (!cb) {
|
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[metal::Device] Unable to create new command buffer");
|
"[metal::Device] Unable to create new command buffer");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Increment ref count so the buffer is not garbage collected
|
// Increment ref count so the buffer is not garbage collected
|
||||||
cb->retain();
|
stream.buffer->retain();
|
||||||
|
|
||||||
bit->second = {0, cb};
|
|
||||||
}
|
}
|
||||||
return bit->second.second;
|
return stream.buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::commit_command_buffer(int index) {
|
void Device::commit_command_buffer(int index) {
|
||||||
auto bit = buffer_map_.find(index);
|
auto& stream = get_stream_(index);
|
||||||
bit->second.second->commit();
|
stream.buffer->commit();
|
||||||
bit->second.second->release();
|
stream.buffer->release();
|
||||||
bit->second = {0, nullptr};
|
stream.buffer = nullptr;
|
||||||
|
stream.buffer_ops = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Device::add_temporary(array arr, int index) {
|
||||||
|
get_stream_(index).temporaries.push_back(std::move(arr));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Device::add_temporaries(std::vector<array> arrays, int index) {
|
||||||
|
if (arrays.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto& stream = get_stream_(index);
|
||||||
|
stream.temporaries.insert(
|
||||||
|
stream.temporaries.end(),
|
||||||
|
std::make_move_iterator(arrays.begin()),
|
||||||
|
std::make_move_iterator(arrays.end()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::end_encoding(int index) {
|
void Device::end_encoding(int index) {
|
||||||
encoder_map_[index] = nullptr;
|
auto& stream = get_stream_(index);
|
||||||
|
if (stream.encoder != nullptr) {
|
||||||
|
// Each command encoder has a unique fence. We also store a map of
|
||||||
|
// all previous outputs of command encoders to their corresponding fence.
|
||||||
|
// - The command encoder records its inputs and outputs.
|
||||||
|
// - Wait on a fence if any inputs in the encoder are outputs of a previous
|
||||||
|
// encoder.
|
||||||
|
// - Update the map of outputs to include this command encoder's outputs.
|
||||||
|
// - Always signal this command encoders fence.
|
||||||
|
// - Add a completion handler for this command encoder that removes outputs
|
||||||
|
// from the map to limit the growth of the map and avoid unecessary waits
|
||||||
|
// - Temporaries are a special case as they do not cross command encoder
|
||||||
|
// boundaries. These can be removed early from the encoders inputs and
|
||||||
|
// outputs since they don't need synchronization.
|
||||||
|
auto& enc = *stream.encoder;
|
||||||
|
// Remove temporaries from inputs and outputs
|
||||||
|
for (auto& t : stream.temporaries) {
|
||||||
|
if (t.data<void>() != nullptr) {
|
||||||
|
enc.outputs().erase(t.buffer().ptr());
|
||||||
|
enc.inputs().erase(t.buffer().ptr());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep references to the fences we waited on and put them
|
||||||
|
// in the completion handler so they are not prematurely released
|
||||||
|
std::unordered_set<std::shared_ptr<Fence>> waiting_on;
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lk(stream.fence_mtx);
|
||||||
|
for (auto in : enc.inputs()) {
|
||||||
|
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
|
||||||
|
// If we've already waited on a fence, don't wait on it again.
|
||||||
|
if (waiting_on.find(it->second) == waiting_on.end()) {
|
||||||
|
enc->waitForFence(it->second->fence);
|
||||||
|
waiting_on.insert(it->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto out : enc.outputs()) {
|
||||||
|
stream.outputs[out] = stream.fence;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
enc->updateFence(stream.fence->fence);
|
||||||
|
stream.buffer->addCompletedHandler(
|
||||||
|
[&stream,
|
||||||
|
waiting_on = std::move(waiting_on),
|
||||||
|
fence = std::move(stream.fence),
|
||||||
|
outputs = std::move(enc.outputs()),
|
||||||
|
temporaries =
|
||||||
|
std::move(stream.temporaries)](MTL::CommandBuffer*) mutable {
|
||||||
|
temporaries.clear();
|
||||||
|
std::lock_guard<std::mutex> lk(stream.fence_mtx);
|
||||||
|
for (auto o : outputs) {
|
||||||
|
if (auto it = stream.outputs.find(o); it != stream.outputs.end()) {
|
||||||
|
if (it->second == fence) {
|
||||||
|
stream.outputs.erase(it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
stream.encoder = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
CommandEncoder& Device::get_command_encoder(int index) {
|
CommandEncoder& Device::get_command_encoder(int index) {
|
||||||
auto eit = encoder_map_.find(index);
|
auto& stream = get_stream_(index);
|
||||||
if (eit->second == nullptr) {
|
if (stream.encoder == nullptr) {
|
||||||
auto cb = get_command_buffer(index);
|
stream.encoder = std::make_unique<CommandEncoder>(stream.buffer);
|
||||||
eit->second = std::make_unique<CommandEncoder>(cb);
|
stream.fence = std::make_shared<Fence>(device_->newFence());
|
||||||
}
|
}
|
||||||
return *(eit->second);
|
return *stream.encoder;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::register_library(
|
void Device::register_library(
|
||||||
|
@ -45,13 +45,13 @@ struct CommandEncoder {
|
|||||||
|
|
||||||
struct ConcurrentContext {
|
struct ConcurrentContext {
|
||||||
ConcurrentContext(CommandEncoder& enc) : enc(enc) {
|
ConcurrentContext(CommandEncoder& enc) : enc(enc) {
|
||||||
enc.concurrent = true;
|
enc.concurrent_ = true;
|
||||||
}
|
}
|
||||||
~ConcurrentContext() {
|
~ConcurrentContext() {
|
||||||
enc.concurrent = false;
|
enc.concurrent_ = false;
|
||||||
enc.outputs.insert(
|
enc.outputs_.insert(
|
||||||
enc.concurrent_outputs.begin(), enc.concurrent_outputs.end());
|
enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
|
||||||
enc.concurrent_outputs.clear();
|
enc.concurrent_outputs_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -59,7 +59,7 @@ struct CommandEncoder {
|
|||||||
};
|
};
|
||||||
|
|
||||||
MTL::ComputeCommandEncoder* operator->() {
|
MTL::ComputeCommandEncoder* operator->() {
|
||||||
return enc;
|
return enc_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_input_array(const array& a, int idx, int64_t offset = 0);
|
void set_input_array(const array& a, int idx, int64_t offset = 0);
|
||||||
@ -70,18 +70,60 @@ struct CommandEncoder {
|
|||||||
ConcurrentContext start_concurrent() {
|
ConcurrentContext start_concurrent() {
|
||||||
return ConcurrentContext(*this);
|
return ConcurrentContext(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
~CommandEncoder();
|
~CommandEncoder();
|
||||||
|
|
||||||
private:
|
// Inputs to all kernels in the encoder including temporaries
|
||||||
void maybe_split();
|
std::unordered_set<const void*>& inputs() {
|
||||||
|
return all_inputs_;
|
||||||
|
};
|
||||||
|
|
||||||
int num_dispatches{0};
|
// Outputs of all kernels in the encoder including temporaries
|
||||||
MTL::CommandBuffer* cbuf;
|
std::unordered_set<const void*> outputs() {
|
||||||
MTL::ComputeCommandEncoder* enc;
|
return all_outputs_;
|
||||||
bool concurrent{false};
|
};
|
||||||
std::unordered_set<MTL::Resource*> outputs;
|
|
||||||
std::unordered_set<MTL::Resource*> concurrent_outputs;
|
private:
|
||||||
|
void set_array(const array& a, int idx, int64_t offset);
|
||||||
|
MTL::ComputeCommandEncoder* enc_;
|
||||||
|
bool concurrent_{false};
|
||||||
|
std::unordered_set<MTL::Resource*> outputs_;
|
||||||
|
std::unordered_set<MTL::Resource*> concurrent_outputs_;
|
||||||
|
std::unordered_set<const void*> all_inputs_;
|
||||||
|
std::unordered_set<const void*> all_outputs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Fence {
|
||||||
|
Fence(MTL::Fence* fence) : fence(fence) {}
|
||||||
|
~Fence() {
|
||||||
|
fence->release();
|
||||||
|
}
|
||||||
|
MTL::Fence* fence;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DeviceStream {
|
||||||
|
DeviceStream(MTL::CommandQueue* queue) : queue(queue) {};
|
||||||
|
~DeviceStream() {
|
||||||
|
queue->release();
|
||||||
|
if (buffer != nullptr) {
|
||||||
|
buffer->release();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MTL::CommandQueue* queue;
|
||||||
|
// A map of prior command encoder outputs to their corresponding fence
|
||||||
|
std::unordered_map<const void*, std::shared_ptr<Fence>> outputs;
|
||||||
|
// Used to allow thread-safe access to the outputs map
|
||||||
|
std::mutex fence_mtx;
|
||||||
|
|
||||||
|
// The buffer and buffer op count are updated
|
||||||
|
// between command buffers
|
||||||
|
MTL::CommandBuffer* buffer{nullptr};
|
||||||
|
int buffer_ops{0};
|
||||||
|
|
||||||
|
// The command encoder, fence, and temporaries are updated between command
|
||||||
|
// encoders
|
||||||
|
std::unique_ptr<CommandEncoder> encoder{nullptr};
|
||||||
|
std::shared_ptr<Fence> fence;
|
||||||
|
std::vector<array> temporaries;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Device {
|
class Device {
|
||||||
@ -136,7 +178,14 @@ class Device {
|
|||||||
MTL::ArgumentEncoder* argument_encoder(
|
MTL::ArgumentEncoder* argument_encoder(
|
||||||
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
|
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
|
||||||
|
|
||||||
|
// Record temporary arrays for the given stream index
|
||||||
|
void add_temporary(array arr, int index);
|
||||||
|
void add_temporaries(std::vector<array> arrays, int index);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
DeviceStream& get_stream_(int index) {
|
||||||
|
return stream_map_.find(index)->second;
|
||||||
|
}
|
||||||
MTL::Library* get_library_cache_(const std::string& name);
|
MTL::Library* get_library_cache_(const std::string& name);
|
||||||
|
|
||||||
MTL::Library* get_library_(const std::string& name);
|
MTL::Library* get_library_(const std::string& name);
|
||||||
@ -170,9 +219,7 @@ class Device {
|
|||||||
const std::vector<MTL::Function*>& linked_functions = {});
|
const std::vector<MTL::Function*>& linked_functions = {});
|
||||||
|
|
||||||
MTL::Device* device_;
|
MTL::Device* device_;
|
||||||
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
std::unordered_map<int32_t, DeviceStream> stream_map_;
|
||||||
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
|
||||||
std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
|
|
||||||
|
|
||||||
std::shared_mutex kernel_mtx_;
|
std::shared_mutex kernel_mtx_;
|
||||||
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
||||||
|
@ -575,10 +575,7 @@ void fft_op(
|
|||||||
auto plan = plan_fft(n);
|
auto plan = plan_fft(n);
|
||||||
if (plan.four_step) {
|
if (plan.four_step) {
|
||||||
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
|
four_step_fft(in, out, axis, inverse, real, plan, copies, s);
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -744,12 +741,7 @@ void fft_op(
|
|||||||
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
compute_encoder->dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!copies.empty()) {
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void fft_op(
|
void fft_op(
|
||||||
@ -792,8 +784,7 @@ void nd_fft_op(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
d.add_temporaries(std::move(temp_arrs), s.index);
|
||||||
[temp_arrs](MTL::CommandBuffer*) mutable { temp_arrs.clear(); });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
@ -174,12 +174,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
|
launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!copies.empty()) {
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -226,13 +226,8 @@ void steel_matmul_regular(
|
|||||||
|
|
||||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
// Clear copies
|
// Record copies
|
||||||
if (!copies.empty()) {
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void steel_matmul(
|
void steel_matmul(
|
||||||
@ -382,12 +377,7 @@ void steel_matmul(
|
|||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!copies.empty()) {
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -435,8 +425,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||||
array zero = array(0, a_pre.dtype());
|
array zero = array(0, a_pre.dtype());
|
||||||
fill_gpu(zero, out, s);
|
fill_gpu(zero, out, s);
|
||||||
auto command_buffer = d.get_command_buffer(s.index);
|
d.add_temporary(std::move(zero), s.index);
|
||||||
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -588,12 +577,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
if (!copies.empty()) {
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
@ -798,12 +782,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
if (!copies.empty()) {
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -916,12 +895,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!copies.empty()) {
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1056,12 +1030,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
if (!copies.empty()) {
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@ -1080,8 +1049,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||||
array zero = array(0, a_pre.dtype());
|
array zero = array(0, a_pre.dtype());
|
||||||
fill_gpu(zero, out, s);
|
fill_gpu(zero, out, s);
|
||||||
auto command_buffer = d.get_command_buffer(s.index);
|
d.add_temporary(std::move(zero), s.index);
|
||||||
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1356,12 +1324,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
if (!copies.empty()) {
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1471,13 +1434,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
// Clear copies
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
if (!copies.empty()) {
|
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@ -1496,8 +1453,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
if (a_pre.size() == 0 || b_pre.size() == 0) {
|
||||||
array zero = array(0, a_pre.dtype());
|
array zero = array(0, a_pre.dtype());
|
||||||
fill_gpu(zero, out, s);
|
fill_gpu(zero, out, s);
|
||||||
auto command_buffer = d.get_command_buffer(s.index);
|
d.add_temporary(std::move(zero), s.index);
|
||||||
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1703,12 +1659,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
if (!copies.empty()) {
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1847,13 +1798,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
|
||||||
// Clear copies
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
if (!copies.empty()) {
|
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -91,12 +91,8 @@ void RMSNorm::eval_gpu(
|
|||||||
compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1);
|
compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1);
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
if (!copies.empty()) {
|
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void RMSNormVJP::eval_gpu(
|
void RMSNormVJP::eval_gpu(
|
||||||
@ -204,10 +200,7 @@ void RMSNormVJP::eval_gpu(
|
|||||||
strided_reduce_general_dispatch(
|
strided_reduce_general_dispatch(
|
||||||
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
||||||
|
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void LayerNorm::eval_gpu(
|
void LayerNorm::eval_gpu(
|
||||||
@ -292,12 +285,8 @@ void LayerNorm::eval_gpu(
|
|||||||
compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7);
|
compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7);
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
if (!copies.empty()) {
|
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void LayerNormVJP::eval_gpu(
|
void LayerNormVJP::eval_gpu(
|
||||||
@ -425,10 +414,7 @@ void LayerNormVJP::eval_gpu(
|
|||||||
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::fast
|
} // namespace mlx::core::fast
|
||||||
|
@ -145,6 +145,7 @@ void launch_qmm(
|
|||||||
}
|
}
|
||||||
|
|
||||||
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
}
|
}
|
||||||
|
|
||||||
void qmm_op(
|
void qmm_op(
|
||||||
@ -350,12 +351,7 @@ void fast::AffineQuantize::eval_gpu(
|
|||||||
: MTL::Size(nthreads, 1, 1);
|
: MTL::Size(nthreads, 1, 1);
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
|
|
||||||
if (!copies.empty()) {
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -660,12 +660,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in, out, op_name, plan, axes_, compute_encoder, d, s);
|
in, out, op_name, plan, axes_, compute_encoder, d, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!copies.empty()) {
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Nothing to reduce just initialize the output
|
// Nothing to reduce just initialize the output
|
||||||
|
@ -262,12 +262,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
|
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!copies.empty()) {
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::fast
|
} // namespace mlx::core::fast
|
||||||
|
@ -107,13 +107,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!copies.empty()) {
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
auto command_buffer = d.get_command_buffer(s.index);
|
|
||||||
command_buffer->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -88,12 +88,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
|
compute_encoder->setBytes(&axis_size, sizeof(int), 2);
|
||||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
if (!copies.empty()) {
|
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -252,11 +252,7 @@ void multi_block_sort(
|
|||||||
(axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General,
|
(axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General,
|
||||||
s);
|
s);
|
||||||
|
|
||||||
// Clear copies
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
d.get_command_buffer(s.index)->addCompletedHandler(
|
|
||||||
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
|
|
||||||
copies.clear();
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void gpu_merge_sort(
|
void gpu_merge_sort(
|
||||||
|
@ -769,7 +769,6 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
out = mx.compile(fn)(a, b)
|
out = mx.compile(fn)(a, b)
|
||||||
expected = fn(a, b)
|
expected = fn(a, b)
|
||||||
print((out - expected).abs().max())
|
|
||||||
self.assertTrue(mx.allclose(out, expected))
|
self.assertTrue(mx.allclose(out, expected))
|
||||||
|
|
||||||
def test_compile_many_inputs(self):
|
def test_compile_many_inputs(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user