mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	rebase
This commit is contained in:
		@@ -1,3 +1,5 @@
 | 
			
		||||
.. _custom_metal_kernels:
 | 
			
		||||
 | 
			
		||||
Custom Metal Kernels
 | 
			
		||||
====================
 | 
			
		||||
 | 
			
		||||
@@ -76,6 +78,10 @@ Putting this all together, the generated function signature for ``myexp`` is as
 | 
			
		||||
 | 
			
		||||
  template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
 | 
			
		||||
 | 
			
		||||
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
 | 
			
		||||
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
 | 
			
		||||
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
 | 
			
		||||
 | 
			
		||||
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
 | 
			
		||||
 | 
			
		||||
Using Shape/Strides
 | 
			
		||||
 
 | 
			
		||||
@@ -161,7 +161,7 @@ A naive way to add the elements from two sets of vectors is with a loop:
 | 
			
		||||
  ys = mx.random.uniform(shape=(100, 4096))
 | 
			
		||||
 | 
			
		||||
  def naive_add(xs, ys):
 | 
			
		||||
      return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
 | 
			
		||||
      return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
 | 
			
		||||
 | 
			
		||||
Instead you can use :func:`vmap` to automatically vectorize the addition:
 | 
			
		||||
 | 
			
		||||
@@ -169,7 +169,7 @@ Instead you can use :func:`vmap` to automatically vectorize the addition:
 | 
			
		||||
 | 
			
		||||
   # Vectorize over the second dimension of x and the
 | 
			
		||||
   # first dimension of y
 | 
			
		||||
   vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
 | 
			
		||||
   vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
 | 
			
		||||
 | 
			
		||||
The ``in_axes`` parameter can be used to specify which dimensions of the
 | 
			
		||||
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										2
									
								
								docs/build/html/_sources/usage/indexing.rst
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								docs/build/html/_sources/usage/indexing.rst
									
									
									
									
										vendored
									
									
								
							@@ -77,7 +77,7 @@ from the GPU. Performing bounds checking for array indices before launching the
 | 
			
		||||
kernel would be extremely inefficient.
 | 
			
		||||
 | 
			
		||||
Indexing with boolean masks is something that MLX may support in the future. In
 | 
			
		||||
general, MLX has limited support for operations for which outputs
 | 
			
		||||
general, MLX has limited support for operations for which output
 | 
			
		||||
*shapes* are dependent on input *data*. Other examples of these types of
 | 
			
		||||
operations which MLX does not yet support include :func:`numpy.nonzero` and the
 | 
			
		||||
single input version of :func:`numpy.where`.
 | 
			
		||||
 
 | 
			
		||||
@@ -109,7 +109,7 @@ Here is a concrete example:
 | 
			
		||||
 | 
			
		||||
An important behavior to be aware of is when the graph will be implicitly
 | 
			
		||||
evaluated. Anytime you ``print`` an array, convert it to an
 | 
			
		||||
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
 | 
			
		||||
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
 | 
			
		||||
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
 | 
			
		||||
saving functions) will also evaluate the array.
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user