mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 10:26:56 +08:00
rebase
This commit is contained in:
16
docs/build/html/_sources/dev/extensions.rst
vendored
16
docs/build/html/_sources/dev/extensions.rst
vendored
@@ -494,7 +494,7 @@ below.
|
||||
|
||||
// Prepare to encode kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Kernel parameters are registered with buffer indices corresponding to
|
||||
// those in the kernel declaration at axpby.metal
|
||||
@@ -509,14 +509,14 @@ below.
|
||||
compute_encoder.set_output_array(out, 2);
|
||||
|
||||
// Encode alpha and beta
|
||||
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
|
||||
compute_encoder->setBytes(&beta_, sizeof(float), 4);
|
||||
compute_encoder.set_bytes(alpha_, 3);
|
||||
compute_encoder.set_bytes(beta_, 4);
|
||||
|
||||
// Encode shape, strides and ndim
|
||||
compute_encoder->setBytes(x.shape().data(), ndim * sizeof(int), 5);
|
||||
compute_encoder->setBytes(x.strides().data(), ndim * sizeof(size_t), 6);
|
||||
compute_encoder->setBytes(y.strides().data(), ndim * sizeof(size_t), 7);
|
||||
compute_encoder->setBytes(&ndim, sizeof(int), 8);
|
||||
compute_encoder.set_vector_bytes(x.shape(), 5);
|
||||
compute_encoder.set_vector_bytes(x.strides(), 6);
|
||||
compute_encoder.set_bytes(y.strides(), 7);
|
||||
compute_encoder.set_bytes(ndim, 8);
|
||||
|
||||
// We launch 1 thread for each input and make sure that the number of
|
||||
// threads in any given threadgroup is not higher than the max allowed
|
||||
@@ -530,7 +530,7 @@ below.
|
||||
|
||||
// Launch the grid with the given number of threads divided among
|
||||
// the given threadgroups
|
||||
compute_encoder.dispatchThreads(grid_dims, group_dims);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
|
||||
|
2
docs/build/html/_sources/install.rst
vendored
2
docs/build/html/_sources/install.rst
vendored
@@ -209,7 +209,7 @@ Metal library by run-time compiling kernels the first time they are used in MLX
|
||||
on a given machine. Note run-time compilation incurs a cold-start cost which can
|
||||
be anwywhere from a few hundred millisecond to a few seconds depending on the
|
||||
application. Once a kernel is compiled, it will be cached by the system. The
|
||||
Metal kernel cache persists accross reboots.
|
||||
Metal kernel cache persists across reboots.
|
||||
|
||||
Troubleshooting
|
||||
^^^^^^^^^^^^^^^
|
||||
|
@@ -1,6 +0,0 @@
|
||||
mlx.core.fast.affine\_quantize
|
||||
==============================
|
||||
|
||||
.. currentmodule:: mlx.core.fast
|
||||
|
||||
.. autofunction:: affine_quantize
|
1
docs/build/html/_sources/python/fast.rst
vendored
1
docs/build/html/_sources/python/fast.rst
vendored
@@ -12,5 +12,4 @@ Fast
|
||||
layer_norm
|
||||
rope
|
||||
scaled_dot_product_attention
|
||||
affine_quantize
|
||||
metal_kernel
|
||||
|
16
docs/build/html/_sources/python/nn/_autosummary/mlx.nn.AvgPool3d.rst
vendored
Normal file
16
docs/build/html/_sources/python/nn/_autosummary/mlx.nn.AvgPool3d.rst
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
mlx.nn.AvgPool3d
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.nn
|
||||
|
||||
.. autoclass:: AvgPool3d
|
||||
|
||||
|
||||
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
|
||||
|
||||
|
16
docs/build/html/_sources/python/nn/_autosummary/mlx.nn.MaxPool3d.rst
vendored
Normal file
16
docs/build/html/_sources/python/nn/_autosummary/mlx.nn.MaxPool3d.rst
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
mlx.nn.MaxPool3d
|
||||
================
|
||||
|
||||
.. currentmodule:: mlx.nn
|
||||
|
||||
.. autoclass:: MaxPool3d
|
||||
|
||||
|
||||
|
||||
|
||||
.. rubric:: Methods
|
||||
|
||||
.. autosummary::
|
||||
|
||||
|
||||
|
@@ -12,6 +12,7 @@ Layers
|
||||
ALiBi
|
||||
AvgPool1d
|
||||
AvgPool2d
|
||||
AvgPool3d
|
||||
BatchNorm
|
||||
CELU
|
||||
Conv1d
|
||||
@@ -41,6 +42,7 @@ Layers
|
||||
LSTM
|
||||
MaxPool1d
|
||||
MaxPool2d
|
||||
MaxPool3d
|
||||
Mish
|
||||
MultiHeadAttention
|
||||
PReLU
|
||||
|
@@ -184,8 +184,8 @@ Let's time these two different versions:
|
||||
print(timeit.timeit(lambda: mx.eval(naive_add(xs, ys)), number=100))
|
||||
print(timeit.timeit(lambda: mx.eval(vmap_add(xs, ys)), number=100))
|
||||
|
||||
On an M1 Max the naive version takes in total ``0.390`` seconds whereas the
|
||||
vectorized version takes only ``0.025`` seconds, more than ten times faster.
|
||||
On an M1 Max the naive version takes in total ``5.639`` seconds whereas the
|
||||
vectorized version takes only ``0.024`` seconds, more than 200 times faster.
|
||||
|
||||
Of course, this operation is quite contrived. A better approach is to simply do
|
||||
``xs + ys.T``, but for more complex functions :func:`vmap` can be quite handy.
|
||||
|
Reference in New Issue
Block a user