mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
rebase
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
.. _custom_metal_kernels:
|
||||
|
||||
Custom Metal Kernels
|
||||
====================
|
||||
|
||||
@@ -76,6 +78,10 @@ Putting this all together, the generated function signature for ``myexp`` is as
|
||||
|
||||
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
|
||||
|
||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
|
||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
|
||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
|
||||
|
||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
|
||||
|
||||
Using Shape/Strides
|
||||
|
@@ -161,7 +161,7 @@ A naive way to add the elements from two sets of vectors is with a loop:
|
||||
ys = mx.random.uniform(shape=(100, 4096))
|
||||
|
||||
def naive_add(xs, ys):
|
||||
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
|
||||
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
|
||||
|
||||
Instead you can use :func:`vmap` to automatically vectorize the addition:
|
||||
|
||||
@@ -169,7 +169,7 @@ Instead you can use :func:`vmap` to automatically vectorize the addition:
|
||||
|
||||
# Vectorize over the second dimension of x and the
|
||||
# first dimension of y
|
||||
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
|
||||
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
|
||||
|
||||
The ``in_axes`` parameter can be used to specify which dimensions of the
|
||||
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
|
||||
|
2
docs/build/html/_sources/usage/indexing.rst
vendored
2
docs/build/html/_sources/usage/indexing.rst
vendored
@@ -77,7 +77,7 @@ from the GPU. Performing bounds checking for array indices before launching the
|
||||
kernel would be extremely inefficient.
|
||||
|
||||
Indexing with boolean masks is something that MLX may support in the future. In
|
||||
general, MLX has limited support for operations for which outputs
|
||||
general, MLX has limited support for operations for which output
|
||||
*shapes* are dependent on input *data*. Other examples of these types of
|
||||
operations which MLX does not yet support include :func:`numpy.nonzero` and the
|
||||
single input version of :func:`numpy.where`.
|
||||
|
@@ -109,7 +109,7 @@ Here is a concrete example:
|
||||
|
||||
An important behavior to be aware of is when the graph will be implicitly
|
||||
evaluated. Anytime you ``print`` an array, convert it to an
|
||||
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
|
||||
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
|
||||
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
|
||||
saving functions) will also evaluate the array.
|
||||
|
||||
|
Reference in New Issue
Block a user