This commit is contained in:
CircleCI Docs 2024-11-05 19:54:16 +00:00
parent a5d741ec3b
commit e5e2ffe503
51 changed files with 2277 additions and 1802 deletions

View File

@ -1,3 +1,5 @@
.. _custom_metal_kernels:
Custom Metal Kernels
====================
@ -76,6 +78,10 @@ Putting this all together, the generated function signature for ``myexp`` is as
template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
Note: ``grid`` and ``threadgroup`` are parameters to the Metal `dispatchThreads <https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads>`_ function.
This means we will launch ``mx.prod(grid)`` threads, subdivided into ``threadgroup`` size threadgroups.
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
Passing ``verbose=True`` to ``mx.fast.metal_kernel.__call__`` will print the generated code for debugging purposes.
Using Shape/Strides

View File

@ -161,7 +161,7 @@ A naive way to add the elements from two sets of vectors is with a loop:
ys = mx.random.uniform(shape=(100, 4096))
def naive_add(xs, ys):
return [xs[i] + ys[:, i] for i in range(xs.shape[1])]
return [xs[i] + ys[:, i] for i in range(xs.shape[0])]
Instead you can use :func:`vmap` to automatically vectorize the addition:
@ -169,7 +169,7 @@ Instead you can use :func:`vmap` to automatically vectorize the addition:
# Vectorize over the second dimension of x and the
# first dimension of y
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(1, 0))
vmap_add = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))
The ``in_axes`` parameter can be used to specify which dimensions of the
corresponding input to vectorize over. Similarly, use ``out_axes`` to specify

View File

@ -77,7 +77,7 @@ from the GPU. Performing bounds checking for array indices before launching the
kernel would be extremely inefficient.
Indexing with boolean masks is something that MLX may support in the future. In
general, MLX has limited support for operations for which outputs
general, MLX has limited support for operations for which output
*shapes* are dependent on input *data*. Other examples of these types of
operations which MLX does not yet support include :func:`numpy.nonzero` and the
single input version of :func:`numpy.where`.

View File

@ -109,7 +109,7 @@ Here is a concrete example:
An important behavior to be aware of is when the graph will be implicitly
evaluated. Anytime you ``print`` an array, convert it to an
:obj:`numpy.ndarray`, or otherwise access it's memory via :obj:`memoryview`,
:obj:`numpy.ndarray`, or otherwise access its memory via :obj:`memoryview`,
the graph will be evaluated. Saving arrays via :func:`save` (or any other MLX
saving functions) will also evaluate the array.

View File

@ -149,7 +149,7 @@ $(function(){ initResizable(false); });
<div class="foldopen" id="foldopen00050" data-start="{" data-end="}">
<div class="line"><a id="l00050" name="l00050"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html#a28bafec56edec3091e8716d8ccfb6ee1"> 50</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html#a28bafec56edec3091e8716d8ccfb6ee1">~ConcurrentContext</a>() {</div>
<div class="line"><a id="l00051" name="l00051"></a><span class="lineno"> 51</span> enc.concurrent_ = <span class="keyword">false</span>;</div>
<div class="line"><a id="l00052" name="l00052"></a><span class="lineno"> 52</span> enc.outputs_.insert(</div>
<div class="line"><a id="l00052" name="l00052"></a><span class="lineno"> 52</span> enc.prev_outputs_.insert(</div>
<div class="line"><a id="l00053" name="l00053"></a><span class="lineno"> 53</span> enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());</div>
<div class="line"><a id="l00054" name="l00054"></a><span class="lineno"> 54</span> enc.concurrent_outputs_.clear();</div>
<div class="line"><a id="l00055" name="l00055"></a><span class="lineno"> 55</span> }</div>
@ -170,212 +170,215 @@ $(function(){ initResizable(false); });
<div class="line"><a id="l00066" name="l00066"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a6a2e28e542eaa2886041bddd51ff6522"> 66</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a6a2e28e542eaa2886041bddd51ff6522">set_output_array</a>(<a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; a, <span class="keywordtype">int</span> idx, int64_t offset = 0);</div>
<div class="line"><a id="l00067" name="l00067"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a74bcd8e35f80f5a62db48c4a2bb0173e"> 67</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a74bcd8e35f80f5a62db48c4a2bb0173e">dispatchThreadgroups</a>(MTL::Size grid_dims, MTL::Size group_dims);</div>
<div class="line"><a id="l00068" name="l00068"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a1e41477f2f489e38499f7830a91c9810"> 68</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a1e41477f2f489e38499f7830a91c9810">dispatchThreads</a>(MTL::Size grid_dims, MTL::Size group_dims);</div>
<div class="line"><a id="l00069" name="l00069"></a><span class="lineno"> 69</span> </div>
<div class="foldopen" id="foldopen00070" data-start="{" data-end="}">
<div class="line"><a id="l00070" name="l00070"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a48b548a0b15f9d1279c938a1c6167034"> 70</a></span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html">ConcurrentContext</a> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a48b548a0b15f9d1279c938a1c6167034">start_concurrent</a>() {</div>
<div class="line"><a id="l00071" name="l00071"></a><span class="lineno"> 71</span> <span class="keywordflow">return</span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html">ConcurrentContext</a>(*<span class="keyword">this</span>);</div>
<div class="line"><a id="l00072" name="l00072"></a><span class="lineno"> 72</span> }</div>
<div class="line"><a id="l00069" name="l00069"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ad538ae88f90560063f9ba502e2795991"> 69</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ad538ae88f90560063f9ba502e2795991">maybeInsertBarrier</a>();</div>
<div class="line"><a id="l00070" name="l00070"></a><span class="lineno"> 70</span> </div>
<div class="foldopen" id="foldopen00071" data-start="{" data-end="}">
<div class="line"><a id="l00071" name="l00071"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a48b548a0b15f9d1279c938a1c6167034"> 71</a></span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html">ConcurrentContext</a> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a48b548a0b15f9d1279c938a1c6167034">start_concurrent</a>() {</div>
<div class="line"><a id="l00072" name="l00072"></a><span class="lineno"> 72</span> <span class="keywordflow">return</span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html">ConcurrentContext</a>(*<span class="keyword">this</span>);</div>
<div class="line"><a id="l00073" name="l00073"></a><span class="lineno"> 73</span> }</div>
</div>
<div class="line"><a id="l00073" name="l00073"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a9b6dd221ccd2d939d544004cb6279198"> 73</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a9b6dd221ccd2d939d544004cb6279198">~CommandEncoder</a>();</div>
<div class="line"><a id="l00074" name="l00074"></a><span class="lineno"> 74</span> </div>
<div class="line"><a id="l00075" name="l00075"></a><span class="lineno"> 75</span> <span class="comment">// Inputs to all kernels in the encoder including temporaries</span></div>
<div class="foldopen" id="foldopen00076" data-start="{" data-end="}">
<div class="line"><a id="l00076" name="l00076"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a27ded7e54bc1712063c874646b445509"> 76</a></span> std::unordered_set&lt;const void*&gt;&amp; <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a27ded7e54bc1712063c874646b445509">inputs</a>() {</div>
<div class="line"><a id="l00077" name="l00077"></a><span class="lineno"> 77</span> <span class="keywordflow">return</span> all_inputs_;</div>
<div class="line"><a id="l00078" name="l00078"></a><span class="lineno"> 78</span> };</div>
<div class="line"><a id="l00074" name="l00074"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a9b6dd221ccd2d939d544004cb6279198"> 74</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a9b6dd221ccd2d939d544004cb6279198">~CommandEncoder</a>();</div>
<div class="line"><a id="l00075" name="l00075"></a><span class="lineno"> 75</span> </div>
<div class="line"><a id="l00076" name="l00076"></a><span class="lineno"> 76</span> <span class="comment">// Inputs to all kernels in the encoder including temporaries</span></div>
<div class="foldopen" id="foldopen00077" data-start="{" data-end="}">
<div class="line"><a id="l00077" name="l00077"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a27ded7e54bc1712063c874646b445509"> 77</a></span> std::unordered_set&lt;const void*&gt;&amp; <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a27ded7e54bc1712063c874646b445509">inputs</a>() {</div>
<div class="line"><a id="l00078" name="l00078"></a><span class="lineno"> 78</span> <span class="keywordflow">return</span> all_inputs_;</div>
<div class="line"><a id="l00079" name="l00079"></a><span class="lineno"> 79</span> };</div>
</div>
<div class="line"><a id="l00079" name="l00079"></a><span class="lineno"> 79</span> </div>
<div class="line"><a id="l00080" name="l00080"></a><span class="lineno"> 80</span> <span class="comment">// Outputs of all kernels in the encoder including temporaries</span></div>
<div class="foldopen" id="foldopen00081" data-start="{" data-end="}">
<div class="line"><a id="l00081" name="l00081"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aefa48740fdee884f02e2d379bca4e78f"> 81</a></span> std::unordered_set&lt;const void*&gt; <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aefa48740fdee884f02e2d379bca4e78f">outputs</a>() {</div>
<div class="line"><a id="l00082" name="l00082"></a><span class="lineno"> 82</span> <span class="keywordflow">return</span> all_outputs_;</div>
<div class="line"><a id="l00083" name="l00083"></a><span class="lineno"> 83</span> };</div>
<div class="line"><a id="l00080" name="l00080"></a><span class="lineno"> 80</span> </div>
<div class="line"><a id="l00081" name="l00081"></a><span class="lineno"> 81</span> <span class="comment">// Outputs of all kernels in the encoder including temporaries</span></div>
<div class="foldopen" id="foldopen00082" data-start="{" data-end="}">
<div class="line"><a id="l00082" name="l00082"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aefa48740fdee884f02e2d379bca4e78f"> 82</a></span> std::unordered_set&lt;const void*&gt; <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aefa48740fdee884f02e2d379bca4e78f">outputs</a>() {</div>
<div class="line"><a id="l00083" name="l00083"></a><span class="lineno"> 83</span> <span class="keywordflow">return</span> all_outputs_;</div>
<div class="line"><a id="l00084" name="l00084"></a><span class="lineno"> 84</span> };</div>
</div>
<div class="line"><a id="l00084" name="l00084"></a><span class="lineno"> 84</span> </div>
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"> 85</span> <span class="keyword">private</span>:</div>
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"> 86</span> MTL::ComputeCommandEncoder* enc_;</div>
<div class="line"><a id="l00087" name="l00087"></a><span class="lineno"> 87</span> <span class="keywordtype">bool</span> concurrent_{<span class="keyword">false</span>};</div>
<div class="line"><a id="l00088" name="l00088"></a><span class="lineno"> 88</span> std::unordered_set&lt;MTL::Resource*&gt; outputs_;</div>
<div class="line"><a id="l00089" name="l00089"></a><span class="lineno"> 89</span> std::unordered_set&lt;MTL::Resource*&gt; concurrent_outputs_;</div>
<div class="line"><a id="l00090" name="l00090"></a><span class="lineno"> 90</span> std::unordered_set&lt;const void*&gt; all_inputs_;</div>
<div class="line"><a id="l00091" name="l00091"></a><span class="lineno"> 91</span> std::unordered_set&lt;const void*&gt; all_outputs_;</div>
<div class="line"><a id="l00092" name="l00092"></a><span class="lineno"> 92</span>};</div>
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"> 85</span> </div>
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"> 86</span> <span class="keyword">private</span>:</div>
<div class="line"><a id="l00087" name="l00087"></a><span class="lineno"> 87</span> MTL::ComputeCommandEncoder* enc_;</div>
<div class="line"><a id="l00088" name="l00088"></a><span class="lineno"> 88</span> <span class="keywordtype">bool</span> needs_barrier_{<span class="keyword">false</span>};</div>
<div class="line"><a id="l00089" name="l00089"></a><span class="lineno"> 89</span> <span class="keywordtype">bool</span> concurrent_{<span class="keyword">false</span>};</div>
<div class="line"><a id="l00090" name="l00090"></a><span class="lineno"> 90</span> std::unordered_set&lt;MTL::Resource*&gt; prev_outputs_;</div>
<div class="line"><a id="l00091" name="l00091"></a><span class="lineno"> 91</span> std::unordered_set&lt;MTL::Resource*&gt; next_outputs_;</div>
<div class="line"><a id="l00092" name="l00092"></a><span class="lineno"> 92</span> std::unordered_set&lt;MTL::Resource*&gt; concurrent_outputs_;</div>
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"> 93</span> std::unordered_set&lt;const void*&gt; all_inputs_;</div>
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"> 94</span> std::unordered_set&lt;const void*&gt; all_outputs_;</div>
<div class="line"><a id="l00095" name="l00095"></a><span class="lineno"> 95</span>};</div>
</div>
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"> 93</span> </div>
<div class="foldopen" id="foldopen00094" data-start="{" data-end="};">
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_fence.html"> 94</a></span><span class="keyword">struct </span><a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_fence.html">Fence</a> {</div>
<div class="line"><a id="l00095" name="l00095"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_fence.html#a30bee4957ae595e04922952a8010fc79"> 95</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_fence.html#a30bee4957ae595e04922952a8010fc79">Fence</a>(MTL::Fence* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>) : <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>(<a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>) {}</div>
<div class="foldopen" id="foldopen00096" data-start="{" data-end="}">
<div class="line"><a id="l00096" name="l00096"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_fence.html#a4940c1aece13814af7727de9abb511f2"> 96</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_fence.html#a4940c1aece13814af7727de9abb511f2">~Fence</a>() {</div>
<div class="line"><a id="l00097" name="l00097"></a><span class="lineno"> 97</span> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>-&gt;release();</div>
<div class="line"><a id="l00098" name="l00098"></a><span class="lineno"> 98</span> }</div>
<div class="line"><a id="l00096" name="l00096"></a><span class="lineno"> 96</span> </div>
<div class="foldopen" id="foldopen00097" data-start="{" data-end="};">
<div class="line"><a id="l00097" name="l00097"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_fence.html"> 97</a></span><span class="keyword">struct </span><a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_fence.html">Fence</a> {</div>
<div class="line"><a id="l00098" name="l00098"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_fence.html#a30bee4957ae595e04922952a8010fc79"> 98</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_fence.html#a30bee4957ae595e04922952a8010fc79">Fence</a>(MTL::Fence* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>) : <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>(<a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>) {}</div>
<div class="foldopen" id="foldopen00099" data-start="{" data-end="}">
<div class="line"><a id="l00099" name="l00099"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_fence.html#a4940c1aece13814af7727de9abb511f2"> 99</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_fence.html#a4940c1aece13814af7727de9abb511f2">~Fence</a>() {</div>
<div class="line"><a id="l00100" name="l00100"></a><span class="lineno"> 100</span> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>-&gt;release();</div>
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"> 101</span> }</div>
</div>
<div class="line"><a id="l00099" name="l00099"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87"> 99</a></span> MTL::Fence* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>;</div>
<div class="line"><a id="l00100" name="l00100"></a><span class="lineno"> 100</span>};</div>
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87"> 102</a></span> MTL::Fence* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">fence</a>;</div>
<div class="line"><a id="l00103" name="l00103"></a><span class="lineno"> 103</span>};</div>
</div>
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"> 101</span> </div>
<div class="foldopen" id="foldopen00102" data-start="{" data-end="};">
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html"> 102</a></span><span class="keyword">struct </span><a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_device_stream.html">DeviceStream</a> {</div>
<div class="line"><a id="l00103" name="l00103"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a573326bc8b48e39076850c7bf52ad0d7"> 103</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a573326bc8b48e39076850c7bf52ad0d7">DeviceStream</a>(MTL::CommandQueue* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>) : <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>(<a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>) {};</div>
<div class="foldopen" id="foldopen00104" data-start="{" data-end="}">
<div class="line"><a id="l00104" name="l00104"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a1c4397732f64f5811381dd01e30e020e"> 104</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a1c4397732f64f5811381dd01e30e020e">~DeviceStream</a>() {</div>
<div class="line"><a id="l00105" name="l00105"></a><span class="lineno"> 105</span> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>-&gt;release();</div>
<div class="line"><a id="l00106" name="l00106"></a><span class="lineno"> 106</span> <span class="keywordflow">if</span> (<a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb">buffer</a> != <span class="keyword">nullptr</span>) {</div>
<div class="line"><a id="l00107" name="l00107"></a><span class="lineno"> 107</span> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb">buffer</a>-&gt;release();</div>
<div class="line"><a id="l00108" name="l00108"></a><span class="lineno"> 108</span> }</div>
<div class="line"><a id="l00109" name="l00109"></a><span class="lineno"> 109</span> };</div>
<div class="line"><a id="l00104" name="l00104"></a><span class="lineno"> 104</span> </div>
<div class="foldopen" id="foldopen00105" data-start="{" data-end="};">
<div class="line"><a id="l00105" name="l00105"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html"> 105</a></span><span class="keyword">struct </span><a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_device_stream.html">DeviceStream</a> {</div>
<div class="line"><a id="l00106" name="l00106"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a573326bc8b48e39076850c7bf52ad0d7"> 106</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a573326bc8b48e39076850c7bf52ad0d7">DeviceStream</a>(MTL::CommandQueue* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>) : <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>(<a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>) {};</div>
<div class="foldopen" id="foldopen00107" data-start="{" data-end="}">
<div class="line"><a id="l00107" name="l00107"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a1c4397732f64f5811381dd01e30e020e"> 107</a></span> <a class="code hl_function" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a1c4397732f64f5811381dd01e30e020e">~DeviceStream</a>() {</div>
<div class="line"><a id="l00108" name="l00108"></a><span class="lineno"> 108</span> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>-&gt;release();</div>
<div class="line"><a id="l00109" name="l00109"></a><span class="lineno"> 109</span> <span class="keywordflow">if</span> (<a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb">buffer</a> != <span class="keyword">nullptr</span>) {</div>
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"> 110</span> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb">buffer</a>-&gt;release();</div>
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"> 111</span> }</div>
<div class="line"><a id="l00112" name="l00112"></a><span class="lineno"> 112</span> };</div>
</div>
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d"> 110</a></span> MTL::CommandQueue* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>;</div>
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"> 111</span> <span class="comment">// A map of prior command encoder outputs to their corresponding fence</span></div>
<div class="line"><a id="l00112" name="l00112"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a55a7a92c6abad369c99a5ede7a2521b9"> 112</a></span> std::unordered_map&lt;const void*, std::shared_ptr&lt;Fence&gt;&gt; <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a55a7a92c6abad369c99a5ede7a2521b9">outputs</a>;</div>
<div class="line"><a id="l00113" name="l00113"></a><span class="lineno"> 113</span> <span class="comment">// Used to allow thread-safe access to the outputs map</span></div>
<div class="line"><a id="l00114" name="l00114"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a6fa08cca881fc3798ae45994a11a4fcd"> 114</a></span> std::mutex <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a6fa08cca881fc3798ae45994a11a4fcd">fence_mtx</a>;</div>
<div class="line"><a id="l00115" name="l00115"></a><span class="lineno"> 115</span> </div>
<div class="line"><a id="l00116" name="l00116"></a><span class="lineno"> 116</span> <span class="comment">// The buffer and buffer op count are updated</span></div>
<div class="line"><a id="l00117" name="l00117"></a><span class="lineno"> 117</span> <span class="comment">// between command buffers</span></div>
<div class="line"><a id="l00118" name="l00118"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb"> 118</a></span> MTL::CommandBuffer* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb">buffer</a>{<span class="keyword">nullptr</span>};</div>
<div class="line"><a id="l00119" name="l00119"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#ab6048b329e65a59033834f3bdd351782"> 119</a></span> <span class="keywordtype">int</span> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#ab6048b329e65a59033834f3bdd351782">buffer_ops</a>{0};</div>
<div class="line"><a id="l00120" name="l00120"></a><span class="lineno"> 120</span> </div>
<div class="line"><a id="l00121" name="l00121"></a><span class="lineno"> 121</span> <span class="comment">// The command encoder, fence, and temporaries are updated between command</span></div>
<div class="line"><a id="l00122" name="l00122"></a><span class="lineno"> 122</span> <span class="comment">// encoders</span></div>
<div class="line"><a id="l00123" name="l00123"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a58e435217b9922f882507ebf48bfbbdd"> 123</a></span> std::unique_ptr&lt;CommandEncoder&gt; <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a58e435217b9922f882507ebf48bfbbdd">encoder</a>{<span class="keyword">nullptr</span>};</div>
<div class="line"><a id="l00124" name="l00124"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a876199de8da1efa9a362451029638499"> 124</a></span> std::shared_ptr&lt;Fence&gt; <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a876199de8da1efa9a362451029638499">fence</a>;</div>
<div class="line"><a id="l00125" name="l00125"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#aee88009117dfff1ad121eabe28d5f3de"> 125</a></span> std::vector&lt;array&gt; <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#aee88009117dfff1ad121eabe28d5f3de">temporaries</a>;</div>
<div class="line"><a id="l00126" name="l00126"></a><span class="lineno"> 126</span>};</div>
<div class="line"><a id="l00113" name="l00113"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d"> 113</a></span> MTL::CommandQueue* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">queue</a>;</div>
<div class="line"><a id="l00114" name="l00114"></a><span class="lineno"> 114</span> <span class="comment">// A map of prior command encoder outputs to their corresponding fence</span></div>
<div class="line"><a id="l00115" name="l00115"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a55a7a92c6abad369c99a5ede7a2521b9"> 115</a></span> std::unordered_map&lt;const void*, std::shared_ptr&lt;Fence&gt;&gt; <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a55a7a92c6abad369c99a5ede7a2521b9">outputs</a>;</div>
<div class="line"><a id="l00116" name="l00116"></a><span class="lineno"> 116</span> <span class="comment">// Used to allow thread-safe access to the outputs map</span></div>
<div class="line"><a id="l00117" name="l00117"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a6fa08cca881fc3798ae45994a11a4fcd"> 117</a></span> std::mutex <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a6fa08cca881fc3798ae45994a11a4fcd">fence_mtx</a>;</div>
<div class="line"><a id="l00118" name="l00118"></a><span class="lineno"> 118</span> </div>
<div class="line"><a id="l00119" name="l00119"></a><span class="lineno"> 119</span> <span class="comment">// The buffer and buffer op count are updated</span></div>
<div class="line"><a id="l00120" name="l00120"></a><span class="lineno"> 120</span> <span class="comment">// between command buffers</span></div>
<div class="line"><a id="l00121" name="l00121"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb"> 121</a></span> MTL::CommandBuffer* <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb">buffer</a>{<span class="keyword">nullptr</span>};</div>
<div class="line"><a id="l00122" name="l00122"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#ab6048b329e65a59033834f3bdd351782"> 122</a></span> <span class="keywordtype">int</span> <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#ab6048b329e65a59033834f3bdd351782">buffer_ops</a>{0};</div>
<div class="line"><a id="l00123" name="l00123"></a><span class="lineno"> 123</span> </div>
<div class="line"><a id="l00124" name="l00124"></a><span class="lineno"> 124</span> <span class="comment">// The command encoder, fence, and temporaries are updated between command</span></div>
<div class="line"><a id="l00125" name="l00125"></a><span class="lineno"> 125</span> <span class="comment">// encoders</span></div>
<div class="line"><a id="l00126" name="l00126"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a58e435217b9922f882507ebf48bfbbdd"> 126</a></span> std::unique_ptr&lt;CommandEncoder&gt; <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a58e435217b9922f882507ebf48bfbbdd">encoder</a>{<span class="keyword">nullptr</span>};</div>
<div class="line"><a id="l00127" name="l00127"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a876199de8da1efa9a362451029638499"> 127</a></span> std::shared_ptr&lt;Fence&gt; <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a876199de8da1efa9a362451029638499">fence</a>;</div>
<div class="line"><a id="l00128" name="l00128"></a><span class="lineno"><a class="line" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#aee88009117dfff1ad121eabe28d5f3de"> 128</a></span> std::vector&lt;array&gt; <a class="code hl_variable" href="structmlx_1_1core_1_1metal_1_1_device_stream.html#aee88009117dfff1ad121eabe28d5f3de">temporaries</a>;</div>
<div class="line"><a id="l00129" name="l00129"></a><span class="lineno"> 129</span>};</div>
</div>
<div class="line"><a id="l00127" name="l00127"></a><span class="lineno"> 127</span> </div>
<div class="foldopen" id="foldopen00128" data-start="{" data-end="};">
<div class="line"><a id="l00128" name="l00128"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html"> 128</a></span><span class="keyword">class </span><a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a> {</div>
<div class="line"><a id="l00129" name="l00129"></a><span class="lineno"> 129</span> <span class="keyword">public</span>:</div>
<div class="line"><a id="l00130" name="l00130"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#ae0db74570eb4b19d8cf19774db91bfd6"> 130</a></span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#ae0db74570eb4b19d8cf19774db91bfd6">Device</a>();</div>
<div class="line"><a id="l00131" name="l00131"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#abf59a4addb5473f9e814e3651ba85f06"> 131</a></span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#abf59a4addb5473f9e814e3651ba85f06">Device</a>(<span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a>&amp;) = <span class="keyword">delete</span>;</div>
<div class="line"><a id="l00132" name="l00132"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#ad1d6382fd18a46b1906e1b43e0bd2e73"> 132</a></span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a>&amp; <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#ad1d6382fd18a46b1906e1b43e0bd2e73">operator=</a>(<span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a>&amp;) = <span class="keyword">delete</span>;</div>
<div class="line"><a id="l00133" name="l00133"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a4f39c28c6cdd1d2da1918f5871bcba6e"> 133</a></span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a4f39c28c6cdd1d2da1918f5871bcba6e">~Device</a>();</div>
<div class="line"><a id="l00134" name="l00134"></a><span class="lineno"> 134</span> </div>
<div class="foldopen" id="foldopen00135" data-start="{" data-end="}">
<div class="line"><a id="l00135" name="l00135"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653"> 135</a></span> MTL::Device* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653">mtl_device</a>() {</div>
<div class="line"><a id="l00136" name="l00136"></a><span class="lineno"> 136</span> <span class="keywordflow">return</span> device_;</div>
<div class="line"><a id="l00137" name="l00137"></a><span class="lineno"> 137</span> };</div>
<div class="line"><a id="l00130" name="l00130"></a><span class="lineno"> 130</span> </div>
<div class="foldopen" id="foldopen00131" data-start="{" data-end="};">
<div class="line"><a id="l00131" name="l00131"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html"> 131</a></span><span class="keyword">class </span><a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a> {</div>
<div class="line"><a id="l00132" name="l00132"></a><span class="lineno"> 132</span> <span class="keyword">public</span>:</div>
<div class="line"><a id="l00133" name="l00133"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#ae0db74570eb4b19d8cf19774db91bfd6"> 133</a></span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#ae0db74570eb4b19d8cf19774db91bfd6">Device</a>();</div>
<div class="line"><a id="l00134" name="l00134"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#abf59a4addb5473f9e814e3651ba85f06"> 134</a></span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#abf59a4addb5473f9e814e3651ba85f06">Device</a>(<span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a>&amp;) = <span class="keyword">delete</span>;</div>
<div class="line"><a id="l00135" name="l00135"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#ad1d6382fd18a46b1906e1b43e0bd2e73"> 135</a></span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a>&amp; <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#ad1d6382fd18a46b1906e1b43e0bd2e73">operator=</a>(<span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a>&amp;) = <span class="keyword">delete</span>;</div>
<div class="line"><a id="l00136" name="l00136"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a4f39c28c6cdd1d2da1918f5871bcba6e"> 136</a></span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a4f39c28c6cdd1d2da1918f5871bcba6e">~Device</a>();</div>
<div class="line"><a id="l00137" name="l00137"></a><span class="lineno"> 137</span> </div>
<div class="foldopen" id="foldopen00138" data-start="{" data-end="}">
<div class="line"><a id="l00138" name="l00138"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653"> 138</a></span> MTL::Device* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653">mtl_device</a>() {</div>
<div class="line"><a id="l00139" name="l00139"></a><span class="lineno"> 139</span> <span class="keywordflow">return</span> device_;</div>
<div class="line"><a id="l00140" name="l00140"></a><span class="lineno"> 140</span> };</div>
</div>
<div class="line"><a id="l00138" name="l00138"></a><span class="lineno"> 138</span> </div>
<div class="foldopen" id="foldopen00139" data-start="{" data-end="}">
<div class="line"><a id="l00139" name="l00139"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a65f64dd8bafdc704d871fc5be5e7bc0b"> 139</a></span> <span class="keyword">const</span> std::string&amp; <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a65f64dd8bafdc704d871fc5be5e7bc0b">get_architecture</a>() {</div>
<div class="line"><a id="l00140" name="l00140"></a><span class="lineno"> 140</span> <span class="keywordflow">return</span> arch_;</div>
<div class="line"><a id="l00141" name="l00141"></a><span class="lineno"> 141</span> }</div>
<div class="line"><a id="l00141" name="l00141"></a><span class="lineno"> 141</span> </div>
<div class="foldopen" id="foldopen00142" data-start="{" data-end="}">
<div class="line"><a id="l00142" name="l00142"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a65f64dd8bafdc704d871fc5be5e7bc0b"> 142</a></span> <span class="keyword">const</span> std::string&amp; <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a65f64dd8bafdc704d871fc5be5e7bc0b">get_architecture</a>() {</div>
<div class="line"><a id="l00143" name="l00143"></a><span class="lineno"> 143</span> <span class="keywordflow">return</span> arch_;</div>
<div class="line"><a id="l00144" name="l00144"></a><span class="lineno"> 144</span> }</div>
</div>
<div class="line"><a id="l00142" name="l00142"></a><span class="lineno"> 142</span> </div>
<div class="line"><a id="l00143" name="l00143"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a8135ae2a8c1e6f3861e84d4e60c28b67"> 143</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a8135ae2a8c1e6f3861e84d4e60c28b67">new_queue</a>(<span class="keywordtype">int</span> index);</div>
<div class="line"><a id="l00144" name="l00144"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a5fe3970fbe92ccc55fce4241ffbe5210"> 144</a></span> MTL::CommandBuffer* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a5fe3970fbe92ccc55fce4241ffbe5210">get_command_buffer</a>(<span class="keywordtype">int</span> index);</div>
<div class="line"><a id="l00145" name="l00145"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a064e1cb6a16de7a0619f6447622350f8"> 145</a></span> <span class="keywordtype">int</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a064e1cb6a16de7a0619f6447622350f8">get_command_buffer_ops</a>(<span class="keywordtype">int</span> index);</div>
<div class="line"><a id="l00146" name="l00146"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a7a33d4d601423a3d3c23d5ad7072abb6"> 146</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a7a33d4d601423a3d3c23d5ad7072abb6">increment_command_buffer_ops</a>(<span class="keywordtype">int</span> index);</div>
<div class="line"><a id="l00147" name="l00147"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a95248f1387824067fd4fed23ace5ac0c"> 147</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a95248f1387824067fd4fed23ace5ac0c">commit_command_buffer</a>(<span class="keywordtype">int</span> index);</div>
<div class="line"><a id="l00148" name="l00148"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#affa682ef612def4890f5152f81ffb7e6"> 148</a></span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a>&amp; <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#affa682ef612def4890f5152f81ffb7e6">get_command_encoder</a>(<span class="keywordtype">int</span> index);</div>
<div class="line"><a id="l00149" name="l00149"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a60689f97347811b27e8c5ca23e0372bf"> 149</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a60689f97347811b27e8c5ca23e0372bf">end_encoding</a>(<span class="keywordtype">int</span> index);</div>
<div class="line"><a id="l00150" name="l00150"></a><span class="lineno"> 150</span> </div>
<div class="line"><a id="l00151" name="l00151"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a45945f2efcd242d915ffa2171e92bf9d"> 151</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a45945f2efcd242d915ffa2171e92bf9d">register_library</a>(</div>
<div class="line"><a id="l00152" name="l00152"></a><span class="lineno"> 152</span> <span class="keyword">const</span> std::string&amp; lib_name,</div>
<div class="line"><a id="l00153" name="l00153"></a><span class="lineno"> 153</span> <span class="keyword">const</span> std::string&amp; lib_path);</div>
<div class="line"><a id="l00154" name="l00154"></a><span class="lineno"> 154</span> </div>
<div class="line"><a id="l00155" name="l00155"></a><span class="lineno"> 155</span> <span class="comment">// Note, this should remain in the header so that it is not dynamically</span></div>
<div class="line"><a id="l00156" name="l00156"></a><span class="lineno"> 156</span> <span class="comment">// linked</span></div>
<div class="foldopen" id="foldopen00157" data-start="{" data-end="}">
<div class="line"><a id="l00157" name="l00157"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a99ff72689b7beb65ad4541391b0eeabf"> 157</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a99ff72689b7beb65ad4541391b0eeabf">register_library</a>(<span class="keyword">const</span> std::string&amp; lib_name) {</div>
<div class="line"><a id="l00158" name="l00158"></a><span class="lineno"> 158</span> <span class="keywordflow">if</span> (<span class="keyword">auto</span> it = library_map_.find(lib_name); it == library_map_.end()) {</div>
<div class="line"><a id="l00159" name="l00159"></a><span class="lineno"> 159</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a45945f2efcd242d915ffa2171e92bf9d">register_library</a>(lib_name, <a class="code hl_function" href="namespacemlx_1_1core_1_1metal.html#a5fd6ba2040e53a254b9d71ae7ebd315f">get_colocated_mtllib_path</a>(lib_name));</div>
<div class="line"><a id="l00160" name="l00160"></a><span class="lineno"> 160</span> }</div>
<div class="line"><a id="l00161" name="l00161"></a><span class="lineno"> 161</span> }</div>
<div class="line"><a id="l00145" name="l00145"></a><span class="lineno"> 145</span> </div>
<div class="line"><a id="l00146" name="l00146"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a8135ae2a8c1e6f3861e84d4e60c28b67"> 146</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a8135ae2a8c1e6f3861e84d4e60c28b67">new_queue</a>(<span class="keywordtype">int</span> index);</div>
<div class="line"><a id="l00147" name="l00147"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a5fe3970fbe92ccc55fce4241ffbe5210"> 147</a></span> MTL::CommandBuffer* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a5fe3970fbe92ccc55fce4241ffbe5210">get_command_buffer</a>(<span class="keywordtype">int</span> index);</div>
<div class="line"><a id="l00148" name="l00148"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a064e1cb6a16de7a0619f6447622350f8"> 148</a></span> <span class="keywordtype">int</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a064e1cb6a16de7a0619f6447622350f8">get_command_buffer_ops</a>(<span class="keywordtype">int</span> index);</div>
<div class="line"><a id="l00149" name="l00149"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a7a33d4d601423a3d3c23d5ad7072abb6"> 149</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a7a33d4d601423a3d3c23d5ad7072abb6">increment_command_buffer_ops</a>(<span class="keywordtype">int</span> index);</div>
<div class="line"><a id="l00150" name="l00150"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a95248f1387824067fd4fed23ace5ac0c"> 150</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a95248f1387824067fd4fed23ace5ac0c">commit_command_buffer</a>(<span class="keywordtype">int</span> index);</div>
<div class="line"><a id="l00151" name="l00151"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#affa682ef612def4890f5152f81ffb7e6"> 151</a></span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a>&amp; <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#affa682ef612def4890f5152f81ffb7e6">get_command_encoder</a>(<span class="keywordtype">int</span> index);</div>
<div class="line"><a id="l00152" name="l00152"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a60689f97347811b27e8c5ca23e0372bf"> 152</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a60689f97347811b27e8c5ca23e0372bf">end_encoding</a>(<span class="keywordtype">int</span> index);</div>
<div class="line"><a id="l00153" name="l00153"></a><span class="lineno"> 153</span> </div>
<div class="line"><a id="l00154" name="l00154"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a45945f2efcd242d915ffa2171e92bf9d"> 154</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a45945f2efcd242d915ffa2171e92bf9d">register_library</a>(</div>
<div class="line"><a id="l00155" name="l00155"></a><span class="lineno"> 155</span> <span class="keyword">const</span> std::string&amp; lib_name,</div>
<div class="line"><a id="l00156" name="l00156"></a><span class="lineno"> 156</span> <span class="keyword">const</span> std::string&amp; lib_path);</div>
<div class="line"><a id="l00157" name="l00157"></a><span class="lineno"> 157</span> </div>
<div class="line"><a id="l00158" name="l00158"></a><span class="lineno"> 158</span> <span class="comment">// Note, this should remain in the header so that it is not dynamically</span></div>
<div class="line"><a id="l00159" name="l00159"></a><span class="lineno"> 159</span> <span class="comment">// linked</span></div>
<div class="foldopen" id="foldopen00160" data-start="{" data-end="}">
<div class="line"><a id="l00160" name="l00160"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a99ff72689b7beb65ad4541391b0eeabf"> 160</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a99ff72689b7beb65ad4541391b0eeabf">register_library</a>(<span class="keyword">const</span> std::string&amp; lib_name) {</div>
<div class="line"><a id="l00161" name="l00161"></a><span class="lineno"> 161</span> <span class="keywordflow">if</span> (<span class="keyword">auto</span> it = library_map_.find(lib_name); it == library_map_.end()) {</div>
<div class="line"><a id="l00162" name="l00162"></a><span class="lineno"> 162</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a45945f2efcd242d915ffa2171e92bf9d">register_library</a>(lib_name, <a class="code hl_function" href="namespacemlx_1_1core_1_1metal.html#a5fd6ba2040e53a254b9d71ae7ebd315f">get_colocated_mtllib_path</a>(lib_name));</div>
<div class="line"><a id="l00163" name="l00163"></a><span class="lineno"> 163</span> }</div>
<div class="line"><a id="l00164" name="l00164"></a><span class="lineno"> 164</span> }</div>
</div>
<div class="line"><a id="l00162" name="l00162"></a><span class="lineno"> 162</span> </div>
<div class="line"><a id="l00163" name="l00163"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a75ed55e73baf48013028796518723ff0"> 163</a></span> MTL::Library* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a75ed55e73baf48013028796518723ff0">get_library</a>(</div>
<div class="line"><a id="l00164" name="l00164"></a><span class="lineno"> 164</span> <span class="keyword">const</span> std::string&amp; name,</div>
<div class="line"><a id="l00165" name="l00165"></a><span class="lineno"> 165</span> <span class="keyword">const</span> std::function&lt;std::string(<span class="keywordtype">void</span>)&gt;&amp; builder);</div>
<div class="line"><a id="l00166" name="l00166"></a><span class="lineno"> 166</span> </div>
<div class="line"><a id="l00167" name="l00167"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a6810c4dcbcfbf93fc51d42aa5ff0fc3a"> 167</a></span> MTL::ComputePipelineState* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a6810c4dcbcfbf93fc51d42aa5ff0fc3a">get_kernel</a>(</div>
<div class="line"><a id="l00168" name="l00168"></a><span class="lineno"> 168</span> <span class="keyword">const</span> std::string&amp; base_name,</div>
<div class="line"><a id="l00169" name="l00169"></a><span class="lineno"> 169</span> MTL::Library* mtl_lib,</div>
<div class="line"><a id="l00170" name="l00170"></a><span class="lineno"> 170</span> <span class="keyword">const</span> std::string&amp; hash_name = <span class="stringliteral">&quot;&quot;</span>,</div>
<div class="line"><a id="l00171" name="l00171"></a><span class="lineno"> 171</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">MTLFCList</a>&amp; func_consts = {},</div>
<div class="line"><a id="l00172" name="l00172"></a><span class="lineno"> 172</span> <span class="keyword">const</span> std::vector&lt;MTL::Function*&gt;&amp; linked_functions = {});</div>
<div class="line"><a id="l00173" name="l00173"></a><span class="lineno"> 173</span> </div>
<div class="line"><a id="l00174" name="l00174"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#afa0cac9d800c21a8a7f6cb224256abaf"> 174</a></span> MTL::ComputePipelineState* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#afa0cac9d800c21a8a7f6cb224256abaf">get_kernel</a>(</div>
<div class="line"><a id="l00175" name="l00175"></a><span class="lineno"> 175</span> <span class="keyword">const</span> std::string&amp; base_name,</div>
<div class="line"><a id="l00176" name="l00176"></a><span class="lineno"> 176</span> <span class="keyword">const</span> std::string&amp; lib_name = <span class="stringliteral">&quot;mlx&quot;</span>,</div>
<div class="line"><a id="l00177" name="l00177"></a><span class="lineno"> 177</span> <span class="keyword">const</span> std::string&amp; hash_name = <span class="stringliteral">&quot;&quot;</span>,</div>
<div class="line"><a id="l00178" name="l00178"></a><span class="lineno"> 178</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">MTLFCList</a>&amp; func_consts = {},</div>
<div class="line"><a id="l00179" name="l00179"></a><span class="lineno"> 179</span> <span class="keyword">const</span> std::vector&lt;MTL::Function*&gt;&amp; linked_functions = {});</div>
<div class="line"><a id="l00180" name="l00180"></a><span class="lineno"> 180</span> </div>
<div class="line"><a id="l00181" name="l00181"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a6e33e2b1287324fb4a6575e0da5e5881"> 181</a></span> MTL::ArgumentEncoder* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a6e33e2b1287324fb4a6575e0da5e5881">argument_encoder</a>(</div>
<div class="line"><a id="l00182" name="l00182"></a><span class="lineno"> 182</span> <span class="keyword">const</span> std::vector&lt;MTL::ArgumentDescriptor*&gt;&amp; arg_descs) <span class="keyword">const</span>;</div>
<div class="line"><a id="l00165" name="l00165"></a><span class="lineno"> 165</span> </div>
<div class="line"><a id="l00166" name="l00166"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a75ed55e73baf48013028796518723ff0"> 166</a></span> MTL::Library* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a75ed55e73baf48013028796518723ff0">get_library</a>(</div>
<div class="line"><a id="l00167" name="l00167"></a><span class="lineno"> 167</span> <span class="keyword">const</span> std::string&amp; name,</div>
<div class="line"><a id="l00168" name="l00168"></a><span class="lineno"> 168</span> <span class="keyword">const</span> std::function&lt;std::string(<span class="keywordtype">void</span>)&gt;&amp; builder);</div>
<div class="line"><a id="l00169" name="l00169"></a><span class="lineno"> 169</span> </div>
<div class="line"><a id="l00170" name="l00170"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a6810c4dcbcfbf93fc51d42aa5ff0fc3a"> 170</a></span> MTL::ComputePipelineState* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a6810c4dcbcfbf93fc51d42aa5ff0fc3a">get_kernel</a>(</div>
<div class="line"><a id="l00171" name="l00171"></a><span class="lineno"> 171</span> <span class="keyword">const</span> std::string&amp; base_name,</div>
<div class="line"><a id="l00172" name="l00172"></a><span class="lineno"> 172</span> MTL::Library* mtl_lib,</div>
<div class="line"><a id="l00173" name="l00173"></a><span class="lineno"> 173</span> <span class="keyword">const</span> std::string&amp; hash_name = <span class="stringliteral">&quot;&quot;</span>,</div>
<div class="line"><a id="l00174" name="l00174"></a><span class="lineno"> 174</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">MTLFCList</a>&amp; func_consts = {},</div>
<div class="line"><a id="l00175" name="l00175"></a><span class="lineno"> 175</span> <span class="keyword">const</span> std::vector&lt;MTL::Function*&gt;&amp; linked_functions = {});</div>
<div class="line"><a id="l00176" name="l00176"></a><span class="lineno"> 176</span> </div>
<div class="line"><a id="l00177" name="l00177"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#afa0cac9d800c21a8a7f6cb224256abaf"> 177</a></span> MTL::ComputePipelineState* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#afa0cac9d800c21a8a7f6cb224256abaf">get_kernel</a>(</div>
<div class="line"><a id="l00178" name="l00178"></a><span class="lineno"> 178</span> <span class="keyword">const</span> std::string&amp; base_name,</div>
<div class="line"><a id="l00179" name="l00179"></a><span class="lineno"> 179</span> <span class="keyword">const</span> std::string&amp; lib_name = <span class="stringliteral">&quot;mlx&quot;</span>,</div>
<div class="line"><a id="l00180" name="l00180"></a><span class="lineno"> 180</span> <span class="keyword">const</span> std::string&amp; hash_name = <span class="stringliteral">&quot;&quot;</span>,</div>
<div class="line"><a id="l00181" name="l00181"></a><span class="lineno"> 181</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">MTLFCList</a>&amp; func_consts = {},</div>
<div class="line"><a id="l00182" name="l00182"></a><span class="lineno"> 182</span> <span class="keyword">const</span> std::vector&lt;MTL::Function*&gt;&amp; linked_functions = {});</div>
<div class="line"><a id="l00183" name="l00183"></a><span class="lineno"> 183</span> </div>
<div class="line"><a id="l00184" name="l00184"></a><span class="lineno"> 184</span> <span class="comment">// Record temporary arrays for the given stream index</span></div>
<div class="line"><a id="l00185" name="l00185"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#acb90010af0cffe27fd8cc6c253d3a576"> 185</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#acb90010af0cffe27fd8cc6c253d3a576">add_temporary</a>(<a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a> arr, <span class="keywordtype">int</span> index);</div>
<div class="line"><a id="l00186" name="l00186"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a72ad17c96fc6ce825bc77f0bed657901"> 186</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a72ad17c96fc6ce825bc77f0bed657901">add_temporaries</a>(std::vector&lt;array&gt; arrays, <span class="keywordtype">int</span> index);</div>
<div class="line"><a id="l00187" name="l00187"></a><span class="lineno"> 187</span> </div>
<div class="line"><a id="l00188" name="l00188"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a03a2f0c712660a1bd437cb16e4aba79f"> 188</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a03a2f0c712660a1bd437cb16e4aba79f">set_residency_set</a>(<span class="keyword">const</span> MTL::ResidencySet* residency_set);</div>
<div class="line"><a id="l00189" name="l00189"></a><span class="lineno"> 189</span> </div>
<div class="line"><a id="l00190" name="l00190"></a><span class="lineno"> 190</span> <span class="keyword">private</span>:</div>
<div class="line"><a id="l00191" name="l00191"></a><span class="lineno"> 191</span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_device_stream.html">DeviceStream</a>&amp; get_stream_(<span class="keywordtype">int</span> index) {</div>
<div class="line"><a id="l00192" name="l00192"></a><span class="lineno"> 192</span> <span class="keywordflow">return</span> stream_map_.find(index)-&gt;second;</div>
<div class="line"><a id="l00193" name="l00193"></a><span class="lineno"> 193</span> }</div>
<div class="line"><a id="l00194" name="l00194"></a><span class="lineno"> 194</span> MTL::Library* get_library_cache_(<span class="keyword">const</span> std::string&amp; name);</div>
<div class="line"><a id="l00195" name="l00195"></a><span class="lineno"> 195</span> </div>
<div class="line"><a id="l00196" name="l00196"></a><span class="lineno"> 196</span> MTL::Library* get_library_(<span class="keyword">const</span> std::string&amp; name);</div>
<div class="line"><a id="l00197" name="l00197"></a><span class="lineno"> 197</span> MTL::Library* build_library_(<span class="keyword">const</span> std::string&amp; source_string);</div>
<div class="line"><a id="l00184" name="l00184"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a6e33e2b1287324fb4a6575e0da5e5881"> 184</a></span> MTL::ArgumentEncoder* <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a6e33e2b1287324fb4a6575e0da5e5881">argument_encoder</a>(</div>
<div class="line"><a id="l00185" name="l00185"></a><span class="lineno"> 185</span> <span class="keyword">const</span> std::vector&lt;MTL::ArgumentDescriptor*&gt;&amp; arg_descs) <span class="keyword">const</span>;</div>
<div class="line"><a id="l00186" name="l00186"></a><span class="lineno"> 186</span> </div>
<div class="line"><a id="l00187" name="l00187"></a><span class="lineno"> 187</span> <span class="comment">// Record temporary arrays for the given stream index</span></div>
<div class="line"><a id="l00188" name="l00188"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#acb90010af0cffe27fd8cc6c253d3a576"> 188</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#acb90010af0cffe27fd8cc6c253d3a576">add_temporary</a>(<a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a> arr, <span class="keywordtype">int</span> index);</div>
<div class="line"><a id="l00189" name="l00189"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a72ad17c96fc6ce825bc77f0bed657901"> 189</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a72ad17c96fc6ce825bc77f0bed657901">add_temporaries</a>(std::vector&lt;array&gt; arrays, <span class="keywordtype">int</span> index);</div>
<div class="line"><a id="l00190" name="l00190"></a><span class="lineno"> 190</span> </div>
<div class="line"><a id="l00191" name="l00191"></a><span class="lineno"><a class="line" href="classmlx_1_1core_1_1metal_1_1_device.html#a03a2f0c712660a1bd437cb16e4aba79f"> 191</a></span> <span class="keywordtype">void</span> <a class="code hl_function" href="classmlx_1_1core_1_1metal_1_1_device.html#a03a2f0c712660a1bd437cb16e4aba79f">set_residency_set</a>(<span class="keyword">const</span> MTL::ResidencySet* residency_set);</div>
<div class="line"><a id="l00192" name="l00192"></a><span class="lineno"> 192</span> </div>
<div class="line"><a id="l00193" name="l00193"></a><span class="lineno"> 193</span> <span class="keyword">private</span>:</div>
<div class="line"><a id="l00194" name="l00194"></a><span class="lineno"> 194</span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_device_stream.html">DeviceStream</a>&amp; get_stream_(<span class="keywordtype">int</span> index) {</div>
<div class="line"><a id="l00195" name="l00195"></a><span class="lineno"> 195</span> <span class="keywordflow">return</span> stream_map_.find(index)-&gt;second;</div>
<div class="line"><a id="l00196" name="l00196"></a><span class="lineno"> 196</span> }</div>
<div class="line"><a id="l00197" name="l00197"></a><span class="lineno"> 197</span> MTL::Library* get_library_cache_(<span class="keyword">const</span> std::string&amp; name);</div>
<div class="line"><a id="l00198" name="l00198"></a><span class="lineno"> 198</span> </div>
<div class="line"><a id="l00199" name="l00199"></a><span class="lineno"> 199</span> MTL::Function* get_function_(<span class="keyword">const</span> std::string&amp; name, MTL::Library* mtl_lib);</div>
<div class="line"><a id="l00200" name="l00200"></a><span class="lineno"> 200</span> </div>
<div class="line"><a id="l00201" name="l00201"></a><span class="lineno"> 201</span> MTL::Function* get_function_(</div>
<div class="line"><a id="l00202" name="l00202"></a><span class="lineno"> 202</span> <span class="keyword">const</span> std::string&amp; name,</div>
<div class="line"><a id="l00203" name="l00203"></a><span class="lineno"> 203</span> <span class="keyword">const</span> std::string&amp; specialized_name,</div>
<div class="line"><a id="l00204" name="l00204"></a><span class="lineno"> 204</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">MTLFCList</a>&amp; func_consts,</div>
<div class="line"><a id="l00205" name="l00205"></a><span class="lineno"> 205</span> MTL::Library* mtl_lib);</div>
<div class="line"><a id="l00206" name="l00206"></a><span class="lineno"> 206</span> </div>
<div class="line"><a id="l00207" name="l00207"></a><span class="lineno"> 207</span> MTL::LinkedFunctions* get_linked_functions_(</div>
<div class="line"><a id="l00208" name="l00208"></a><span class="lineno"> 208</span> <span class="keyword">const</span> std::vector&lt;MTL::Function*&gt;&amp; funcs);</div>
<div class="line"><a id="l00199" name="l00199"></a><span class="lineno"> 199</span> MTL::Library* get_library_(<span class="keyword">const</span> std::string&amp; name);</div>
<div class="line"><a id="l00200" name="l00200"></a><span class="lineno"> 200</span> MTL::Library* build_library_(<span class="keyword">const</span> std::string&amp; source_string);</div>
<div class="line"><a id="l00201" name="l00201"></a><span class="lineno"> 201</span> </div>
<div class="line"><a id="l00202" name="l00202"></a><span class="lineno"> 202</span> MTL::Function* get_function_(<span class="keyword">const</span> std::string&amp; name, MTL::Library* mtl_lib);</div>
<div class="line"><a id="l00203" name="l00203"></a><span class="lineno"> 203</span> </div>
<div class="line"><a id="l00204" name="l00204"></a><span class="lineno"> 204</span> MTL::Function* get_function_(</div>
<div class="line"><a id="l00205" name="l00205"></a><span class="lineno"> 205</span> <span class="keyword">const</span> std::string&amp; name,</div>
<div class="line"><a id="l00206" name="l00206"></a><span class="lineno"> 206</span> <span class="keyword">const</span> std::string&amp; specialized_name,</div>
<div class="line"><a id="l00207" name="l00207"></a><span class="lineno"> 207</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">MTLFCList</a>&amp; func_consts,</div>
<div class="line"><a id="l00208" name="l00208"></a><span class="lineno"> 208</span> MTL::Library* mtl_lib);</div>
<div class="line"><a id="l00209" name="l00209"></a><span class="lineno"> 209</span> </div>
<div class="line"><a id="l00210" name="l00210"></a><span class="lineno"> 210</span> MTL::ComputePipelineState* get_kernel_(</div>
<div class="line"><a id="l00211" name="l00211"></a><span class="lineno"> 211</span> <span class="keyword">const</span> std::string&amp; name,</div>
<div class="line"><a id="l00212" name="l00212"></a><span class="lineno"> 212</span> <span class="keyword">const</span> MTL::Function* mtl_function);</div>
<div class="line"><a id="l00213" name="l00213"></a><span class="lineno"> 213</span> </div>
<div class="line"><a id="l00214" name="l00214"></a><span class="lineno"> 214</span> MTL::ComputePipelineState* get_kernel_(</div>
<div class="line"><a id="l00215" name="l00215"></a><span class="lineno"> 215</span> <span class="keyword">const</span> std::string&amp; name,</div>
<div class="line"><a id="l00216" name="l00216"></a><span class="lineno"> 216</span> <span class="keyword">const</span> MTL::Function* mtl_function,</div>
<div class="line"><a id="l00217" name="l00217"></a><span class="lineno"> 217</span> <span class="keyword">const</span> MTL::LinkedFunctions* linked_functions);</div>
<div class="line"><a id="l00218" name="l00218"></a><span class="lineno"> 218</span> </div>
<div class="line"><a id="l00219" name="l00219"></a><span class="lineno"> 219</span> MTL::ComputePipelineState* get_kernel_(</div>
<div class="line"><a id="l00220" name="l00220"></a><span class="lineno"> 220</span> <span class="keyword">const</span> std::string&amp; base_name,</div>
<div class="line"><a id="l00221" name="l00221"></a><span class="lineno"> 221</span> MTL::Library* mtl_lib,</div>
<div class="line"><a id="l00222" name="l00222"></a><span class="lineno"> 222</span> <span class="keyword">const</span> std::string&amp; hash_name,</div>
<div class="line"><a id="l00223" name="l00223"></a><span class="lineno"> 223</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">MTLFCList</a>&amp; func_consts = {},</div>
<div class="line"><a id="l00224" name="l00224"></a><span class="lineno"> 224</span> <span class="keyword">const</span> std::vector&lt;MTL::Function*&gt;&amp; linked_functions = {});</div>
<div class="line"><a id="l00225" name="l00225"></a><span class="lineno"> 225</span> </div>
<div class="line"><a id="l00226" name="l00226"></a><span class="lineno"> 226</span> MTL::Device* device_;</div>
<div class="line"><a id="l00227" name="l00227"></a><span class="lineno"> 227</span> std::unordered_map&lt;int32_t, DeviceStream&gt; stream_map_;</div>
<div class="line"><a id="l00210" name="l00210"></a><span class="lineno"> 210</span> MTL::LinkedFunctions* get_linked_functions_(</div>
<div class="line"><a id="l00211" name="l00211"></a><span class="lineno"> 211</span> <span class="keyword">const</span> std::vector&lt;MTL::Function*&gt;&amp; funcs);</div>
<div class="line"><a id="l00212" name="l00212"></a><span class="lineno"> 212</span> </div>
<div class="line"><a id="l00213" name="l00213"></a><span class="lineno"> 213</span> MTL::ComputePipelineState* get_kernel_(</div>
<div class="line"><a id="l00214" name="l00214"></a><span class="lineno"> 214</span> <span class="keyword">const</span> std::string&amp; name,</div>
<div class="line"><a id="l00215" name="l00215"></a><span class="lineno"> 215</span> <span class="keyword">const</span> MTL::Function* mtl_function);</div>
<div class="line"><a id="l00216" name="l00216"></a><span class="lineno"> 216</span> </div>
<div class="line"><a id="l00217" name="l00217"></a><span class="lineno"> 217</span> MTL::ComputePipelineState* get_kernel_(</div>
<div class="line"><a id="l00218" name="l00218"></a><span class="lineno"> 218</span> <span class="keyword">const</span> std::string&amp; name,</div>
<div class="line"><a id="l00219" name="l00219"></a><span class="lineno"> 219</span> <span class="keyword">const</span> MTL::Function* mtl_function,</div>
<div class="line"><a id="l00220" name="l00220"></a><span class="lineno"> 220</span> <span class="keyword">const</span> MTL::LinkedFunctions* linked_functions);</div>
<div class="line"><a id="l00221" name="l00221"></a><span class="lineno"> 221</span> </div>
<div class="line"><a id="l00222" name="l00222"></a><span class="lineno"> 222</span> MTL::ComputePipelineState* get_kernel_(</div>
<div class="line"><a id="l00223" name="l00223"></a><span class="lineno"> 223</span> <span class="keyword">const</span> std::string&amp; base_name,</div>
<div class="line"><a id="l00224" name="l00224"></a><span class="lineno"> 224</span> MTL::Library* mtl_lib,</div>
<div class="line"><a id="l00225" name="l00225"></a><span class="lineno"> 225</span> <span class="keyword">const</span> std::string&amp; hash_name,</div>
<div class="line"><a id="l00226" name="l00226"></a><span class="lineno"> 226</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">MTLFCList</a>&amp; func_consts = {},</div>
<div class="line"><a id="l00227" name="l00227"></a><span class="lineno"> 227</span> <span class="keyword">const</span> std::vector&lt;MTL::Function*&gt;&amp; linked_functions = {});</div>
<div class="line"><a id="l00228" name="l00228"></a><span class="lineno"> 228</span> </div>
<div class="line"><a id="l00229" name="l00229"></a><span class="lineno"> 229</span> std::shared_mutex kernel_mtx_;</div>
<div class="line"><a id="l00230" name="l00230"></a><span class="lineno"> 230</span> std::unordered_map&lt;std::string, MTL::ComputePipelineState*&gt; kernel_map_;</div>
<div class="line"><a id="l00229" name="l00229"></a><span class="lineno"> 229</span> MTL::Device* device_;</div>
<div class="line"><a id="l00230" name="l00230"></a><span class="lineno"> 230</span> std::unordered_map&lt;int32_t, DeviceStream&gt; stream_map_;</div>
<div class="line"><a id="l00231" name="l00231"></a><span class="lineno"> 231</span> </div>
<div class="line"><a id="l00232" name="l00232"></a><span class="lineno"> 232</span> std::shared_mutex library_mtx_;</div>
<div class="line"><a id="l00233" name="l00233"></a><span class="lineno"> 233</span> std::unordered_map&lt;std::string, MTL::Library*&gt; library_map_;</div>
<div class="line"><a id="l00234" name="l00234"></a><span class="lineno"> 234</span> <span class="keyword">const</span> MTL::ResidencySet* residency_set_{<span class="keyword">nullptr</span>};</div>
<div class="line"><a id="l00235" name="l00235"></a><span class="lineno"> 235</span> std::string arch_;</div>
<div class="line"><a id="l00236" name="l00236"></a><span class="lineno"> 236</span>};</div>
<div class="line"><a id="l00232" name="l00232"></a><span class="lineno"> 232</span> std::shared_mutex kernel_mtx_;</div>
<div class="line"><a id="l00233" name="l00233"></a><span class="lineno"> 233</span> std::unordered_map&lt;std::string, MTL::ComputePipelineState*&gt; kernel_map_;</div>
<div class="line"><a id="l00234" name="l00234"></a><span class="lineno"> 234</span> </div>
<div class="line"><a id="l00235" name="l00235"></a><span class="lineno"> 235</span> std::shared_mutex library_mtx_;</div>
<div class="line"><a id="l00236" name="l00236"></a><span class="lineno"> 236</span> std::unordered_map&lt;std::string, MTL::Library*&gt; library_map_;</div>
<div class="line"><a id="l00237" name="l00237"></a><span class="lineno"> 237</span> <span class="keyword">const</span> MTL::ResidencySet* residency_set_{<span class="keyword">nullptr</span>};</div>
<div class="line"><a id="l00238" name="l00238"></a><span class="lineno"> 238</span> std::string arch_;</div>
<div class="line"><a id="l00239" name="l00239"></a><span class="lineno"> 239</span>};</div>
</div>
<div class="line"><a id="l00237" name="l00237"></a><span class="lineno"> 237</span> </div>
<div class="line"><a id="l00238" name="l00238"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core_1_1metal.html#a910797b74824e6ee576fbb533dee8b57"> 238</a></span><a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a>&amp; <a class="code hl_function" href="namespacemlx_1_1core_1_1metal.html#a910797b74824e6ee576fbb533dee8b57">device</a>(<a class="code hl_struct" href="structmlx_1_1core_1_1_device.html">mlx::core::Device</a>);</div>
<div class="line"><a id="l00239" name="l00239"></a><span class="lineno"> 239</span> </div>
<div class="line"><a id="l00240" name="l00240"></a><span class="lineno"> 240</span>} <span class="comment">// namespace mlx::core::metal</span></div>
<div class="line"><a id="l00240" name="l00240"></a><span class="lineno"> 240</span> </div>
<div class="line"><a id="l00241" name="l00241"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core_1_1metal.html#a910797b74824e6ee576fbb533dee8b57"> 241</a></span><a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">Device</a>&amp; <a class="code hl_function" href="namespacemlx_1_1core_1_1metal.html#a910797b74824e6ee576fbb533dee8b57">device</a>(<a class="code hl_struct" href="structmlx_1_1core_1_1_device.html">mlx::core::Device</a>);</div>
<div class="line"><a id="l00242" name="l00242"></a><span class="lineno"> 242</span> </div>
<div class="line"><a id="l00243" name="l00243"></a><span class="lineno"> 243</span>} <span class="comment">// namespace mlx::core::metal</span></div>
<div class="ttc" id="aarray_8h_html"><div class="ttname"><a href="array_8h.html">array.h</a></div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1array_html"><div class="ttname"><a href="classmlx_1_1core_1_1array.html">mlx::core::array</a></div><div class="ttdef"><b>Definition</b> array.h:20</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html">mlx::core::metal::Device</a></div><div class="ttdef"><b>Definition</b> device.h:128</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html">mlx::core::metal::Device</a></div><div class="ttdef"><b>Definition</b> device.h:131</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a03a2f0c712660a1bd437cb16e4aba79f"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a03a2f0c712660a1bd437cb16e4aba79f">mlx::core::metal::Device::set_residency_set</a></div><div class="ttdeci">void set_residency_set(const MTL::ResidencySet *residency_set)</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a064e1cb6a16de7a0619f6447622350f8"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a064e1cb6a16de7a0619f6447622350f8">mlx::core::metal::Device::get_command_buffer_ops</a></div><div class="ttdeci">int get_command_buffer_ops(int index)</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a31dba377f2be44a746db10d1b9367653"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653">mlx::core::metal::Device::mtl_device</a></div><div class="ttdeci">MTL::Device * mtl_device()</div><div class="ttdef"><b>Definition</b> device.h:135</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a31dba377f2be44a746db10d1b9367653"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653">mlx::core::metal::Device::mtl_device</a></div><div class="ttdeci">MTL::Device * mtl_device()</div><div class="ttdef"><b>Definition</b> device.h:138</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a45945f2efcd242d915ffa2171e92bf9d"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a45945f2efcd242d915ffa2171e92bf9d">mlx::core::metal::Device::register_library</a></div><div class="ttdeci">void register_library(const std::string &amp;lib_name, const std::string &amp;lib_path)</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a4f39c28c6cdd1d2da1918f5871bcba6e"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a4f39c28c6cdd1d2da1918f5871bcba6e">mlx::core::metal::Device::~Device</a></div><div class="ttdeci">~Device()</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a5fe3970fbe92ccc55fce4241ffbe5210"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a5fe3970fbe92ccc55fce4241ffbe5210">mlx::core::metal::Device::get_command_buffer</a></div><div class="ttdeci">MTL::CommandBuffer * get_command_buffer(int index)</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a60689f97347811b27e8c5ca23e0372bf"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a60689f97347811b27e8c5ca23e0372bf">mlx::core::metal::Device::end_encoding</a></div><div class="ttdeci">void end_encoding(int index)</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a65f64dd8bafdc704d871fc5be5e7bc0b"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a65f64dd8bafdc704d871fc5be5e7bc0b">mlx::core::metal::Device::get_architecture</a></div><div class="ttdeci">const std::string &amp; get_architecture()</div><div class="ttdef"><b>Definition</b> device.h:139</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a65f64dd8bafdc704d871fc5be5e7bc0b"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a65f64dd8bafdc704d871fc5be5e7bc0b">mlx::core::metal::Device::get_architecture</a></div><div class="ttdeci">const std::string &amp; get_architecture()</div><div class="ttdef"><b>Definition</b> device.h:142</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a6810c4dcbcfbf93fc51d42aa5ff0fc3a"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a6810c4dcbcfbf93fc51d42aa5ff0fc3a">mlx::core::metal::Device::get_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_kernel(const std::string &amp;base_name, MTL::Library *mtl_lib, const std::string &amp;hash_name=&quot;&quot;, const MTLFCList &amp;func_consts={}, const std::vector&lt; MTL::Function * &gt; &amp;linked_functions={})</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a6e33e2b1287324fb4a6575e0da5e5881"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a6e33e2b1287324fb4a6575e0da5e5881">mlx::core::metal::Device::argument_encoder</a></div><div class="ttdeci">MTL::ArgumentEncoder * argument_encoder(const std::vector&lt; MTL::ArgumentDescriptor * &gt; &amp;arg_descs) const</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a72ad17c96fc6ce825bc77f0bed657901"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a72ad17c96fc6ce825bc77f0bed657901">mlx::core::metal::Device::add_temporaries</a></div><div class="ttdeci">void add_temporaries(std::vector&lt; array &gt; arrays, int index)</div></div>
@ -383,7 +386,7 @@ $(function(){ initResizable(false); });
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a7a33d4d601423a3d3c23d5ad7072abb6"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a7a33d4d601423a3d3c23d5ad7072abb6">mlx::core::metal::Device::increment_command_buffer_ops</a></div><div class="ttdeci">void increment_command_buffer_ops(int index)</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a8135ae2a8c1e6f3861e84d4e60c28b67"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a8135ae2a8c1e6f3861e84d4e60c28b67">mlx::core::metal::Device::new_queue</a></div><div class="ttdeci">void new_queue(int index)</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a95248f1387824067fd4fed23ace5ac0c"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a95248f1387824067fd4fed23ace5ac0c">mlx::core::metal::Device::commit_command_buffer</a></div><div class="ttdeci">void commit_command_buffer(int index)</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a99ff72689b7beb65ad4541391b0eeabf"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a99ff72689b7beb65ad4541391b0eeabf">mlx::core::metal::Device::register_library</a></div><div class="ttdeci">void register_library(const std::string &amp;lib_name)</div><div class="ttdef"><b>Definition</b> device.h:157</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_a99ff72689b7beb65ad4541391b0eeabf"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#a99ff72689b7beb65ad4541391b0eeabf">mlx::core::metal::Device::register_library</a></div><div class="ttdeci">void register_library(const std::string &amp;lib_name)</div><div class="ttdef"><b>Definition</b> device.h:160</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_abf59a4addb5473f9e814e3651ba85f06"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#abf59a4addb5473f9e814e3651ba85f06">mlx::core::metal::Device::Device</a></div><div class="ttdeci">Device(const Device &amp;)=delete</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_acb90010af0cffe27fd8cc6c253d3a576"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#acb90010af0cffe27fd8cc6c253d3a576">mlx::core::metal::Device::add_temporary</a></div><div class="ttdeci">void add_temporary(array arr, int index)</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html_ad1d6382fd18a46b1906e1b43e0bd2e73"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html#ad1d6382fd18a46b1906e1b43e0bd2e73">mlx::core::metal::Device::operator=</a></div><div class="ttdeci">Device &amp; operator=(const Device &amp;)=delete</div></div>
@ -402,31 +405,32 @@ $(function(){ initResizable(false); });
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></div><div class="ttdef"><b>Definition</b> device.h:41</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a1e41477f2f489e38499f7830a91c9810"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a1e41477f2f489e38499f7830a91c9810">mlx::core::metal::CommandEncoder::dispatchThreads</a></div><div class="ttdeci">void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims)</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a2334774486f447213ee997e55c2e52a3"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a2334774486f447213ee997e55c2e52a3">mlx::core::metal::CommandEncoder::CommandEncoder</a></div><div class="ttdeci">CommandEncoder(MTL::CommandBuffer *cbuf)</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a27ded7e54bc1712063c874646b445509"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a27ded7e54bc1712063c874646b445509">mlx::core::metal::CommandEncoder::inputs</a></div><div class="ttdeci">std::unordered_set&lt; const void * &gt; &amp; inputs()</div><div class="ttdef"><b>Definition</b> device.h:76</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a27ded7e54bc1712063c874646b445509"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a27ded7e54bc1712063c874646b445509">mlx::core::metal::CommandEncoder::inputs</a></div><div class="ttdeci">std::unordered_set&lt; const void * &gt; &amp; inputs()</div><div class="ttdef"><b>Definition</b> device.h:77</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a3f42a1362b4a513fa89e7b3dcc570a8e"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a3f42a1362b4a513fa89e7b3dcc570a8e">mlx::core::metal::CommandEncoder::operator=</a></div><div class="ttdeci">CommandEncoder &amp; operator=(const CommandEncoder &amp;)=delete</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a48b548a0b15f9d1279c938a1c6167034"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a48b548a0b15f9d1279c938a1c6167034">mlx::core::metal::CommandEncoder::start_concurrent</a></div><div class="ttdeci">ConcurrentContext start_concurrent()</div><div class="ttdef"><b>Definition</b> device.h:70</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a48b548a0b15f9d1279c938a1c6167034"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a48b548a0b15f9d1279c938a1c6167034">mlx::core::metal::CommandEncoder::start_concurrent</a></div><div class="ttdeci">ConcurrentContext start_concurrent()</div><div class="ttdef"><b>Definition</b> device.h:71</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a6a2e28e542eaa2886041bddd51ff6522"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a6a2e28e542eaa2886041bddd51ff6522">mlx::core::metal::CommandEncoder::set_output_array</a></div><div class="ttdeci">void set_output_array(array &amp;a, int idx, int64_t offset=0)</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a74bcd8e35f80f5a62db48c4a2bb0173e"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a74bcd8e35f80f5a62db48c4a2bb0173e">mlx::core::metal::CommandEncoder::dispatchThreadgroups</a></div><div class="ttdeci">void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims)</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_a9b6dd221ccd2d939d544004cb6279198"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a9b6dd221ccd2d939d544004cb6279198">mlx::core::metal::CommandEncoder::~CommandEncoder</a></div><div class="ttdeci">~CommandEncoder()</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_aac45ab0630ea32cf7d15c7ba3e229966"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aac45ab0630ea32cf7d15c7ba3e229966">mlx::core::metal::CommandEncoder::operator-&gt;</a></div><div class="ttdeci">MTL::ComputeCommandEncoder * operator-&gt;()</div><div class="ttdef"><b>Definition</b> device.h:61</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_ab69ff0d7f14b9b59db4df0608193dce4"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ab69ff0d7f14b9b59db4df0608193dce4">mlx::core::metal::CommandEncoder::set_input_array</a></div><div class="ttdeci">void set_input_array(const array &amp;a, int idx, int64_t offset=0)</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_ac68ca977b5bde5434284ce7979647f14"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ac68ca977b5bde5434284ce7979647f14">mlx::core::metal::CommandEncoder::CommandEncoder</a></div><div class="ttdeci">CommandEncoder(const CommandEncoder &amp;)=delete</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_aefa48740fdee884f02e2d379bca4e78f"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aefa48740fdee884f02e2d379bca4e78f">mlx::core::metal::CommandEncoder::outputs</a></div><div class="ttdeci">std::unordered_set&lt; const void * &gt; outputs()</div><div class="ttdef"><b>Definition</b> device.h:81</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html">mlx::core::metal::DeviceStream</a></div><div class="ttdef"><b>Definition</b> device.h:102</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a1c4397732f64f5811381dd01e30e020e"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a1c4397732f64f5811381dd01e30e020e">mlx::core::metal::DeviceStream::~DeviceStream</a></div><div class="ttdeci">~DeviceStream()</div><div class="ttdef"><b>Definition</b> device.h:104</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a55a7a92c6abad369c99a5ede7a2521b9"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a55a7a92c6abad369c99a5ede7a2521b9">mlx::core::metal::DeviceStream::outputs</a></div><div class="ttdeci">std::unordered_map&lt; const void *, std::shared_ptr&lt; Fence &gt; &gt; outputs</div><div class="ttdef"><b>Definition</b> device.h:112</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a573326bc8b48e39076850c7bf52ad0d7"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a573326bc8b48e39076850c7bf52ad0d7">mlx::core::metal::DeviceStream::DeviceStream</a></div><div class="ttdeci">DeviceStream(MTL::CommandQueue *queue)</div><div class="ttdef"><b>Definition</b> device.h:103</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a58e435217b9922f882507ebf48bfbbdd"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a58e435217b9922f882507ebf48bfbbdd">mlx::core::metal::DeviceStream::encoder</a></div><div class="ttdeci">std::unique_ptr&lt; CommandEncoder &gt; encoder</div><div class="ttdef"><b>Definition</b> device.h:123</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a6fa08cca881fc3798ae45994a11a4fcd"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a6fa08cca881fc3798ae45994a11a4fcd">mlx::core::metal::DeviceStream::fence_mtx</a></div><div class="ttdeci">std::mutex fence_mtx</div><div class="ttdef"><b>Definition</b> device.h:114</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a77c75a63c51ea56815a86bd882ed190d"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">mlx::core::metal::DeviceStream::queue</a></div><div class="ttdeci">MTL::CommandQueue * queue</div><div class="ttdef"><b>Definition</b> device.h:110</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a876199de8da1efa9a362451029638499"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a876199de8da1efa9a362451029638499">mlx::core::metal::DeviceStream::fence</a></div><div class="ttdeci">std::shared_ptr&lt; Fence &gt; fence</div><div class="ttdef"><b>Definition</b> device.h:124</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a99183c92599edfeb75f7fa0f37e1d9eb"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb">mlx::core::metal::DeviceStream::buffer</a></div><div class="ttdeci">MTL::CommandBuffer * buffer</div><div class="ttdef"><b>Definition</b> device.h:118</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_ab6048b329e65a59033834f3bdd351782"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#ab6048b329e65a59033834f3bdd351782">mlx::core::metal::DeviceStream::buffer_ops</a></div><div class="ttdeci">int buffer_ops</div><div class="ttdef"><b>Definition</b> device.h:119</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_aee88009117dfff1ad121eabe28d5f3de"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#aee88009117dfff1ad121eabe28d5f3de">mlx::core::metal::DeviceStream::temporaries</a></div><div class="ttdeci">std::vector&lt; array &gt; temporaries</div><div class="ttdef"><b>Definition</b> device.h:125</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_fence_html"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_fence.html">mlx::core::metal::Fence</a></div><div class="ttdef"><b>Definition</b> device.h:94</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_fence_html_a30bee4957ae595e04922952a8010fc79"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_fence.html#a30bee4957ae595e04922952a8010fc79">mlx::core::metal::Fence::Fence</a></div><div class="ttdeci">Fence(MTL::Fence *fence)</div><div class="ttdef"><b>Definition</b> device.h:95</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_fence_html_a4940c1aece13814af7727de9abb511f2"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_fence.html#a4940c1aece13814af7727de9abb511f2">mlx::core::metal::Fence::~Fence</a></div><div class="ttdeci">~Fence()</div><div class="ttdef"><b>Definition</b> device.h:96</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_fence_html_aeccd8f2b81418ae9fc446ae2b6e15b87"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">mlx::core::metal::Fence::fence</a></div><div class="ttdeci">MTL::Fence * fence</div><div class="ttdef"><b>Definition</b> device.h:99</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_ad538ae88f90560063f9ba502e2795991"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ad538ae88f90560063f9ba502e2795991">mlx::core::metal::CommandEncoder::maybeInsertBarrier</a></div><div class="ttdeci">void maybeInsertBarrier()</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_command_encoder_html_aefa48740fdee884f02e2d379bca4e78f"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aefa48740fdee884f02e2d379bca4e78f">mlx::core::metal::CommandEncoder::outputs</a></div><div class="ttdeci">std::unordered_set&lt; const void * &gt; outputs()</div><div class="ttdef"><b>Definition</b> device.h:82</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html">mlx::core::metal::DeviceStream</a></div><div class="ttdef"><b>Definition</b> device.h:105</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a1c4397732f64f5811381dd01e30e020e"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a1c4397732f64f5811381dd01e30e020e">mlx::core::metal::DeviceStream::~DeviceStream</a></div><div class="ttdeci">~DeviceStream()</div><div class="ttdef"><b>Definition</b> device.h:107</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a55a7a92c6abad369c99a5ede7a2521b9"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a55a7a92c6abad369c99a5ede7a2521b9">mlx::core::metal::DeviceStream::outputs</a></div><div class="ttdeci">std::unordered_map&lt; const void *, std::shared_ptr&lt; Fence &gt; &gt; outputs</div><div class="ttdef"><b>Definition</b> device.h:115</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a573326bc8b48e39076850c7bf52ad0d7"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a573326bc8b48e39076850c7bf52ad0d7">mlx::core::metal::DeviceStream::DeviceStream</a></div><div class="ttdeci">DeviceStream(MTL::CommandQueue *queue)</div><div class="ttdef"><b>Definition</b> device.h:106</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a58e435217b9922f882507ebf48bfbbdd"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a58e435217b9922f882507ebf48bfbbdd">mlx::core::metal::DeviceStream::encoder</a></div><div class="ttdeci">std::unique_ptr&lt; CommandEncoder &gt; encoder</div><div class="ttdef"><b>Definition</b> device.h:126</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a6fa08cca881fc3798ae45994a11a4fcd"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a6fa08cca881fc3798ae45994a11a4fcd">mlx::core::metal::DeviceStream::fence_mtx</a></div><div class="ttdeci">std::mutex fence_mtx</div><div class="ttdef"><b>Definition</b> device.h:117</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a77c75a63c51ea56815a86bd882ed190d"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d">mlx::core::metal::DeviceStream::queue</a></div><div class="ttdeci">MTL::CommandQueue * queue</div><div class="ttdef"><b>Definition</b> device.h:113</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a876199de8da1efa9a362451029638499"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a876199de8da1efa9a362451029638499">mlx::core::metal::DeviceStream::fence</a></div><div class="ttdeci">std::shared_ptr&lt; Fence &gt; fence</div><div class="ttdef"><b>Definition</b> device.h:127</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_a99183c92599edfeb75f7fa0f37e1d9eb"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#a99183c92599edfeb75f7fa0f37e1d9eb">mlx::core::metal::DeviceStream::buffer</a></div><div class="ttdeci">MTL::CommandBuffer * buffer</div><div class="ttdef"><b>Definition</b> device.h:121</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_ab6048b329e65a59033834f3bdd351782"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#ab6048b329e65a59033834f3bdd351782">mlx::core::metal::DeviceStream::buffer_ops</a></div><div class="ttdeci">int buffer_ops</div><div class="ttdef"><b>Definition</b> device.h:122</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_device_stream_html_aee88009117dfff1ad121eabe28d5f3de"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_device_stream.html#aee88009117dfff1ad121eabe28d5f3de">mlx::core::metal::DeviceStream::temporaries</a></div><div class="ttdeci">std::vector&lt; array &gt; temporaries</div><div class="ttdef"><b>Definition</b> device.h:128</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_fence_html"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_fence.html">mlx::core::metal::Fence</a></div><div class="ttdef"><b>Definition</b> device.h:97</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_fence_html_a30bee4957ae595e04922952a8010fc79"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_fence.html#a30bee4957ae595e04922952a8010fc79">mlx::core::metal::Fence::Fence</a></div><div class="ttdeci">Fence(MTL::Fence *fence)</div><div class="ttdef"><b>Definition</b> device.h:98</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_fence_html_a4940c1aece13814af7727de9abb511f2"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_fence.html#a4940c1aece13814af7727de9abb511f2">mlx::core::metal::Fence::~Fence</a></div><div class="ttdeci">~Fence()</div><div class="ttdef"><b>Definition</b> device.h:99</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1metal_1_1_fence_html_aeccd8f2b81418ae9fc446ae2b6e15b87"><div class="ttname"><a href="structmlx_1_1core_1_1metal_1_1_fence.html#aeccd8f2b81418ae9fc446ae2b6e15b87">mlx::core::metal::Fence::fence</a></div><div class="ttdeci">MTL::Fence * fence</div><div class="ttdef"><b>Definition</b> device.h:102</div></div>
</div><!-- fragment --></div><!-- contents -->
<!-- start footer part -->
<hr class="footer"/><address class="footer"><small>

View File

@ -864,7 +864,7 @@
<article class="bd-article">
<section id="custom-metal-kernels">
<h1>Custom Metal Kernels<a class="headerlink" href="#custom-metal-kernels" title="Link to this heading">#</a></h1>
<span id="id1"></span><h1>Custom Metal Kernels<a class="headerlink" href="#custom-metal-kernels" title="Link to this heading">#</a></h1>
<p>MLX supports writing custom Metal kernels through the Python and C++ APIs.</p>
<section id="simple-example">
<h2>Simple Example<a class="headerlink" href="#simple-example" title="Link to this heading">#</a></h2>
@ -947,6 +947,9 @@ All the attributes defined in Table 5.8 of the <a class="reference external" hre
<span class="k">template</span><span class="w"> </span><span class="p">[[</span><span class="n">host_name</span><span class="p">(</span><span class="s">&quot;custom_kernel_myexp_float&quot;</span><span class="p">)]]</span><span class="w"> </span><span class="p">[[</span><span class="n">kernel</span><span class="p">]]</span><span class="w"> </span><span class="k">decltype</span><span class="p">(</span><span class="n">custom_kernel_myexp_float</span><span class="o">&lt;</span><span class="kt">float</span><span class="o">&gt;</span><span class="p">)</span><span class="w"> </span><span class="n">custom_kernel_myexp_float</span><span class="o">&lt;</span><span class="kt">float</span><span class="o">&gt;</span><span class="p">;</span>
</pre></div>
</div>
<p>Note: <code class="docutils literal notranslate"><span class="pre">grid</span></code> and <code class="docutils literal notranslate"><span class="pre">threadgroup</span></code> are parameters to the Metal <a class="reference external" href="https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/2866532-dispatchthreads">dispatchThreads</a> function.
This means we will launch <code class="docutils literal notranslate"><span class="pre">mx.prod(grid)</span></code> threads, subdivided into <code class="docutils literal notranslate"><span class="pre">threadgroup</span></code> size threadgroups.
For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.</p>
<p>Passing <code class="docutils literal notranslate"><span class="pre">verbose=True</span></code> to <code class="docutils literal notranslate"><span class="pre">mx.fast.metal_kernel.__call__</span></code> will print the generated code for debugging purposes.</p>
</section>
<section id="using-shape-strides">

View File

@ -4330,9 +4330,9 @@
<a href="kernels_8h.html#a195b86cad5bb99aa1bcd23952305af6b"/>
<a href="kernels_8h.html#a1d4cffc3c78067b3d9a62d64f3fb686f"/>
<a href="kernels_8h.html#a35a412f688d79eb47e42d20a7c8650ee"/>
<a href="kernels_8h.html#a3bd386cb6db09f636963ce66ceaf8647"/>
<a href="kernels_8h.html#a4decd4a07d91487e6903f6e3c8b7513a"/>
<a href="kernels_8h.html#a4e809746f48e5dcf7fa63215d3f5e33e"/>
<a href="kernels_8h.html#a51c4bb09230348bd0252e22bfdc9bc89"/>
<a href="kernels_8h.html#a54eb3b65375022428aab5f810e40624b"/>
<a href="kernels_8h.html#a76f614e9956a6ca05a9be4db5a483446"/>
<a href="kernels_8h.html#a7aa91fcfe8b9caa42d60a957f11bfe6b"/>
@ -4443,9 +4443,9 @@
<a href="metal_2kernels_2unary_8h.html#a7c7690f0df9d2acc60b63be58d9c7777"/>
<a href="metal_2kernels_2unary_8h.html#ac965f8d3ed62f8580dbfb645e83d4ae5"/>
<a href="metal_2reduce_8h.html"/>
<a href="metal_2reduce_8h.html#a3ab0fd997d9a35782106ff083a72e098"/>
<a href="metal_2reduce_8h.html#aa0332c64ee9965f05026c30a0b778000"/>
<a href="metal_2reduce_8h.html#ab1eeca8ec6fa31819ee108fa6ed2c41b"/>
<a href="metal_2reduce_8h.html#af7b7ca7c6aa87558d9f98cee5c7a99a8"/>
<a href="metal_2slicing_8h.html"/>
<a href="metal_2slicing_8h.html#a050299d0d366ca5c9d09d1004dcc3e7d"/>
<a href="metal_2slicing_8h.html#a59048c5ff114c101a496bf33f62e3de9"/>
@ -4850,9 +4850,11 @@
<a href="namespacemlx_1_1core.html#a3a6f43c2485f0d42293184f1aecbeaee"/>
<a href="namespacemlx_1_1core.html#a3a8f6f0af477788c4f0aa98abfc5f1ab"/>
<a href="namespacemlx_1_1core.html#a3a8fe7ba84714dbb5fdc81e93a07abc8"/>
<a href="namespacemlx_1_1core.html#a3ab0fd997d9a35782106ff083a72e098"/>
<a href="namespacemlx_1_1core.html#a3ac798e65e59fe10b7fb5c522efce782"/>
<a href="namespacemlx_1_1core.html#a3b900ab319948c5a01a3ecd30a709027"/>
<a href="namespacemlx_1_1core.html#a3ba20a804c306067b7023259429e0e48"/>
<a href="namespacemlx_1_1core.html#a3bd386cb6db09f636963ce66ceaf8647"/>
<a href="namespacemlx_1_1core.html#a3c41a304126bc225bdc68062d1eb6e7e"/>
<a href="namespacemlx_1_1core.html#a3cc5c154e4ad9a83ad43da8513146fdc"/>
<a href="namespacemlx_1_1core.html#a3d2b2929ed4636e9e2b86e125b2e57d9"/>
@ -4900,7 +4902,6 @@
<a href="namespacemlx_1_1core.html#a514263e63f6825b490203ca586864687"/>
<a href="namespacemlx_1_1core.html#a514cf8b4e6f0a6af3a867e752f4338f7"/>
<a href="namespacemlx_1_1core.html#a517019d42d4e426b7b98e1c719bb47ce"/>
<a href="namespacemlx_1_1core.html#a51c4bb09230348bd0252e22bfdc9bc89"/>
<a href="namespacemlx_1_1core.html#a5287610200ff573730c9c92413f48881"/>
<a href="namespacemlx_1_1core.html#a54833be1d44bc3adfc9ea218fc3685bd"/>
<a href="namespacemlx_1_1core.html#a54863a54f258acf2b5c734950618e4e1"/>
@ -5272,7 +5273,6 @@
<a href="namespacemlx_1_1core.html#af69db7def588d7da430434a69456e29c"/>
<a href="namespacemlx_1_1core.html#af7577c91b8c43682f0ebc9eb9758aae4"/>
<a href="namespacemlx_1_1core.html#af776fd91dd60594dcfebbafd17f19068"/>
<a href="namespacemlx_1_1core.html#af7b7ca7c6aa87558d9f98cee5c7a99a8"/>
<a href="namespacemlx_1_1core.html#af7eea1682a38d363c56a066321e6d526"/>
<a href="namespacemlx_1_1core.html#af810587a17e692f4eec256d3c3cd27de"/>
<a href="namespacemlx_1_1core.html#af84ed854132c1514dca5a524fdb7ed05"/>
@ -5933,11 +5933,11 @@
<a href="quantized_8h.html#a0386011c52d03e60885a31e6fbd903dd"/>
<a href="quantized_8h.html#a07b26d2d0b0d65dfe925c452c453fa42"/>
<a href="quantized_8h.html#a0ba59096494f1001c195312571523ae9"/>
<a href="quantized_8h.html#a1546533c5b925b2fbb3bec870ec7487a"/>
<a href="quantized_8h.html#a1a66b061c46383952a0f067c3848971f"/>
<a href="quantized_8h.html#a2ce135e392dbf9a3e5180fb083792ed7"/>
<a href="quantized_8h.html#a3ab400746ad77be89c30d25638e01698"/>
<a href="quantized_8h.html#a47bcf4a14566e01e14bd3c155811db59"/>
<a href="quantized_8h.html#a4a8c8db7d5d480733726fd6d1a645e12"/>
<a href="quantized_8h.html#a530b720e123e59d73ea89a0a2d0946b7"/>
<a href="quantized_8h.html#a6076203615038eb06816158f7b3869c6"/>
<a href="quantized_8h.html#a62969a218d93680f5e35d0c61b160b99"/>
@ -5952,6 +5952,7 @@
<a href="quantized_8h.html#aa69e143d646fad332c1a53e8c9b337b7"/>
<a href="quantized_8h.html#ab1ae143eba2afceb8df63f38b26f9a84"/>
<a href="quantized_8h.html#ab364d58ab652e3ad87a8f80910556071"/>
<a href="quantized_8h.html#ab8243818512d6078d23e6ffb65fd7bb8"/>
<a href="quantized_8h.html#aba7687e6f8f1d29c0a1b2a3db150bd81"/>
<a href="quantized_8h.html#abe2e3ef0ee4ec2cb61dc5330ad463d10"/>
<a href="quantized_8h.html#accab1f9e17a65242347c051f98e4c0be"/>
@ -6016,8 +6017,10 @@
<a href="reduce__all_8h.html"/>
<a href="reduce__all_8h.html#a99ef48ae72b3e715c5f4d7ea07cd213d"/>
<a href="reduce__col_8h.html"/>
<a href="reduce__col_8h.html#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d"/>
<a href="reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385"/>
<a href="reduce__col_8h.html#adf7aeb18cd1d5042cf6d9b46b582d8ce"/>
<a href="reduce__col_8h.html#a5b4f4c4c247ad341ff8d31dcbbbce0eb"/>
<a href="reduce__col_8h.html#a7c378443a2b6f4d9210db8a21a9ac4f5"/>
<a href="reduce__init_8h.html"/>
<a href="reduce__init_8h.html#a0088604ac2eaa6940689ff12c4ba5fc2"/>
<a href="reduce__row_8h.html"/>
@ -6051,7 +6054,7 @@
<a href="scheduler_8h.html#aa2d4eacf5d5cbc778a51aafd4fd8e4d7"/>
<a href="scheduler_8h.html#ae856e468c2f7c8f8ec672522cc13730b"/>
<a href="sdpa__vector_8h.html"/>
<a href="sdpa__vector_8h.html#a6f0d7918430064bab910bdaa6c64e927"/>
<a href="sdpa__vector_8h.html#a4bf36f16e16c1c62d9b243573568e5ae"/>
<a href="sort_8h.html"/>
<a href="sort_8h.html#a0386011c52d03e60885a31e6fbd903dd"/>
<a href="sort_8h.html#a32cbe4163b8b0f5cb2c97b256119a4b2"/>
@ -6929,6 +6932,7 @@
<a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aac45ab0630ea32cf7d15c7ba3e229966"/>
<a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ab69ff0d7f14b9b59db4df0608193dce4"/>
<a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ac68ca977b5bde5434284ce7979647f14"/>
<a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ad538ae88f90560063f9ba502e2795991"/>
<a href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aefa48740fdee884f02e2d379bca4e78f"/>
<a href="structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html"/>
<a href="structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html#a28bafec56edec3091e8716d8ccfb6ee1"/>

View File

@ -93,6 +93,7 @@ $(function(){ initResizable(false); });
<li>Matmul()&#160;:&#160;<a class="el" href="classmlx_1_1core_1_1_matmul.html#adef92f30ab35e540ccb316ea6b94e6f7">mlx::core::Matmul</a></li>
<li>max()&#160;:&#160;<a class="el" href="structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#a92320d40a58218e40cc414986ac95c50">metal::_numeric_limits_impl&lt; bfloat16_t &gt;</a></li>
<li>Maximum()&#160;:&#160;<a class="el" href="classmlx_1_1core_1_1_maximum.html#a28389307e385efe1b2955b86b115e816">mlx::core::Maximum</a></li>
<li>maybeInsertBarrier()&#160;:&#160;<a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ad538ae88f90560063f9ba502e2795991">mlx::core::metal::CommandEncoder</a></li>
<li>merge_partition()&#160;:&#160;<a class="el" href="struct_block_merge_sort.html#ab2300cbecb23f3433bad888924c831ca">BlockMergeSort&lt; val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp &gt;</a>, <a class="el" href="struct_kernel_multi_block_merge_sort.html#ab15895b4233aba0e279cc44a07a201fe">KernelMultiBlockMergeSort&lt; val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp &gt;</a></li>
<li>merge_step()&#160;:&#160;<a class="el" href="struct_block_merge_sort.html#ab65f190edf1851b37c39ad49ce99a43c">BlockMergeSort&lt; val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp &gt;</a></li>
<li>min()&#160;:&#160;<a class="el" href="structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#adaed80031f5ca0ff69d30ec4c5d0c98f">metal::_numeric_limits_impl&lt; bfloat16_t &gt;</a></li>

View File

@ -102,6 +102,7 @@ $(function(){ initResizable(false); });
<li>max_exponent&#160;:&#160;<a class="el" href="structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#a61bb136f819fa392c50bdf3c38f3aad2">metal::_numeric_limits_impl&lt; bfloat16_t &gt;</a></li>
<li>max_exponent10&#160;:&#160;<a class="el" href="structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#a76bfb2deb0e0afc011f77bf5a6d0ed94">metal::_numeric_limits_impl&lt; bfloat16_t &gt;</a></li>
<li>Maximum()&#160;:&#160;<a class="el" href="classmlx_1_1core_1_1_maximum.html#a28389307e385efe1b2955b86b115e816">mlx::core::Maximum</a></li>
<li>maybeInsertBarrier()&#160;:&#160;<a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ad538ae88f90560063f9ba502e2795991">mlx::core::metal::CommandEncoder</a></li>
<li>merge_partition()&#160;:&#160;<a class="el" href="struct_block_merge_sort.html#ab2300cbecb23f3433bad888924c831ca">BlockMergeSort&lt; val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp &gt;</a>, <a class="el" href="struct_kernel_multi_block_merge_sort.html#ab15895b4233aba0e279cc44a07a201fe">KernelMultiBlockMergeSort&lt; val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp &gt;</a></li>
<li>merge_step()&#160;:&#160;<a class="el" href="struct_block_merge_sort.html#ab65f190edf1851b37c39ad49ce99a43c">BlockMergeSort&lt; val_t, idx_t, ARG_SORT, BLOCK_THREADS, N_PER_THREAD, CompareOp &gt;</a></li>
<li>Min&#160;:&#160;<a class="el" href="classmlx_1_1core_1_1distributed_1_1_all_reduce.html#abb4560980e5d01aed14175ce8f6fc924a4f685dcd48e6614d6bb2ccda4f2686ef">mlx::core::distributed::AllReduce</a>, <a class="el" href="classmlx_1_1core_1_1_reduce.html#a0848518b16ae6d4043d6be247bdf31c9a0d3d1f5c94725bdc42fa692e2c074418">mlx::core::Reduce</a>, <a class="el" href="classmlx_1_1core_1_1_scan.html#a47bf2ec54ead4b8f00f9f188518630f1a7d2ee8f14f2e70a9d47170fecc6da898">mlx::core::Scan</a>, <a class="el" href="classmlx_1_1core_1_1_scatter.html#a614d19af11dc30644b2b4941033b613cad914e4c3475ce9858f2de4bf35dcfdbf">mlx::core::Scatter</a></li>

View File

@ -92,8 +92,10 @@ $(function(){ initResizable(false); });
<li>can_convert_to_bfloat&#160;:&#160;<a class="el" href="backend_2metal_2kernels_2bf16_8h.html#aae77817d261452b2f001f4d947a3e04e">bf16.h</a></li>
<li>can_convert_to_complex64&#160;:&#160;<a class="el" href="backend_2metal_2kernels_2complex_8h.html#a4f90ad54f4fae363e8d3cc41d539557b">complex.h</a></li>
<li>ceildiv()&#160;:&#160;<a class="el" href="backend_2metal_2kernels_2utils_8h.html#a8e5a4b0fb5d018d7b078d147efe4f1e3">utils.h</a></li>
<li>col_reduce_2pass()&#160;:&#160;<a class="el" href="reduce__col_8h.html#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d">reduce_col.h</a></li>
<li>col_reduce_longcolumn()&#160;:&#160;<a class="el" href="reduce__col_8h.html#a5b4f4c4c247ad341ff8d31dcbbbce0eb">reduce_col.h</a></li>
<li>col_reduce_looped()&#160;:&#160;<a class="el" href="reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385">reduce_col.h</a></li>
<li>col_reduce_small()&#160;:&#160;<a class="el" href="reduce__col_8h.html#adf7aeb18cd1d5042cf6d9b46b582d8ce">reduce_col.h</a></li>
<li>col_reduce_small()&#160;:&#160;<a class="el" href="reduce__col_8h.html#a7c378443a2b6f4d9210db8a21a9ac4f5">reduce_col.h</a></li>
<li>complex_binop&#160;:&#160;<a class="el" href="types_2complex_8h.html#a9c7995d495359894e1b30c0f1678d6bd">complex.h</a></li>
<li>complex_binop_helper&#160;:&#160;<a class="el" href="types_2complex_8h.html#ac6890f9852de12339b09b65757ebc8c4">complex.h</a></li>
<li>complex_mul()&#160;:&#160;<a class="el" href="radix_8h.html#a5bfc53b531214c9ce277bebc18aa67d6">radix.h</a></li>

View File

@ -88,8 +88,10 @@ $(function(){ initResizable(false); });
<h3><a id="index_c" name="index_c"></a>- c -</h3><ul>
<li>ceildiv()&#160;:&#160;<a class="el" href="backend_2metal_2kernels_2utils_8h.html#a8e5a4b0fb5d018d7b078d147efe4f1e3">utils.h</a></li>
<li>col_reduce_2pass()&#160;:&#160;<a class="el" href="reduce__col_8h.html#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d">reduce_col.h</a></li>
<li>col_reduce_longcolumn()&#160;:&#160;<a class="el" href="reduce__col_8h.html#a5b4f4c4c247ad341ff8d31dcbbbce0eb">reduce_col.h</a></li>
<li>col_reduce_looped()&#160;:&#160;<a class="el" href="reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385">reduce_col.h</a></li>
<li>col_reduce_small()&#160;:&#160;<a class="el" href="reduce__col_8h.html#adf7aeb18cd1d5042cf6d9b46b582d8ce">reduce_col.h</a></li>
<li>col_reduce_small()&#160;:&#160;<a class="el" href="reduce__col_8h.html#a7c378443a2b6f4d9210db8a21a9ac4f5">reduce_col.h</a></li>
<li>complex_mul()&#160;:&#160;<a class="el" href="radix_8h.html#a5bfc53b531214c9ce277bebc18aa67d6">radix.h</a></li>
<li>complex_mul_conj()&#160;:&#160;<a class="el" href="radix_8h.html#a0e2dfd3d1dda09f47ccc64eec35629f3">radix.h</a></li>
<li>contiguous_scan()&#160;:&#160;<a class="el" href="scan_8h.html#a60d279b9add7d56639bb209408f09d79">scan.h</a></li>

View File

@ -101,7 +101,8 @@ $(function(){ initResizable(false); });
<li>qmv_quad_impl()&#160;:&#160;<a class="el" href="quantized_8h.html#ad5cf1cf63656bc1780685d22169cd4ef">quantized.h</a></li>
<li>qouter()&#160;:&#160;<a class="el" href="quantized_8h.html#ae756f6817b584c60f5dcdd1d9c6b4f58">quantized.h</a></li>
<li>qvm()&#160;:&#160;<a class="el" href="quantized_8h.html#ad84f7d5ab9e32dbbe3ca759ae5d5d5c5">quantized.h</a></li>
<li>qvm_impl()&#160;:&#160;<a class="el" href="quantized_8h.html#a4a8c8db7d5d480733726fd6d1a645e12">quantized.h</a></li>
<li>qvm_impl()&#160;:&#160;<a class="el" href="quantized_8h.html#a1546533c5b925b2fbb3bec870ec7487a">quantized.h</a></li>
<li>qvm_split_k()&#160;:&#160;<a class="el" href="quantized_8h.html#ab8243818512d6078d23e6ffb65fd7bb8">quantized.h</a></li>
</ul>
</div><!-- contents -->
<!-- start footer part -->

View File

@ -88,7 +88,7 @@ $(function(){ initResizable(false); });
<h3><a id="index_s" name="index_s"></a>- s -</h3><ul>
<li>scatter_impl()&#160;:&#160;<a class="el" href="scatter_8h.html#ad1ce39d0b6d733a95e739121fcc61bd1">scatter.h</a></li>
<li>sdpa_vector()&#160;:&#160;<a class="el" href="sdpa__vector_8h.html#a6f0d7918430064bab910bdaa6c64e927">sdpa_vector.h</a></li>
<li>sdpa_vector()&#160;:&#160;<a class="el" href="sdpa__vector_8h.html#a4bf36f16e16c1c62d9b243573568e5ae">sdpa_vector.h</a></li>
<li>simd_shuffle()&#160;:&#160;<a class="el" href="backend_2metal_2kernels_2utils_8h.html#a71986ecdd7d18f975dd22c3df7421ce2">utils.h</a></li>
<li>simd_shuffle_and_fill_up()&#160;:&#160;<a class="el" href="backend_2metal_2kernels_2utils_8h.html#a5862d5ea154c9b76cf56a630cf6385b4">utils.h</a></li>
<li>simd_shuffle_down()&#160;:&#160;<a class="el" href="backend_2metal_2kernels_2utils_8h.html#aba6279624b1d30c525efee856a222b5c">utils.h</a></li>

View File

@ -102,7 +102,8 @@ $(function(){ initResizable(false); });
<li>qouter()&#160;:&#160;<a class="el" href="quantized_8h.html#ae756f6817b584c60f5dcdd1d9c6b4f58">quantized.h</a></li>
<li>QUAD_SIZE&#160;:&#160;<a class="el" href="quantized_8h.html#a803e4d5a1459844ba647aea5b004e133">quantized.h</a></li>
<li>qvm()&#160;:&#160;<a class="el" href="quantized_8h.html#ad84f7d5ab9e32dbbe3ca759ae5d5d5c5">quantized.h</a></li>
<li>qvm_impl()&#160;:&#160;<a class="el" href="quantized_8h.html#a4a8c8db7d5d480733726fd6d1a645e12">quantized.h</a></li>
<li>qvm_impl()&#160;:&#160;<a class="el" href="quantized_8h.html#a1546533c5b925b2fbb3bec870ec7487a">quantized.h</a></li>
<li>qvm_split_k()&#160;:&#160;<a class="el" href="quantized_8h.html#ab8243818512d6078d23e6ffb65fd7bb8">quantized.h</a></li>
</ul>
</div><!-- contents -->
<!-- start footer part -->

View File

@ -89,7 +89,7 @@ $(function(){ initResizable(false); });
<h3><a id="index_s" name="index_s"></a>- s -</h3><ul>
<li>scatter_impl()&#160;:&#160;<a class="el" href="scatter_8h.html#ad1ce39d0b6d733a95e739121fcc61bd1">scatter.h</a></li>
<li>scatter_kernels&#160;:&#160;<a class="el" href="jit_2indexing_8h.html#a768c949cd650a44c6b402fc1440c1a56">indexing.h</a></li>
<li>sdpa_vector()&#160;:&#160;<a class="el" href="sdpa__vector_8h.html#a6f0d7918430064bab910bdaa6c64e927">sdpa_vector.h</a></li>
<li>sdpa_vector()&#160;:&#160;<a class="el" href="sdpa__vector_8h.html#a4bf36f16e16c1c62d9b243573568e5ae">sdpa_vector.h</a></li>
<li>simd_shuffle()&#160;:&#160;<a class="el" href="backend_2metal_2kernels_2utils_8h.html#a71986ecdd7d18f975dd22c3df7421ce2">utils.h</a></li>
<li>simd_shuffle_and_fill_up()&#160;:&#160;<a class="el" href="backend_2metal_2kernels_2utils_8h.html#a5862d5ea154c9b76cf56a630cf6385b4">utils.h</a></li>
<li>simd_shuffle_down()&#160;:&#160;<a class="el" href="backend_2metal_2kernels_2utils_8h.html#aba6279624b1d30c525efee856a222b5c">utils.h</a></li>

View File

@ -129,8 +129,8 @@ Functions</h2></td></tr>
<tr class="separator:a84ebe6275218070f0ea320f126f64e22"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:afb57825bb763050cc9a9d194aa41ac36" id="r_afb57825bb763050cc9a9d194aa41ac36"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState *&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="namespacemlx_1_1core.html#afb57825bb763050cc9a9d194aa41ac36">mlx::core::get_mb_sort_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const std::string &amp;kernel_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;in, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;idx, int bn, int tn)</td></tr>
<tr class="separator:afb57825bb763050cc9a9d194aa41ac36"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a51c4bb09230348bd0252e22bfdc9bc89" id="r_a51c4bb09230348bd0252e22bfdc9bc89"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState *&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="namespacemlx_1_1core.html#a51c4bb09230348bd0252e22bfdc9bc89">mlx::core::get_reduce_init_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const std::string &amp;kernel_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out)</td></tr>
<tr class="separator:a51c4bb09230348bd0252e22bfdc9bc89"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a3bd386cb6db09f636963ce66ceaf8647" id="r_a3bd386cb6db09f636963ce66ceaf8647"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState *&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="namespacemlx_1_1core.html#a3bd386cb6db09f636963ce66ceaf8647">mlx::core::get_reduce_init_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const std::string &amp;kernel_name, const std::string &amp;func_name, const std::string &amp;op_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out)</td></tr>
<tr class="separator:a3bd386cb6db09f636963ce66ceaf8647"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a7aa91fcfe8b9caa42d60a957f11bfe6b" id="r_a7aa91fcfe8b9caa42d60a957f11bfe6b"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState *&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b">mlx::core::get_reduce_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const std::string &amp;kernel_name, const std::string &amp;func_name, const std::string &amp;op_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;in, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, int ndim=-1, int bm=-1, int bn=-1)</td></tr>
<tr class="separator:a7aa91fcfe8b9caa42d60a957f11bfe6b"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a84fa8e0aee321a9d614433a0b933103b" id="r_a84fa8e0aee321a9d614433a0b933103b"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState *&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="namespacemlx_1_1core.html#a84fa8e0aee321a9d614433a0b933103b">mlx::core::get_steel_gemm_fused_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const std::string &amp;kernel_name, const std::string &amp;hash_name, const <a class="el" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">metal::MTLFCList</a> &amp;func_consts, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)</td></tr>

View File

@ -169,152 +169,154 @@ $(function(){ initResizable(false); });
<div class="line"><a id="l00076" name="l00076"></a><span class="lineno"> 76</span> <span class="keywordtype">int</span> bn,</div>
<div class="line"><a id="l00077" name="l00077"></a><span class="lineno"> 77</span> <span class="keywordtype">int</span> tn);</div>
<div class="line"><a id="l00078" name="l00078"></a><span class="lineno"> 78</span> </div>
<div class="line"><a id="l00079" name="l00079"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a51c4bb09230348bd0252e22bfdc9bc89"> 79</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a51c4bb09230348bd0252e22bfdc9bc89">get_reduce_init_kernel</a>(</div>
<div class="line"><a id="l00079" name="l00079"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a3bd386cb6db09f636963ce66ceaf8647"> 79</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a3bd386cb6db09f636963ce66ceaf8647">get_reduce_init_kernel</a>(</div>
<div class="line"><a id="l00080" name="l00080"></a><span class="lineno"> 80</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00081" name="l00081"></a><span class="lineno"> 81</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00082" name="l00082"></a><span class="lineno"> 82</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out);</div>
<div class="line"><a id="l00083" name="l00083"></a><span class="lineno"> 83</span> </div>
<div class="line"><a id="l00084" name="l00084"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b"> 84</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b">get_reduce_kernel</a>(</div>
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"> 85</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"> 86</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00087" name="l00087"></a><span class="lineno"> 87</span> <span class="keyword">const</span> std::string&amp; func_name,</div>
<div class="line"><a id="l00088" name="l00088"></a><span class="lineno"> 88</span> <span class="keyword">const</span> std::string&amp; op_name,</div>
<div class="line"><a id="l00089" name="l00089"></a><span class="lineno"> 89</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; in,</div>
<div class="line"><a id="l00090" name="l00090"></a><span class="lineno"> 90</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00091" name="l00091"></a><span class="lineno"> 91</span> <span class="keywordtype">int</span> ndim = -1,</div>
<div class="line"><a id="l00092" name="l00092"></a><span class="lineno"> 92</span> <span class="keywordtype">int</span> bm = -1,</div>
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"> 93</span> <span class="keywordtype">int</span> bn = -1);</div>
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"> 94</span> </div>
<div class="line"><a id="l00095" name="l00095"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a84fa8e0aee321a9d614433a0b933103b"> 95</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a84fa8e0aee321a9d614433a0b933103b">get_steel_gemm_fused_kernel</a>(</div>
<div class="line"><a id="l00096" name="l00096"></a><span class="lineno"> 96</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00097" name="l00097"></a><span class="lineno"> 97</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00098" name="l00098"></a><span class="lineno"> 98</span> <span class="keyword">const</span> std::string&amp; hash_name,</div>
<div class="line"><a id="l00099" name="l00099"></a><span class="lineno"> 99</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">metal::MTLFCList</a>&amp; func_consts,</div>
<div class="line"><a id="l00100" name="l00100"></a><span class="lineno"> 100</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"> 101</span> <span class="keywordtype">bool</span> transpose_a,</div>
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"> 102</span> <span class="keywordtype">bool</span> transpose_b,</div>
<div class="line"><a id="l00103" name="l00103"></a><span class="lineno"> 103</span> <span class="keywordtype">int</span> bm,</div>
<div class="line"><a id="l00104" name="l00104"></a><span class="lineno"> 104</span> <span class="keywordtype">int</span> bn,</div>
<div class="line"><a id="l00105" name="l00105"></a><span class="lineno"> 105</span> <span class="keywordtype">int</span> bk,</div>
<div class="line"><a id="l00106" name="l00106"></a><span class="lineno"> 106</span> <span class="keywordtype">int</span> wm,</div>
<div class="line"><a id="l00107" name="l00107"></a><span class="lineno"> 107</span> <span class="keywordtype">int</span> wn);</div>
<div class="line"><a id="l00108" name="l00108"></a><span class="lineno"> 108</span> </div>
<div class="line"><a id="l00109" name="l00109"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#af48c6f2f72b61dbd6766e4f5fea85df5"> 109</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#af48c6f2f72b61dbd6766e4f5fea85df5">get_steel_gemm_splitk_kernel</a>(</div>
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"> 110</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"> 111</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00112" name="l00112"></a><span class="lineno"> 112</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; in,</div>
<div class="line"><a id="l00113" name="l00113"></a><span class="lineno"> 113</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00114" name="l00114"></a><span class="lineno"> 114</span> <span class="keywordtype">bool</span> transpose_a,</div>
<div class="line"><a id="l00115" name="l00115"></a><span class="lineno"> 115</span> <span class="keywordtype">bool</span> transpose_b,</div>
<div class="line"><a id="l00116" name="l00116"></a><span class="lineno"> 116</span> <span class="keywordtype">int</span> bm,</div>
<div class="line"><a id="l00117" name="l00117"></a><span class="lineno"> 117</span> <span class="keywordtype">int</span> bn,</div>
<div class="line"><a id="l00118" name="l00118"></a><span class="lineno"> 118</span> <span class="keywordtype">int</span> bk,</div>
<div class="line"><a id="l00119" name="l00119"></a><span class="lineno"> 119</span> <span class="keywordtype">int</span> wm,</div>
<div class="line"><a id="l00120" name="l00120"></a><span class="lineno"> 120</span> <span class="keywordtype">int</span> wn,</div>
<div class="line"><a id="l00121" name="l00121"></a><span class="lineno"> 121</span> <span class="keywordtype">bool</span> mn_aligned,</div>
<div class="line"><a id="l00122" name="l00122"></a><span class="lineno"> 122</span> <span class="keywordtype">bool</span> k_aligned);</div>
<div class="line"><a id="l00123" name="l00123"></a><span class="lineno"> 123</span> </div>
<div class="line"><a id="l00124" name="l00124"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a195b86cad5bb99aa1bcd23952305af6b"> 124</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a195b86cad5bb99aa1bcd23952305af6b">get_steel_gemm_splitk_accum_kernel</a>(</div>
<div class="line"><a id="l00125" name="l00125"></a><span class="lineno"> 125</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00126" name="l00126"></a><span class="lineno"> 126</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00127" name="l00127"></a><span class="lineno"> 127</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; in,</div>
<div class="line"><a id="l00128" name="l00128"></a><span class="lineno"> 128</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00129" name="l00129"></a><span class="lineno"> 129</span> <span class="keywordtype">bool</span> axbpy);</div>
<div class="line"><a id="l00130" name="l00130"></a><span class="lineno"> 130</span> </div>
<div class="line"><a id="l00131" name="l00131"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#ab5f60614e965144b451930fdf935e08d"> 131</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#ab5f60614e965144b451930fdf935e08d">get_steel_gemm_masked_kernel</a>(</div>
<div class="line"><a id="l00132" name="l00132"></a><span class="lineno"> 132</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00133" name="l00133"></a><span class="lineno"> 133</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00134" name="l00134"></a><span class="lineno"> 134</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00135" name="l00135"></a><span class="lineno"> 135</span> <span class="keyword">const</span> std::optional&lt;array&gt;&amp; mask_out,</div>
<div class="line"><a id="l00136" name="l00136"></a><span class="lineno"> 136</span> <span class="keyword">const</span> std::optional&lt;array&gt;&amp; mask_op,</div>
<div class="line"><a id="l00137" name="l00137"></a><span class="lineno"> 137</span> <span class="keywordtype">bool</span> transpose_a,</div>
<div class="line"><a id="l00138" name="l00138"></a><span class="lineno"> 138</span> <span class="keywordtype">bool</span> transpose_b,</div>
<div class="line"><a id="l00139" name="l00139"></a><span class="lineno"> 139</span> <span class="keywordtype">int</span> bm,</div>
<div class="line"><a id="l00140" name="l00140"></a><span class="lineno"> 140</span> <span class="keywordtype">int</span> bn,</div>
<div class="line"><a id="l00141" name="l00141"></a><span class="lineno"> 141</span> <span class="keywordtype">int</span> bk,</div>
<div class="line"><a id="l00142" name="l00142"></a><span class="lineno"> 142</span> <span class="keywordtype">int</span> wm,</div>
<div class="line"><a id="l00143" name="l00143"></a><span class="lineno"> 143</span> <span class="keywordtype">int</span> wn,</div>
<div class="line"><a id="l00144" name="l00144"></a><span class="lineno"> 144</span> <span class="keywordtype">bool</span> mn_aligned,</div>
<div class="line"><a id="l00145" name="l00145"></a><span class="lineno"> 145</span> <span class="keywordtype">bool</span> k_aligned);</div>
<div class="line"><a id="l00146" name="l00146"></a><span class="lineno"> 146</span> </div>
<div class="line"><a id="l00147" name="l00147"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#adce79d220672f5f3c65cc31d145ca9c4"> 147</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#adce79d220672f5f3c65cc31d145ca9c4">get_steel_conv_kernel</a>(</div>
<div class="line"><a id="l00148" name="l00148"></a><span class="lineno"> 148</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00149" name="l00149"></a><span class="lineno"> 149</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00150" name="l00150"></a><span class="lineno"> 150</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00151" name="l00151"></a><span class="lineno"> 151</span> <span class="keywordtype">int</span> bm,</div>
<div class="line"><a id="l00152" name="l00152"></a><span class="lineno"> 152</span> <span class="keywordtype">int</span> bn,</div>
<div class="line"><a id="l00153" name="l00153"></a><span class="lineno"> 153</span> <span class="keywordtype">int</span> bk,</div>
<div class="line"><a id="l00154" name="l00154"></a><span class="lineno"> 154</span> <span class="keywordtype">int</span> wm,</div>
<div class="line"><a id="l00155" name="l00155"></a><span class="lineno"> 155</span> <span class="keywordtype">int</span> wn,</div>
<div class="line"><a id="l00156" name="l00156"></a><span class="lineno"> 156</span> <span class="keywordtype">int</span> n_channel_specialization,</div>
<div class="line"><a id="l00157" name="l00157"></a><span class="lineno"> 157</span> <span class="keywordtype">bool</span> small_filter);</div>
<div class="line"><a id="l00158" name="l00158"></a><span class="lineno"> 158</span> </div>
<div class="line"><a id="l00159" name="l00159"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a90c24e0d0b99b68fad9deefcf4d3e818"> 159</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a90c24e0d0b99b68fad9deefcf4d3e818">get_gemv_masked_kernel</a>(</div>
<div class="line"><a id="l00160" name="l00160"></a><span class="lineno"> 160</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00161" name="l00161"></a><span class="lineno"> 161</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00162" name="l00162"></a><span class="lineno"> 162</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00163" name="l00163"></a><span class="lineno"> 163</span> <span class="keyword">const</span> std::optional&lt;array&gt;&amp; mask_out,</div>
<div class="line"><a id="l00164" name="l00164"></a><span class="lineno"> 164</span> <span class="keyword">const</span> std::optional&lt;array&gt;&amp; mask_op,</div>
<div class="line"><a id="l00165" name="l00165"></a><span class="lineno"> 165</span> <span class="keywordtype">bool</span> transpose_mat,</div>
<div class="line"><a id="l00166" name="l00166"></a><span class="lineno"> 166</span> <span class="keywordtype">int</span> bm,</div>
<div class="line"><a id="l00167" name="l00167"></a><span class="lineno"> 167</span> <span class="keywordtype">int</span> bn,</div>
<div class="line"><a id="l00168" name="l00168"></a><span class="lineno"> 168</span> <span class="keywordtype">int</span> sm,</div>
<div class="line"><a id="l00169" name="l00169"></a><span class="lineno"> 169</span> <span class="keywordtype">int</span> sn,</div>
<div class="line"><a id="l00170" name="l00170"></a><span class="lineno"> 170</span> <span class="keywordtype">int</span> tm,</div>
<div class="line"><a id="l00171" name="l00171"></a><span class="lineno"> 171</span> <span class="keywordtype">int</span> tn,</div>
<div class="line"><a id="l00172" name="l00172"></a><span class="lineno"> 172</span> <span class="keywordtype">bool</span> contiguous);</div>
<div class="line"><a id="l00173" name="l00173"></a><span class="lineno"> 173</span> </div>
<div class="line"><a id="l00174" name="l00174"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#abce2b67044ee06a7bbe7a91ec7c8c48d"> 174</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#abce2b67044ee06a7bbe7a91ec7c8c48d">get_steel_conv_general_kernel</a>(</div>
<div class="line"><a id="l00175" name="l00175"></a><span class="lineno"> 175</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00176" name="l00176"></a><span class="lineno"> 176</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00177" name="l00177"></a><span class="lineno"> 177</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00178" name="l00178"></a><span class="lineno"> 178</span> <span class="keywordtype">int</span> bm,</div>
<div class="line"><a id="l00179" name="l00179"></a><span class="lineno"> 179</span> <span class="keywordtype">int</span> bn,</div>
<div class="line"><a id="l00180" name="l00180"></a><span class="lineno"> 180</span> <span class="keywordtype">int</span> bk,</div>
<div class="line"><a id="l00181" name="l00181"></a><span class="lineno"> 181</span> <span class="keywordtype">int</span> wm,</div>
<div class="line"><a id="l00182" name="l00182"></a><span class="lineno"> 182</span> <span class="keywordtype">int</span> wn);</div>
<div class="line"><a id="l00183" name="l00183"></a><span class="lineno"> 183</span> </div>
<div class="line"><a id="l00184" name="l00184"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a1d4cffc3c78067b3d9a62d64f3fb686f"> 184</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a1d4cffc3c78067b3d9a62d64f3fb686f">get_fft_kernel</a>(</div>
<div class="line"><a id="l00185" name="l00185"></a><span class="lineno"> 185</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00186" name="l00186"></a><span class="lineno"> 186</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00187" name="l00187"></a><span class="lineno"> 187</span> <span class="keyword">const</span> std::string&amp; hash_name,</div>
<div class="line"><a id="l00188" name="l00188"></a><span class="lineno"> 188</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">metal::MTLFCList</a>&amp; func_consts,</div>
<div class="line"><a id="l00189" name="l00189"></a><span class="lineno"> 189</span> <span class="keyword">const</span> std::string&amp; template_def);</div>
<div class="line"><a id="l00190" name="l00190"></a><span class="lineno"> 190</span> </div>
<div class="line"><a id="l00191" name="l00191"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#aa3faeae5378bfaafe3ce3432a051e43e"> 191</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#aa3faeae5378bfaafe3ce3432a051e43e">get_quantized_kernel</a>(</div>
<div class="line"><a id="l00192" name="l00192"></a><span class="lineno"> 192</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00193" name="l00193"></a><span class="lineno"> 193</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00194" name="l00194"></a><span class="lineno"> 194</span> <span class="keyword">const</span> std::string&amp; template_def);</div>
<div class="line"><a id="l00195" name="l00195"></a><span class="lineno"> 195</span> </div>
<div class="line"><a id="l00196" name="l00196"></a><span class="lineno"> 196</span><span class="comment">// Create a GPU kernel template definition for JIT compilation</span></div>
<div class="line"><a id="l00197" name="l00197"></a><span class="lineno"> 197</span><span class="keyword">template</span> &lt;<span class="keyword">typename</span>... Args&gt;</div>
<div class="line"><a id="l00198" name="l00198"></a><span class="lineno"> 198</span>std::string</div>
<div class="foldopen" id="foldopen00199" data-start="{" data-end="}">
<div class="line"><a id="l00199" name="l00199"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#aae0d19f0acdef2accd2428fb84c8a032"> 199</a></span><a class="code hl_function" href="namespacemlx_1_1core.html#aae0d19f0acdef2accd2428fb84c8a032">get_template_definition</a>(std::string name, std::string func, Args... args) {</div>
<div class="line"><a id="l00200" name="l00200"></a><span class="lineno"> 200</span> std::ostringstream s;</div>
<div class="line"><a id="l00201" name="l00201"></a><span class="lineno"> 201</span> s &lt;&lt; func &lt;&lt; <span class="stringliteral">&quot;&lt;&quot;</span>;</div>
<div class="line"><a id="l00202" name="l00202"></a><span class="lineno"> 202</span> <span class="keywordtype">bool</span> first = <span class="keyword">true</span>;</div>
<div class="line"><a id="l00203" name="l00203"></a><span class="lineno"> 203</span> <span class="keyword">auto</span> add_arg = [&amp;s, &amp;first](<span class="keyword">const</span> <span class="keyword">auto</span>&amp; arg) {</div>
<div class="line"><a id="l00204" name="l00204"></a><span class="lineno"> 204</span> <span class="keywordflow">if</span> (!first) {</div>
<div class="line"><a id="l00205" name="l00205"></a><span class="lineno"> 205</span> s &lt;&lt; <span class="stringliteral">&quot;, &quot;</span>;</div>
<div class="line"><a id="l00206" name="l00206"></a><span class="lineno"> 206</span> }</div>
<div class="line"><a id="l00207" name="l00207"></a><span class="lineno"> 207</span> first = <span class="keyword">false</span>;</div>
<div class="line"><a id="l00208" name="l00208"></a><span class="lineno"> 208</span> s &lt;&lt; arg;</div>
<div class="line"><a id="l00209" name="l00209"></a><span class="lineno"> 209</span> };</div>
<div class="line"><a id="l00210" name="l00210"></a><span class="lineno"> 210</span> (add_arg(args), ...);</div>
<div class="line"><a id="l00211" name="l00211"></a><span class="lineno"> 211</span> s &lt;&lt; <span class="stringliteral">&quot;&gt;&quot;</span>;</div>
<div class="line"><a id="l00212" name="l00212"></a><span class="lineno"> 212</span> <span class="keywordflow">return</span> fmt::format(</div>
<div class="line"><a id="l00213" name="l00213"></a><span class="lineno"> 213</span> <span class="stringliteral">&quot;\ntemplate [[host_name(\&quot;{0}\&quot;)]] [[kernel]] decltype({1}) {1};\n&quot;</span>,</div>
<div class="line"><a id="l00214" name="l00214"></a><span class="lineno"> 214</span> name,</div>
<div class="line"><a id="l00215" name="l00215"></a><span class="lineno"> 215</span> s.str());</div>
<div class="line"><a id="l00216" name="l00216"></a><span class="lineno"> 216</span>}</div>
<div class="line"><a id="l00082" name="l00082"></a><span class="lineno"> 82</span> <span class="keyword">const</span> std::string&amp; func_name,</div>
<div class="line"><a id="l00083" name="l00083"></a><span class="lineno"> 83</span> <span class="keyword">const</span> std::string&amp; op_name,</div>
<div class="line"><a id="l00084" name="l00084"></a><span class="lineno"> 84</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out);</div>
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"> 85</span> </div>
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b"> 86</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b">get_reduce_kernel</a>(</div>
<div class="line"><a id="l00087" name="l00087"></a><span class="lineno"> 87</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00088" name="l00088"></a><span class="lineno"> 88</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00089" name="l00089"></a><span class="lineno"> 89</span> <span class="keyword">const</span> std::string&amp; func_name,</div>
<div class="line"><a id="l00090" name="l00090"></a><span class="lineno"> 90</span> <span class="keyword">const</span> std::string&amp; op_name,</div>
<div class="line"><a id="l00091" name="l00091"></a><span class="lineno"> 91</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; in,</div>
<div class="line"><a id="l00092" name="l00092"></a><span class="lineno"> 92</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"> 93</span> <span class="keywordtype">int</span> ndim = -1,</div>
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"> 94</span> <span class="keywordtype">int</span> bm = -1,</div>
<div class="line"><a id="l00095" name="l00095"></a><span class="lineno"> 95</span> <span class="keywordtype">int</span> bn = -1);</div>
<div class="line"><a id="l00096" name="l00096"></a><span class="lineno"> 96</span> </div>
<div class="line"><a id="l00097" name="l00097"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a84fa8e0aee321a9d614433a0b933103b"> 97</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a84fa8e0aee321a9d614433a0b933103b">get_steel_gemm_fused_kernel</a>(</div>
<div class="line"><a id="l00098" name="l00098"></a><span class="lineno"> 98</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00099" name="l00099"></a><span class="lineno"> 99</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00100" name="l00100"></a><span class="lineno"> 100</span> <span class="keyword">const</span> std::string&amp; hash_name,</div>
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"> 101</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">metal::MTLFCList</a>&amp; func_consts,</div>
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"> 102</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00103" name="l00103"></a><span class="lineno"> 103</span> <span class="keywordtype">bool</span> transpose_a,</div>
<div class="line"><a id="l00104" name="l00104"></a><span class="lineno"> 104</span> <span class="keywordtype">bool</span> transpose_b,</div>
<div class="line"><a id="l00105" name="l00105"></a><span class="lineno"> 105</span> <span class="keywordtype">int</span> bm,</div>
<div class="line"><a id="l00106" name="l00106"></a><span class="lineno"> 106</span> <span class="keywordtype">int</span> bn,</div>
<div class="line"><a id="l00107" name="l00107"></a><span class="lineno"> 107</span> <span class="keywordtype">int</span> bk,</div>
<div class="line"><a id="l00108" name="l00108"></a><span class="lineno"> 108</span> <span class="keywordtype">int</span> wm,</div>
<div class="line"><a id="l00109" name="l00109"></a><span class="lineno"> 109</span> <span class="keywordtype">int</span> wn);</div>
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"> 110</span> </div>
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#af48c6f2f72b61dbd6766e4f5fea85df5"> 111</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#af48c6f2f72b61dbd6766e4f5fea85df5">get_steel_gemm_splitk_kernel</a>(</div>
<div class="line"><a id="l00112" name="l00112"></a><span class="lineno"> 112</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00113" name="l00113"></a><span class="lineno"> 113</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00114" name="l00114"></a><span class="lineno"> 114</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; in,</div>
<div class="line"><a id="l00115" name="l00115"></a><span class="lineno"> 115</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00116" name="l00116"></a><span class="lineno"> 116</span> <span class="keywordtype">bool</span> transpose_a,</div>
<div class="line"><a id="l00117" name="l00117"></a><span class="lineno"> 117</span> <span class="keywordtype">bool</span> transpose_b,</div>
<div class="line"><a id="l00118" name="l00118"></a><span class="lineno"> 118</span> <span class="keywordtype">int</span> bm,</div>
<div class="line"><a id="l00119" name="l00119"></a><span class="lineno"> 119</span> <span class="keywordtype">int</span> bn,</div>
<div class="line"><a id="l00120" name="l00120"></a><span class="lineno"> 120</span> <span class="keywordtype">int</span> bk,</div>
<div class="line"><a id="l00121" name="l00121"></a><span class="lineno"> 121</span> <span class="keywordtype">int</span> wm,</div>
<div class="line"><a id="l00122" name="l00122"></a><span class="lineno"> 122</span> <span class="keywordtype">int</span> wn,</div>
<div class="line"><a id="l00123" name="l00123"></a><span class="lineno"> 123</span> <span class="keywordtype">bool</span> mn_aligned,</div>
<div class="line"><a id="l00124" name="l00124"></a><span class="lineno"> 124</span> <span class="keywordtype">bool</span> k_aligned);</div>
<div class="line"><a id="l00125" name="l00125"></a><span class="lineno"> 125</span> </div>
<div class="line"><a id="l00126" name="l00126"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a195b86cad5bb99aa1bcd23952305af6b"> 126</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a195b86cad5bb99aa1bcd23952305af6b">get_steel_gemm_splitk_accum_kernel</a>(</div>
<div class="line"><a id="l00127" name="l00127"></a><span class="lineno"> 127</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00128" name="l00128"></a><span class="lineno"> 128</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00129" name="l00129"></a><span class="lineno"> 129</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; in,</div>
<div class="line"><a id="l00130" name="l00130"></a><span class="lineno"> 130</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00131" name="l00131"></a><span class="lineno"> 131</span> <span class="keywordtype">bool</span> axbpy);</div>
<div class="line"><a id="l00132" name="l00132"></a><span class="lineno"> 132</span> </div>
<div class="line"><a id="l00133" name="l00133"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#ab5f60614e965144b451930fdf935e08d"> 133</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#ab5f60614e965144b451930fdf935e08d">get_steel_gemm_masked_kernel</a>(</div>
<div class="line"><a id="l00134" name="l00134"></a><span class="lineno"> 134</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00135" name="l00135"></a><span class="lineno"> 135</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00136" name="l00136"></a><span class="lineno"> 136</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00137" name="l00137"></a><span class="lineno"> 137</span> <span class="keyword">const</span> std::optional&lt;array&gt;&amp; mask_out,</div>
<div class="line"><a id="l00138" name="l00138"></a><span class="lineno"> 138</span> <span class="keyword">const</span> std::optional&lt;array&gt;&amp; mask_op,</div>
<div class="line"><a id="l00139" name="l00139"></a><span class="lineno"> 139</span> <span class="keywordtype">bool</span> transpose_a,</div>
<div class="line"><a id="l00140" name="l00140"></a><span class="lineno"> 140</span> <span class="keywordtype">bool</span> transpose_b,</div>
<div class="line"><a id="l00141" name="l00141"></a><span class="lineno"> 141</span> <span class="keywordtype">int</span> bm,</div>
<div class="line"><a id="l00142" name="l00142"></a><span class="lineno"> 142</span> <span class="keywordtype">int</span> bn,</div>
<div class="line"><a id="l00143" name="l00143"></a><span class="lineno"> 143</span> <span class="keywordtype">int</span> bk,</div>
<div class="line"><a id="l00144" name="l00144"></a><span class="lineno"> 144</span> <span class="keywordtype">int</span> wm,</div>
<div class="line"><a id="l00145" name="l00145"></a><span class="lineno"> 145</span> <span class="keywordtype">int</span> wn,</div>
<div class="line"><a id="l00146" name="l00146"></a><span class="lineno"> 146</span> <span class="keywordtype">bool</span> mn_aligned,</div>
<div class="line"><a id="l00147" name="l00147"></a><span class="lineno"> 147</span> <span class="keywordtype">bool</span> k_aligned);</div>
<div class="line"><a id="l00148" name="l00148"></a><span class="lineno"> 148</span> </div>
<div class="line"><a id="l00149" name="l00149"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#adce79d220672f5f3c65cc31d145ca9c4"> 149</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#adce79d220672f5f3c65cc31d145ca9c4">get_steel_conv_kernel</a>(</div>
<div class="line"><a id="l00150" name="l00150"></a><span class="lineno"> 150</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00151" name="l00151"></a><span class="lineno"> 151</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00152" name="l00152"></a><span class="lineno"> 152</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00153" name="l00153"></a><span class="lineno"> 153</span> <span class="keywordtype">int</span> bm,</div>
<div class="line"><a id="l00154" name="l00154"></a><span class="lineno"> 154</span> <span class="keywordtype">int</span> bn,</div>
<div class="line"><a id="l00155" name="l00155"></a><span class="lineno"> 155</span> <span class="keywordtype">int</span> bk,</div>
<div class="line"><a id="l00156" name="l00156"></a><span class="lineno"> 156</span> <span class="keywordtype">int</span> wm,</div>
<div class="line"><a id="l00157" name="l00157"></a><span class="lineno"> 157</span> <span class="keywordtype">int</span> wn,</div>
<div class="line"><a id="l00158" name="l00158"></a><span class="lineno"> 158</span> <span class="keywordtype">int</span> n_channel_specialization,</div>
<div class="line"><a id="l00159" name="l00159"></a><span class="lineno"> 159</span> <span class="keywordtype">bool</span> small_filter);</div>
<div class="line"><a id="l00160" name="l00160"></a><span class="lineno"> 160</span> </div>
<div class="line"><a id="l00161" name="l00161"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a90c24e0d0b99b68fad9deefcf4d3e818"> 161</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a90c24e0d0b99b68fad9deefcf4d3e818">get_gemv_masked_kernel</a>(</div>
<div class="line"><a id="l00162" name="l00162"></a><span class="lineno"> 162</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00163" name="l00163"></a><span class="lineno"> 163</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00164" name="l00164"></a><span class="lineno"> 164</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00165" name="l00165"></a><span class="lineno"> 165</span> <span class="keyword">const</span> std::optional&lt;array&gt;&amp; mask_out,</div>
<div class="line"><a id="l00166" name="l00166"></a><span class="lineno"> 166</span> <span class="keyword">const</span> std::optional&lt;array&gt;&amp; mask_op,</div>
<div class="line"><a id="l00167" name="l00167"></a><span class="lineno"> 167</span> <span class="keywordtype">bool</span> transpose_mat,</div>
<div class="line"><a id="l00168" name="l00168"></a><span class="lineno"> 168</span> <span class="keywordtype">int</span> bm,</div>
<div class="line"><a id="l00169" name="l00169"></a><span class="lineno"> 169</span> <span class="keywordtype">int</span> bn,</div>
<div class="line"><a id="l00170" name="l00170"></a><span class="lineno"> 170</span> <span class="keywordtype">int</span> sm,</div>
<div class="line"><a id="l00171" name="l00171"></a><span class="lineno"> 171</span> <span class="keywordtype">int</span> sn,</div>
<div class="line"><a id="l00172" name="l00172"></a><span class="lineno"> 172</span> <span class="keywordtype">int</span> tm,</div>
<div class="line"><a id="l00173" name="l00173"></a><span class="lineno"> 173</span> <span class="keywordtype">int</span> tn,</div>
<div class="line"><a id="l00174" name="l00174"></a><span class="lineno"> 174</span> <span class="keywordtype">bool</span> contiguous);</div>
<div class="line"><a id="l00175" name="l00175"></a><span class="lineno"> 175</span> </div>
<div class="line"><a id="l00176" name="l00176"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#abce2b67044ee06a7bbe7a91ec7c8c48d"> 176</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#abce2b67044ee06a7bbe7a91ec7c8c48d">get_steel_conv_general_kernel</a>(</div>
<div class="line"><a id="l00177" name="l00177"></a><span class="lineno"> 177</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00178" name="l00178"></a><span class="lineno"> 178</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00179" name="l00179"></a><span class="lineno"> 179</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00180" name="l00180"></a><span class="lineno"> 180</span> <span class="keywordtype">int</span> bm,</div>
<div class="line"><a id="l00181" name="l00181"></a><span class="lineno"> 181</span> <span class="keywordtype">int</span> bn,</div>
<div class="line"><a id="l00182" name="l00182"></a><span class="lineno"> 182</span> <span class="keywordtype">int</span> bk,</div>
<div class="line"><a id="l00183" name="l00183"></a><span class="lineno"> 183</span> <span class="keywordtype">int</span> wm,</div>
<div class="line"><a id="l00184" name="l00184"></a><span class="lineno"> 184</span> <span class="keywordtype">int</span> wn);</div>
<div class="line"><a id="l00185" name="l00185"></a><span class="lineno"> 185</span> </div>
<div class="line"><a id="l00186" name="l00186"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a1d4cffc3c78067b3d9a62d64f3fb686f"> 186</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#a1d4cffc3c78067b3d9a62d64f3fb686f">get_fft_kernel</a>(</div>
<div class="line"><a id="l00187" name="l00187"></a><span class="lineno"> 187</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00188" name="l00188"></a><span class="lineno"> 188</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00189" name="l00189"></a><span class="lineno"> 189</span> <span class="keyword">const</span> std::string&amp; hash_name,</div>
<div class="line"><a id="l00190" name="l00190"></a><span class="lineno"> 190</span> <span class="keyword">const</span> <a class="code hl_typedef" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">metal::MTLFCList</a>&amp; func_consts,</div>
<div class="line"><a id="l00191" name="l00191"></a><span class="lineno"> 191</span> <span class="keyword">const</span> std::string&amp; template_def);</div>
<div class="line"><a id="l00192" name="l00192"></a><span class="lineno"> 192</span> </div>
<div class="line"><a id="l00193" name="l00193"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#aa3faeae5378bfaafe3ce3432a051e43e"> 193</a></span>MTL::ComputePipelineState* <a class="code hl_function" href="namespacemlx_1_1core.html#aa3faeae5378bfaafe3ce3432a051e43e">get_quantized_kernel</a>(</div>
<div class="line"><a id="l00194" name="l00194"></a><span class="lineno"> 194</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00195" name="l00195"></a><span class="lineno"> 195</span> <span class="keyword">const</span> std::string&amp; kernel_name,</div>
<div class="line"><a id="l00196" name="l00196"></a><span class="lineno"> 196</span> <span class="keyword">const</span> std::string&amp; template_def);</div>
<div class="line"><a id="l00197" name="l00197"></a><span class="lineno"> 197</span> </div>
<div class="line"><a id="l00198" name="l00198"></a><span class="lineno"> 198</span><span class="comment">// Create a GPU kernel template definition for JIT compilation</span></div>
<div class="line"><a id="l00199" name="l00199"></a><span class="lineno"> 199</span><span class="keyword">template</span> &lt;<span class="keyword">typename</span>... Args&gt;</div>
<div class="line"><a id="l00200" name="l00200"></a><span class="lineno"> 200</span>std::string</div>
<div class="foldopen" id="foldopen00201" data-start="{" data-end="}">
<div class="line"><a id="l00201" name="l00201"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#aae0d19f0acdef2accd2428fb84c8a032"> 201</a></span><a class="code hl_function" href="namespacemlx_1_1core.html#aae0d19f0acdef2accd2428fb84c8a032">get_template_definition</a>(std::string name, std::string func, Args... args) {</div>
<div class="line"><a id="l00202" name="l00202"></a><span class="lineno"> 202</span> std::ostringstream s;</div>
<div class="line"><a id="l00203" name="l00203"></a><span class="lineno"> 203</span> s &lt;&lt; func &lt;&lt; <span class="stringliteral">&quot;&lt;&quot;</span>;</div>
<div class="line"><a id="l00204" name="l00204"></a><span class="lineno"> 204</span> <span class="keywordtype">bool</span> first = <span class="keyword">true</span>;</div>
<div class="line"><a id="l00205" name="l00205"></a><span class="lineno"> 205</span> <span class="keyword">auto</span> add_arg = [&amp;s, &amp;first](<span class="keyword">const</span> <span class="keyword">auto</span>&amp; arg) {</div>
<div class="line"><a id="l00206" name="l00206"></a><span class="lineno"> 206</span> <span class="keywordflow">if</span> (!first) {</div>
<div class="line"><a id="l00207" name="l00207"></a><span class="lineno"> 207</span> s &lt;&lt; <span class="stringliteral">&quot;, &quot;</span>;</div>
<div class="line"><a id="l00208" name="l00208"></a><span class="lineno"> 208</span> }</div>
<div class="line"><a id="l00209" name="l00209"></a><span class="lineno"> 209</span> first = <span class="keyword">false</span>;</div>
<div class="line"><a id="l00210" name="l00210"></a><span class="lineno"> 210</span> s &lt;&lt; arg;</div>
<div class="line"><a id="l00211" name="l00211"></a><span class="lineno"> 211</span> };</div>
<div class="line"><a id="l00212" name="l00212"></a><span class="lineno"> 212</span> (add_arg(args), ...);</div>
<div class="line"><a id="l00213" name="l00213"></a><span class="lineno"> 213</span> s &lt;&lt; <span class="stringliteral">&quot;&gt;&quot;</span>;</div>
<div class="line"><a id="l00214" name="l00214"></a><span class="lineno"> 214</span> <span class="keywordflow">return</span> fmt::format(</div>
<div class="line"><a id="l00215" name="l00215"></a><span class="lineno"> 215</span> <span class="stringliteral">&quot;\ntemplate [[host_name(\&quot;{0}\&quot;)]] [[kernel]] decltype({1}) {1};\n&quot;</span>,</div>
<div class="line"><a id="l00216" name="l00216"></a><span class="lineno"> 216</span> name,</div>
<div class="line"><a id="l00217" name="l00217"></a><span class="lineno"> 217</span> s.str());</div>
<div class="line"><a id="l00218" name="l00218"></a><span class="lineno"> 218</span>}</div>
</div>
<div class="line"><a id="l00217" name="l00217"></a><span class="lineno"> 217</span> </div>
<div class="line"><a id="l00218" name="l00218"></a><span class="lineno"> 218</span>} <span class="comment">// namespace mlx::core</span></div>
<div class="line"><a id="l00219" name="l00219"></a><span class="lineno"> 219</span> </div>
<div class="line"><a id="l00220" name="l00220"></a><span class="lineno"> 220</span>} <span class="comment">// namespace mlx::core</span></div>
<div class="ttc" id="aarray_8h_html"><div class="ttname"><a href="array_8h.html">array.h</a></div></div>
<div class="ttc" id="abackend_2metal_2device_8h_html"><div class="ttname"><a href="backend_2metal_2device_8h.html">device.h</a></div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1array_html"><div class="ttname"><a href="classmlx_1_1core_1_1array.html">mlx::core::array</a></div><div class="ttdef"><b>Definition</b> array.h:20</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html">mlx::core::metal::Device</a></div><div class="ttdef"><b>Definition</b> device.h:128</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html">mlx::core::metal::Device</a></div><div class="ttdef"><b>Definition</b> device.h:131</div></div>
<div class="ttc" id="acommon_2binary_8h_html_a70228731d29946574b238d21fb4b360c"><div class="ttname"><a href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a></div><div class="ttdeci">Op op</div><div class="ttdef"><b>Definition</b> binary.h:129</div></div>
<div class="ttc" id="anamespacemlx_1_1core_1_1metal_html_a616e09a1ef321d527770721cef264c54"><div class="ttname"><a href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">mlx::core::metal::MTLFCList</a></div><div class="ttdeci">std::vector&lt; std::tuple&lt; const void *, MTL::DataType, NS::UInteger &gt; &gt; MTLFCList</div><div class="ttdef"><b>Definition</b> device.h:38</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html"><div class="ttname"><a href="namespacemlx_1_1core.html">mlx::core</a></div><div class="ttdef"><b>Definition</b> allocator.h:7</div></div>
@ -322,9 +324,9 @@ $(function(){ initResizable(false); });
<div class="ttc" id="anamespacemlx_1_1core_html_a195b86cad5bb99aa1bcd23952305af6b"><div class="ttname"><a href="namespacemlx_1_1core.html#a195b86cad5bb99aa1bcd23952305af6b">mlx::core::get_steel_gemm_splitk_accum_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_steel_gemm_splitk_accum_kernel(metal::Device &amp;d, const std::string &amp;kernel_name, const array &amp;in, const array &amp;out, bool axbpy)</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_a1d4cffc3c78067b3d9a62d64f3fb686f"><div class="ttname"><a href="namespacemlx_1_1core.html#a1d4cffc3c78067b3d9a62d64f3fb686f">mlx::core::get_fft_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_fft_kernel(metal::Device &amp;d, const std::string &amp;kernel_name, const std::string &amp;hash_name, const metal::MTLFCList &amp;func_consts, const std::string &amp;template_def)</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_a35a412f688d79eb47e42d20a7c8650ee"><div class="ttname"><a href="namespacemlx_1_1core.html#a35a412f688d79eb47e42d20a7c8650ee">mlx::core::get_softmax_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_softmax_kernel(metal::Device &amp;d, const std::string &amp;kernel_name, bool precise, const array &amp;out)</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_a3bd386cb6db09f636963ce66ceaf8647"><div class="ttname"><a href="namespacemlx_1_1core.html#a3bd386cb6db09f636963ce66ceaf8647">mlx::core::get_reduce_init_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_reduce_init_kernel(metal::Device &amp;d, const std::string &amp;kernel_name, const std::string &amp;func_name, const std::string &amp;op_name, const array &amp;out)</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_a4decd4a07d91487e6903f6e3c8b7513a"><div class="ttname"><a href="namespacemlx_1_1core.html#a4decd4a07d91487e6903f6e3c8b7513a">mlx::core::get_binary_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_binary_kernel(metal::Device &amp;d, const std::string &amp;kernel_name, Dtype in_type, Dtype out_type, const std::string op)</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_a4e809746f48e5dcf7fa63215d3f5e33e"><div class="ttname"><a href="namespacemlx_1_1core.html#a4e809746f48e5dcf7fa63215d3f5e33e">mlx::core::get_binary_two_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_binary_two_kernel(metal::Device &amp;d, const std::string &amp;kernel_name, Dtype in_type, Dtype out_type, const std::string op)</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_a51c4bb09230348bd0252e22bfdc9bc89"><div class="ttname"><a href="namespacemlx_1_1core.html#a51c4bb09230348bd0252e22bfdc9bc89">mlx::core::get_reduce_init_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_reduce_init_kernel(metal::Device &amp;d, const std::string &amp;kernel_name, const array &amp;out)</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_a54eb3b65375022428aab5f810e40624b"><div class="ttname"><a href="namespacemlx_1_1core.html#a54eb3b65375022428aab5f810e40624b">mlx::core::get_ternary_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_ternary_kernel(metal::Device &amp;d, const std::string &amp;kernel_name, Dtype type, const std::string op)</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_a76f614e9956a6ca05a9be4db5a483446"><div class="ttname"><a href="namespacemlx_1_1core.html#a76f614e9956a6ca05a9be4db5a483446">mlx::core::get_arange_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_arange_kernel(metal::Device &amp;d, const std::string &amp;kernel_name, const array &amp;out)</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_a7aa91fcfe8b9caa42d60a957f11bfe6b"><div class="ttname"><a href="namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b">mlx::core::get_reduce_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_reduce_kernel(metal::Device &amp;d, const std::string &amp;kernel_name, const std::string &amp;func_name, const std::string &amp;op_name, const array &amp;in, const array &amp;out, int ndim=-1, int bm=-1, int bn=-1)</div></div>
@ -332,7 +334,7 @@ $(function(){ initResizable(false); });
<div class="ttc" id="anamespacemlx_1_1core_html_a84fa8e0aee321a9d614433a0b933103b"><div class="ttname"><a href="namespacemlx_1_1core.html#a84fa8e0aee321a9d614433a0b933103b">mlx::core::get_steel_gemm_fused_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_steel_gemm_fused_kernel(metal::Device &amp;d, const std::string &amp;kernel_name, const std::string &amp;hash_name, const metal::MTLFCList &amp;func_consts, const array &amp;out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_a90c24e0d0b99b68fad9deefcf4d3e818"><div class="ttname"><a href="namespacemlx_1_1core.html#a90c24e0d0b99b68fad9deefcf4d3e818">mlx::core::get_gemv_masked_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_gemv_masked_kernel(metal::Device &amp;d, const std::string &amp;kernel_name, const array &amp;out, const std::optional&lt; array &gt; &amp;mask_out, const std::optional&lt; array &gt; &amp;mask_op, bool transpose_mat, int bm, int bn, int sm, int sn, int tm, int tn, bool contiguous)</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_aa3faeae5378bfaafe3ce3432a051e43e"><div class="ttname"><a href="namespacemlx_1_1core.html#aa3faeae5378bfaafe3ce3432a051e43e">mlx::core::get_quantized_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_quantized_kernel(metal::Device &amp;d, const std::string &amp;kernel_name, const std::string &amp;template_def)</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_aae0d19f0acdef2accd2428fb84c8a032"><div class="ttname"><a href="namespacemlx_1_1core.html#aae0d19f0acdef2accd2428fb84c8a032">mlx::core::get_template_definition</a></div><div class="ttdeci">std::string get_template_definition(std::string name, std::string func, Args... args)</div><div class="ttdef"><b>Definition</b> kernels.h:199</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_aae0d19f0acdef2accd2428fb84c8a032"><div class="ttname"><a href="namespacemlx_1_1core.html#aae0d19f0acdef2accd2428fb84c8a032">mlx::core::get_template_definition</a></div><div class="ttdeci">std::string get_template_definition(std::string name, std::string func, Args... args)</div><div class="ttdef"><b>Definition</b> kernels.h:201</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_ab5f60614e965144b451930fdf935e08d"><div class="ttname"><a href="namespacemlx_1_1core.html#ab5f60614e965144b451930fdf935e08d">mlx::core::get_steel_gemm_masked_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_steel_gemm_masked_kernel(metal::Device &amp;d, const std::string &amp;kernel_name, const array &amp;out, const std::optional&lt; array &gt; &amp;mask_out, const std::optional&lt; array &gt; &amp;mask_op, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_abce2b67044ee06a7bbe7a91ec7c8c48d"><div class="ttname"><a href="namespacemlx_1_1core.html#abce2b67044ee06a7bbe7a91ec7c8c48d">mlx::core::get_steel_conv_general_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_steel_conv_general_kernel(metal::Device &amp;d, const std::string &amp;kernel_name, const array &amp;out, int bm, int bn, int bk, int wm, int wn)</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_adce79d220672f5f3c65cc31d145ca9c4"><div class="ttname"><a href="namespacemlx_1_1core.html#adce79d220672f5f3c65cc31d145ca9c4">mlx::core::get_steel_conv_kernel</a></div><div class="ttdeci">MTL::ComputePipelineState * get_steel_conv_kernel(metal::Device &amp;d, const std::string &amp;kernel_name, const array &amp;out, int bm, int bn, int bk, int wm, int wn, int n_channel_specialization, bool small_filter)</div></div>

View File

@ -143,7 +143,7 @@ $(function(){ initResizable(false); });
<div class="line"><a id="l00050" name="l00050"></a><span class="lineno"> 50</span>} <span class="comment">// namespace mlx::core</span></div>
<div class="ttc" id="abackend_2metal_2device_8h_html"><div class="ttname"><a href="backend_2metal_2device_8h.html">device.h</a></div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1array_html"><div class="ttname"><a href="classmlx_1_1core_1_1array.html">mlx::core::array</a></div><div class="ttdef"><b>Definition</b> array.h:20</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html">mlx::core::metal::Device</a></div><div class="ttdef"><b>Definition</b> device.h:128</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html">mlx::core::metal::Device</a></div><div class="ttdef"><b>Definition</b> device.h:131</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html"><div class="ttname"><a href="namespacemlx_1_1core.html">mlx::core</a></div><div class="ttdef"><b>Definition</b> allocator.h:7</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_a227588758ccc9ee869dba147e830bb74"><div class="ttname"><a href="namespacemlx_1_1core.html#a227588758ccc9ee869dba147e830bb74">mlx::core::steel_matmul_regular</a></div><div class="ttdeci">void steel_matmul_regular(const Stream &amp;s, metal::Device &amp;d, const array &amp;a, const array &amp;b, array &amp;out, int M, int N, int K, int batch_size_out, int lda, int ldb, int ldd, bool transpose_a, bool transpose_b, std::vector&lt; int &gt; batch_shape, std::vector&lt; size_t &gt; batch_strides, size_t A_batch_stride, size_t B_batch_stride, size_t matrix_stride_out, std::vector&lt; array &gt; &amp;copies)</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_ab43a7633794498e1c6775cca829eb886"><div class="ttname"><a href="namespacemlx_1_1core.html#ab43a7633794498e1c6775cca829eb886">mlx::core::steel_matmul</a></div><div class="ttdeci">void steel_matmul(const Stream &amp;s, metal::Device &amp;d, const array &amp;a, const array &amp;b, array &amp;out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector&lt; array &gt; &amp;copies, std::vector&lt; int &gt; batch_shape={}, std::vector&lt; size_t &gt; A_batch_stride={}, std::vector&lt; size_t &gt; B_batch_stride={})</div></div>

View File

@ -109,8 +109,8 @@ Namespaces</h2></td></tr>
</table><table class="memberdecls">
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a id="func-members" name="func-members"></a>
Functions</h2></td></tr>
<tr class="memitem:af7b7ca7c6aa87558d9f98cee5c7a99a8" id="r_af7b7ca7c6aa87558d9f98cee5c7a99a8"><td class="memItemLeft" align="right" valign="top">void&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="namespacemlx_1_1core.html#af7b7ca7c6aa87558d9f98cee5c7a99a8">mlx::core::all_reduce_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, const std::string &amp;op_name, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &amp;compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &amp;s, std::vector&lt; <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &gt; &amp;copies)</td></tr>
<tr class="separator:af7b7ca7c6aa87558d9f98cee5c7a99a8"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a3ab0fd997d9a35782106ff083a72e098" id="r_a3ab0fd997d9a35782106ff083a72e098"><td class="memItemLeft" align="right" valign="top">void&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="namespacemlx_1_1core.html#a3ab0fd997d9a35782106ff083a72e098">mlx::core::all_reduce_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, const std::string &amp;op_name, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &amp;compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &amp;s)</td></tr>
<tr class="separator:a3ab0fd997d9a35782106ff083a72e098"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:ab1eeca8ec6fa31819ee108fa6ed2c41b" id="r_ab1eeca8ec6fa31819ee108fa6ed2c41b"><td class="memItemLeft" align="right" valign="top">void&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="namespacemlx_1_1core.html#ab1eeca8ec6fa31819ee108fa6ed2c41b">mlx::core::row_reduce_general_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, const std::string &amp;op_name, const <a class="el" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a> &amp;plan, const std::vector&lt; int &gt; &amp;axes, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &amp;compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &amp;s)</td></tr>
<tr class="separator:ab1eeca8ec6fa31819ee108fa6ed2c41b"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:aa0332c64ee9965f05026c30a0b778000" id="r_aa0332c64ee9965f05026c30a0b778000"><td class="memItemLeft" align="right" valign="top">void&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="namespacemlx_1_1core.html#aa0332c64ee9965f05026c30a0b778000">mlx::core::strided_reduce_general_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, const std::string &amp;op_name, const <a class="el" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a> &amp;plan, const std::vector&lt; int &gt; &amp;axes, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &amp;compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &amp;s)</td></tr>

View File

@ -103,44 +103,43 @@ $(function(){ initResizable(false); });
<div class="line"><a id="l00010" name="l00010"></a><span class="lineno"> 10</span> </div>
<div class="line"><a id="l00011" name="l00011"></a><span class="lineno"> 11</span><span class="keyword">using </span>metal::CommandEncoder;</div>
<div class="line"><a id="l00012" name="l00012"></a><span class="lineno"> 12</span> </div>
<div class="line"><a id="l00013" name="l00013"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#af7b7ca7c6aa87558d9f98cee5c7a99a8"> 13</a></span><span class="keywordtype">void</span> <a class="code hl_function" href="namespacemlx_1_1core.html#af7b7ca7c6aa87558d9f98cee5c7a99a8">all_reduce_dispatch</a>(</div>
<div class="line"><a id="l00013" name="l00013"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#a3ab0fd997d9a35782106ff083a72e098"> 13</a></span><span class="keywordtype">void</span> <a class="code hl_function" href="namespacemlx_1_1core.html#a3ab0fd997d9a35782106ff083a72e098">all_reduce_dispatch</a>(</div>
<div class="line"><a id="l00014" name="l00014"></a><span class="lineno"> 14</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; in,</div>
<div class="line"><a id="l00015" name="l00015"></a><span class="lineno"> 15</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00016" name="l00016"></a><span class="lineno"> 16</span> <span class="keyword">const</span> std::string&amp; op_name,</div>
<div class="line"><a id="l00017" name="l00017"></a><span class="lineno"> 17</span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a>&amp; compute_encoder,</div>
<div class="line"><a id="l00018" name="l00018"></a><span class="lineno"> 18</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00019" name="l00019"></a><span class="lineno"> 19</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_stream.html">Stream</a>&amp; s,</div>
<div class="line"><a id="l00020" name="l00020"></a><span class="lineno"> 20</span> std::vector&lt;array&gt;&amp; copies);</div>
<div class="line"><a id="l00021" name="l00021"></a><span class="lineno"> 21</span> </div>
<div class="line"><a id="l00022" name="l00022"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#ab1eeca8ec6fa31819ee108fa6ed2c41b"> 22</a></span><span class="keywordtype">void</span> <a class="code hl_function" href="namespacemlx_1_1core.html#ab1eeca8ec6fa31819ee108fa6ed2c41b">row_reduce_general_dispatch</a>(</div>
<div class="line"><a id="l00023" name="l00023"></a><span class="lineno"> 23</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; in,</div>
<div class="line"><a id="l00024" name="l00024"></a><span class="lineno"> 24</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00025" name="l00025"></a><span class="lineno"> 25</span> <span class="keyword">const</span> std::string&amp; op_name,</div>
<div class="line"><a id="l00026" name="l00026"></a><span class="lineno"> 26</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a>&amp; plan,</div>
<div class="line"><a id="l00027" name="l00027"></a><span class="lineno"> 27</span> <span class="keyword">const</span> std::vector&lt;int&gt;&amp; axes,</div>
<div class="line"><a id="l00028" name="l00028"></a><span class="lineno"> 28</span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a>&amp; compute_encoder,</div>
<div class="line"><a id="l00029" name="l00029"></a><span class="lineno"> 29</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00030" name="l00030"></a><span class="lineno"> 30</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_stream.html">Stream</a>&amp; s);</div>
<div class="line"><a id="l00031" name="l00031"></a><span class="lineno"> 31</span> </div>
<div class="line"><a id="l00032" name="l00032"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#aa0332c64ee9965f05026c30a0b778000"> 32</a></span><span class="keywordtype">void</span> <a class="code hl_function" href="namespacemlx_1_1core.html#aa0332c64ee9965f05026c30a0b778000">strided_reduce_general_dispatch</a>(</div>
<div class="line"><a id="l00033" name="l00033"></a><span class="lineno"> 33</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; in,</div>
<div class="line"><a id="l00034" name="l00034"></a><span class="lineno"> 34</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00035" name="l00035"></a><span class="lineno"> 35</span> <span class="keyword">const</span> std::string&amp; op_name,</div>
<div class="line"><a id="l00036" name="l00036"></a><span class="lineno"> 36</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a>&amp; plan,</div>
<div class="line"><a id="l00037" name="l00037"></a><span class="lineno"> 37</span> <span class="keyword">const</span> std::vector&lt;int&gt;&amp; axes,</div>
<div class="line"><a id="l00038" name="l00038"></a><span class="lineno"> 38</span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a>&amp; compute_encoder,</div>
<div class="line"><a id="l00039" name="l00039"></a><span class="lineno"> 39</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00040" name="l00040"></a><span class="lineno"> 40</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_stream.html">Stream</a>&amp; s);</div>
<div class="line"><a id="l00041" name="l00041"></a><span class="lineno"> 41</span> </div>
<div class="line"><a id="l00042" name="l00042"></a><span class="lineno"> 42</span>} <span class="comment">// namespace mlx::core</span></div>
<div class="line"><a id="l00019" name="l00019"></a><span class="lineno"> 19</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_stream.html">Stream</a>&amp; s);</div>
<div class="line"><a id="l00020" name="l00020"></a><span class="lineno"> 20</span> </div>
<div class="line"><a id="l00021" name="l00021"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#ab1eeca8ec6fa31819ee108fa6ed2c41b"> 21</a></span><span class="keywordtype">void</span> <a class="code hl_function" href="namespacemlx_1_1core.html#ab1eeca8ec6fa31819ee108fa6ed2c41b">row_reduce_general_dispatch</a>(</div>
<div class="line"><a id="l00022" name="l00022"></a><span class="lineno"> 22</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; in,</div>
<div class="line"><a id="l00023" name="l00023"></a><span class="lineno"> 23</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00024" name="l00024"></a><span class="lineno"> 24</span> <span class="keyword">const</span> std::string&amp; op_name,</div>
<div class="line"><a id="l00025" name="l00025"></a><span class="lineno"> 25</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a>&amp; plan,</div>
<div class="line"><a id="l00026" name="l00026"></a><span class="lineno"> 26</span> <span class="keyword">const</span> std::vector&lt;int&gt;&amp; axes,</div>
<div class="line"><a id="l00027" name="l00027"></a><span class="lineno"> 27</span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a>&amp; compute_encoder,</div>
<div class="line"><a id="l00028" name="l00028"></a><span class="lineno"> 28</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00029" name="l00029"></a><span class="lineno"> 29</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_stream.html">Stream</a>&amp; s);</div>
<div class="line"><a id="l00030" name="l00030"></a><span class="lineno"> 30</span> </div>
<div class="line"><a id="l00031" name="l00031"></a><span class="lineno"><a class="line" href="namespacemlx_1_1core.html#aa0332c64ee9965f05026c30a0b778000"> 31</a></span><span class="keywordtype">void</span> <a class="code hl_function" href="namespacemlx_1_1core.html#aa0332c64ee9965f05026c30a0b778000">strided_reduce_general_dispatch</a>(</div>
<div class="line"><a id="l00032" name="l00032"></a><span class="lineno"> 32</span> <span class="keyword">const</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; in,</div>
<div class="line"><a id="l00033" name="l00033"></a><span class="lineno"> 33</span> <a class="code hl_class" href="classmlx_1_1core_1_1array.html">array</a>&amp; out,</div>
<div class="line"><a id="l00034" name="l00034"></a><span class="lineno"> 34</span> <span class="keyword">const</span> std::string&amp; op_name,</div>
<div class="line"><a id="l00035" name="l00035"></a><span class="lineno"> 35</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a>&amp; plan,</div>
<div class="line"><a id="l00036" name="l00036"></a><span class="lineno"> 36</span> <span class="keyword">const</span> std::vector&lt;int&gt;&amp; axes,</div>
<div class="line"><a id="l00037" name="l00037"></a><span class="lineno"> 37</span> <a class="code hl_struct" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a>&amp; compute_encoder,</div>
<div class="line"><a id="l00038" name="l00038"></a><span class="lineno"> 38</span> <a class="code hl_class" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a>&amp; d,</div>
<div class="line"><a id="l00039" name="l00039"></a><span class="lineno"> 39</span> <span class="keyword">const</span> <a class="code hl_struct" href="structmlx_1_1core_1_1_stream.html">Stream</a>&amp; s);</div>
<div class="line"><a id="l00040" name="l00040"></a><span class="lineno"> 40</span> </div>
<div class="line"><a id="l00041" name="l00041"></a><span class="lineno"> 41</span>} <span class="comment">// namespace mlx::core</span></div>
<div class="ttc" id="abackend_2metal_2device_8h_html"><div class="ttname"><a href="backend_2metal_2device_8h.html">device.h</a></div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1array_html"><div class="ttname"><a href="classmlx_1_1core_1_1array.html">mlx::core::array</a></div><div class="ttdef"><b>Definition</b> array.h:20</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html">mlx::core::metal::Device</a></div><div class="ttdef"><b>Definition</b> device.h:128</div></div>
<div class="ttc" id="aclassmlx_1_1core_1_1metal_1_1_device_html"><div class="ttname"><a href="classmlx_1_1core_1_1metal_1_1_device.html">mlx::core::metal::Device</a></div><div class="ttdef"><b>Definition</b> device.h:131</div></div>
<div class="ttc" id="acommon_2reduce_8h_html"><div class="ttname"><a href="common_2reduce_8h.html">reduce.h</a></div></div>
<div class="ttc" id="anamespacemlx_1_1core_html"><div class="ttname"><a href="namespacemlx_1_1core.html">mlx::core</a></div><div class="ttdef"><b>Definition</b> allocator.h:7</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_a3ab0fd997d9a35782106ff083a72e098"><div class="ttname"><a href="namespacemlx_1_1core.html#a3ab0fd997d9a35782106ff083a72e098">mlx::core::all_reduce_dispatch</a></div><div class="ttdeci">void all_reduce_dispatch(const array &amp;in, array &amp;out, const std::string &amp;op_name, CommandEncoder &amp;compute_encoder, metal::Device &amp;d, const Stream &amp;s)</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_aa0332c64ee9965f05026c30a0b778000"><div class="ttname"><a href="namespacemlx_1_1core.html#aa0332c64ee9965f05026c30a0b778000">mlx::core::strided_reduce_general_dispatch</a></div><div class="ttdeci">void strided_reduce_general_dispatch(const array &amp;in, array &amp;out, const std::string &amp;op_name, const ReductionPlan &amp;plan, const std::vector&lt; int &gt; &amp;axes, CommandEncoder &amp;compute_encoder, metal::Device &amp;d, const Stream &amp;s)</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_ab1eeca8ec6fa31819ee108fa6ed2c41b"><div class="ttname"><a href="namespacemlx_1_1core.html#ab1eeca8ec6fa31819ee108fa6ed2c41b">mlx::core::row_reduce_general_dispatch</a></div><div class="ttdeci">void row_reduce_general_dispatch(const array &amp;in, array &amp;out, const std::string &amp;op_name, const ReductionPlan &amp;plan, const std::vector&lt; int &gt; &amp;axes, CommandEncoder &amp;compute_encoder, metal::Device &amp;d, const Stream &amp;s)</div></div>
<div class="ttc" id="anamespacemlx_1_1core_html_af7b7ca7c6aa87558d9f98cee5c7a99a8"><div class="ttname"><a href="namespacemlx_1_1core.html#af7b7ca7c6aa87558d9f98cee5c7a99a8">mlx::core::all_reduce_dispatch</a></div><div class="ttdeci">void all_reduce_dispatch(const array &amp;in, array &amp;out, const std::string &amp;op_name, CommandEncoder &amp;compute_encoder, metal::Device &amp;d, const Stream &amp;s, std::vector&lt; array &gt; &amp;copies)</div></div>
<div class="ttc" id="astream_8h_html"><div class="ttname"><a href="stream_8h.html">stream.h</a></div></div>
<div class="ttc" id="astructmlx_1_1core_1_1_reduction_plan_html"><div class="ttname"><a href="structmlx_1_1core_1_1_reduction_plan.html">mlx::core::ReductionPlan</a></div><div class="ttdef"><b>Definition</b> reduce.h:39</div></div>
<div class="ttc" id="astructmlx_1_1core_1_1_stream_html"><div class="ttname"><a href="structmlx_1_1core_1_1_stream.html">mlx::core::Stream</a></div><div class="ttdef"><b>Definition</b> stream.h:9</div></div>

View File

@ -99,7 +99,7 @@ $(function(){ initResizable(false); });
<li>aligned_dealloc()&#160;:&#160;<a class="el" href="namespacepocketfft_1_1detail.html#aec7820e36a33e0a8bb83aa03b04b81e8">pocketfft::detail</a></li>
<li>all()&#160;:&#160;<a class="el" href="group__ops.html#ga3b1b90ef1275ca17655b6d7f25d3ee68">mlx::core</a></li>
<li>all_gather()&#160;:&#160;<a class="el" href="namespacemlx_1_1core_1_1distributed.html#a82ef5e8cc7ac62cd228e51b1c1b77cb7">mlx::core::distributed</a>, <a class="el" href="namespacemlx_1_1core_1_1distributed_1_1detail.html#aeb5a1726358213bc75756506f7b54d04">mlx::core::distributed::detail</a></li>
<li>all_reduce_dispatch()&#160;:&#160;<a class="el" href="namespacemlx_1_1core.html#af7b7ca7c6aa87558d9f98cee5c7a99a8">mlx::core</a></li>
<li>all_reduce_dispatch()&#160;:&#160;<a class="el" href="namespacemlx_1_1core.html#a3ab0fd997d9a35782106ff083a72e098">mlx::core</a></li>
<li>all_sum()&#160;:&#160;<a class="el" href="namespacemlx_1_1core_1_1distributed.html#a67ccb1a5445fc6f5db49dd36a15e5980">mlx::core::distributed</a>, <a class="el" href="namespacemlx_1_1core_1_1distributed_1_1detail.html#aa1d225b25f7b6426c48c5e35860ee960">mlx::core::distributed::detail</a></li>
<li>allclose()&#160;:&#160;<a class="el" href="group__ops.html#gaf0cd4257de7542daf9faf5e605e31020">mlx::core</a></li>
<li>alloc_tmp()&#160;:&#160;<a class="el" href="namespacepocketfft_1_1detail.html#a4db03cbcd9d43d9e0b0b9067713c80e9">pocketfft::detail</a></li>

View File

@ -98,7 +98,7 @@ $(function(){ initResizable(false); });
<li>aligned_dealloc()&#160;:&#160;<a class="el" href="namespacepocketfft_1_1detail.html#aec7820e36a33e0a8bb83aa03b04b81e8">pocketfft::detail</a></li>
<li>all()&#160;:&#160;<a class="el" href="group__ops.html#ga3b1b90ef1275ca17655b6d7f25d3ee68">mlx::core</a></li>
<li>all_gather()&#160;:&#160;<a class="el" href="namespacemlx_1_1core_1_1distributed.html#a82ef5e8cc7ac62cd228e51b1c1b77cb7">mlx::core::distributed</a>, <a class="el" href="namespacemlx_1_1core_1_1distributed_1_1detail.html#aeb5a1726358213bc75756506f7b54d04">mlx::core::distributed::detail</a></li>
<li>all_reduce_dispatch()&#160;:&#160;<a class="el" href="namespacemlx_1_1core.html#af7b7ca7c6aa87558d9f98cee5c7a99a8">mlx::core</a></li>
<li>all_reduce_dispatch()&#160;:&#160;<a class="el" href="namespacemlx_1_1core.html#a3ab0fd997d9a35782106ff083a72e098">mlx::core</a></li>
<li>all_sum()&#160;:&#160;<a class="el" href="namespacemlx_1_1core_1_1distributed.html#a67ccb1a5445fc6f5db49dd36a15e5980">mlx::core::distributed</a>, <a class="el" href="namespacemlx_1_1core_1_1distributed_1_1detail.html#aa1d225b25f7b6426c48c5e35860ee960">mlx::core::distributed::detail</a></li>
<li>allclose()&#160;:&#160;<a class="el" href="group__ops.html#gaf0cd4257de7542daf9faf5e605e31020">mlx::core</a></li>
<li>alloc_tmp()&#160;:&#160;<a class="el" href="namespacepocketfft_1_1detail.html#a4db03cbcd9d43d9e0b0b9067713c80e9">pocketfft::detail</a></li>

View File

@ -112,7 +112,7 @@ $(function(){ initResizable(false); });
<li>get_pool()&#160;:&#160;<a class="el" href="namespacepocketfft_1_1detail_1_1threading.html#a7ec2b3f99232bd0f15f7b022c59d139a">pocketfft::detail::threading</a></li>
<li>get_primitive_string()&#160;:&#160;<a class="el" href="namespacemlx_1_1core.html#ad4be35b310a252edd80d9cf04f094a60">mlx::core</a></li>
<li>get_quantized_kernel()&#160;:&#160;<a class="el" href="namespacemlx_1_1core.html#aa3faeae5378bfaafe3ce3432a051e43e">mlx::core</a></li>
<li>get_reduce_init_kernel()&#160;:&#160;<a class="el" href="namespacemlx_1_1core.html#a51c4bb09230348bd0252e22bfdc9bc89">mlx::core</a></li>
<li>get_reduce_init_kernel()&#160;:&#160;<a class="el" href="namespacemlx_1_1core.html#a3bd386cb6db09f636963ce66ceaf8647">mlx::core</a></li>
<li>get_reduce_kernel()&#160;:&#160;<a class="el" href="namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b">mlx::core</a></li>
<li>get_reduction_plan()&#160;:&#160;<a class="el" href="namespacemlx_1_1core.html#ac97b5a6f009ca3d99854ce9512c20dba">mlx::core</a></li>
<li>get_scan_kernel()&#160;:&#160;<a class="el" href="namespacemlx_1_1core.html#aeefaff208444d3fa61ecc0946fe1de5f">mlx::core</a></li>

View File

@ -116,7 +116,7 @@ $(function(){ initResizable(false); });
<li>get_pool()&#160;:&#160;<a class="el" href="namespacepocketfft_1_1detail_1_1threading.html#a7ec2b3f99232bd0f15f7b022c59d139a">pocketfft::detail::threading</a></li>
<li>get_primitive_string()&#160;:&#160;<a class="el" href="namespacemlx_1_1core.html#ad4be35b310a252edd80d9cf04f094a60">mlx::core</a></li>
<li>get_quantized_kernel()&#160;:&#160;<a class="el" href="namespacemlx_1_1core.html#aa3faeae5378bfaafe3ce3432a051e43e">mlx::core</a></li>
<li>get_reduce_init_kernel()&#160;:&#160;<a class="el" href="namespacemlx_1_1core.html#a51c4bb09230348bd0252e22bfdc9bc89">mlx::core</a></li>
<li>get_reduce_init_kernel()&#160;:&#160;<a class="el" href="namespacemlx_1_1core.html#a3bd386cb6db09f636963ce66ceaf8647">mlx::core</a></li>
<li>get_reduce_kernel()&#160;:&#160;<a class="el" href="namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b">mlx::core</a></li>
<li>get_reduction_plan()&#160;:&#160;<a class="el" href="namespacemlx_1_1core.html#ac97b5a6f009ca3d99854ce9512c20dba">mlx::core</a></li>
<li>get_scan_kernel()&#160;:&#160;<a class="el" href="namespacemlx_1_1core.html#aeefaff208444d3fa61ecc0946fe1de5f">mlx::core</a></li>

View File

@ -534,8 +534,8 @@ Functions</h2></td></tr>
<tr class="separator:a84ebe6275218070f0ea320f126f64e22"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:afb57825bb763050cc9a9d194aa41ac36" id="r_afb57825bb763050cc9a9d194aa41ac36"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState *&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#afb57825bb763050cc9a9d194aa41ac36">get_mb_sort_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const std::string &amp;kernel_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;in, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;idx, int bn, int tn)</td></tr>
<tr class="separator:afb57825bb763050cc9a9d194aa41ac36"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a51c4bb09230348bd0252e22bfdc9bc89" id="r_a51c4bb09230348bd0252e22bfdc9bc89"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState *&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#a51c4bb09230348bd0252e22bfdc9bc89">get_reduce_init_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const std::string &amp;kernel_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out)</td></tr>
<tr class="separator:a51c4bb09230348bd0252e22bfdc9bc89"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a3bd386cb6db09f636963ce66ceaf8647" id="r_a3bd386cb6db09f636963ce66ceaf8647"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState *&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#a3bd386cb6db09f636963ce66ceaf8647">get_reduce_init_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const std::string &amp;kernel_name, const std::string &amp;func_name, const std::string &amp;op_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out)</td></tr>
<tr class="separator:a3bd386cb6db09f636963ce66ceaf8647"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a7aa91fcfe8b9caa42d60a957f11bfe6b" id="r_a7aa91fcfe8b9caa42d60a957f11bfe6b"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState *&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#a7aa91fcfe8b9caa42d60a957f11bfe6b">get_reduce_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const std::string &amp;kernel_name, const std::string &amp;func_name, const std::string &amp;op_name, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;in, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, int ndim=-1, int bm=-1, int bn=-1)</td></tr>
<tr class="separator:a7aa91fcfe8b9caa42d60a957f11bfe6b"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a84fa8e0aee321a9d614433a0b933103b" id="r_a84fa8e0aee321a9d614433a0b933103b"><td class="memItemLeft" align="right" valign="top">MTL::ComputePipelineState *&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#a84fa8e0aee321a9d614433a0b933103b">get_steel_gemm_fused_kernel</a> (<a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const std::string &amp;kernel_name, const std::string &amp;hash_name, const <a class="el" href="namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54">metal::MTLFCList</a> &amp;func_consts, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)</td></tr>
@ -563,8 +563,8 @@ Functions</h2></td></tr>
<tr class="separator:a227588758ccc9ee869dba147e830bb74"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:ab43a7633794498e1c6775cca829eb886" id="r_ab43a7633794498e1c6775cca829eb886"><td class="memItemLeft" align="right" valign="top">void&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#ab43a7633794498e1c6775cca829eb886">steel_matmul</a> (const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &amp;s, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;a, const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;b, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, int M, int N, int K, int batch_size_out, int lda, int ldb, bool transpose_a, bool transpose_b, std::vector&lt; <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &gt; &amp;copies, std::vector&lt; int &gt; batch_shape={}, std::vector&lt; size_t &gt; A_batch_stride={}, std::vector&lt; size_t &gt; B_batch_stride={})</td></tr>
<tr class="separator:ab43a7633794498e1c6775cca829eb886"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:af7b7ca7c6aa87558d9f98cee5c7a99a8" id="r_af7b7ca7c6aa87558d9f98cee5c7a99a8"><td class="memItemLeft" align="right" valign="top">void&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#af7b7ca7c6aa87558d9f98cee5c7a99a8">all_reduce_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, const std::string &amp;op_name, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &amp;compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &amp;s, std::vector&lt; <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &gt; &amp;copies)</td></tr>
<tr class="separator:af7b7ca7c6aa87558d9f98cee5c7a99a8"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a3ab0fd997d9a35782106ff083a72e098" id="r_a3ab0fd997d9a35782106ff083a72e098"><td class="memItemLeft" align="right" valign="top">void&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#a3ab0fd997d9a35782106ff083a72e098">all_reduce_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, const std::string &amp;op_name, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &amp;compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &amp;s)</td></tr>
<tr class="separator:a3ab0fd997d9a35782106ff083a72e098"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:ab1eeca8ec6fa31819ee108fa6ed2c41b" id="r_ab1eeca8ec6fa31819ee108fa6ed2c41b"><td class="memItemLeft" align="right" valign="top">void&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#ab1eeca8ec6fa31819ee108fa6ed2c41b">row_reduce_general_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, const std::string &amp;op_name, const <a class="el" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a> &amp;plan, const std::vector&lt; int &gt; &amp;axes, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &amp;compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &amp;s)</td></tr>
<tr class="separator:ab1eeca8ec6fa31819ee108fa6ed2c41b"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:aa0332c64ee9965f05026c30a0b778000" id="r_aa0332c64ee9965f05026c30a0b778000"><td class="memItemLeft" align="right" valign="top">void&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#aa0332c64ee9965f05026c30a0b778000">strided_reduce_general_dispatch</a> (const <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;in, <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &amp;out, const std::string &amp;op_name, const <a class="el" href="structmlx_1_1core_1_1_reduction_plan.html">ReductionPlan</a> &amp;plan, const std::vector&lt; int &gt; &amp;axes, <a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">CommandEncoder</a> &amp;compute_encoder, <a class="el" href="classmlx_1_1core_1_1metal_1_1_device.html">metal::Device</a> &amp;d, const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &amp;s)</td></tr>
@ -2634,8 +2634,8 @@ template&lt;typename... T&gt; </div>
</div>
</div>
<h2 class="groupheader">Function Documentation</h2>
<a id="af7b7ca7c6aa87558d9f98cee5c7a99a8" name="af7b7ca7c6aa87558d9f98cee5c7a99a8"></a>
<h2 class="memtitle"><span class="permalink"><a href="#af7b7ca7c6aa87558d9f98cee5c7a99a8">&#9670;&#160;</a></span>all_reduce_dispatch()</h2>
<a id="a3ab0fd997d9a35782106ff083a72e098" name="a3ab0fd997d9a35782106ff083a72e098"></a>
<h2 class="memtitle"><span class="permalink"><a href="#a3ab0fd997d9a35782106ff083a72e098">&#9670;&#160;</a></span>all_reduce_dispatch()</h2>
<div class="memitem">
<div class="memproto">
@ -2668,12 +2668,7 @@ template&lt;typename... T&gt; </div>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &amp;</td> <td class="paramname"><span class="paramname"><em>s</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">std::vector&lt; <a class="el" href="classmlx_1_1core_1_1array.html">array</a> &gt; &amp;</td> <td class="paramname"><span class="paramname"><em>copies</em></span>&#160;)</td>
<td class="paramtype">const <a class="el" href="structmlx_1_1core_1_1_stream.html">Stream</a> &amp;</td> <td class="paramname"><span class="paramname"><em>s</em></span>&#160;)</td>
</tr>
</table>
</div><div class="memdoc">
@ -4418,8 +4413,8 @@ template&lt;typename... Arrays, typename = enable_for_arrays_t&lt;Arrays...&gt;
</div>
</div>
<a id="a51c4bb09230348bd0252e22bfdc9bc89" name="a51c4bb09230348bd0252e22bfdc9bc89"></a>
<h2 class="memtitle"><span class="permalink"><a href="#a51c4bb09230348bd0252e22bfdc9bc89">&#9670;&#160;</a></span>get_reduce_init_kernel()</h2>
<a id="a3bd386cb6db09f636963ce66ceaf8647" name="a3bd386cb6db09f636963ce66ceaf8647"></a>
<h2 class="memtitle"><span class="permalink"><a href="#a3bd386cb6db09f636963ce66ceaf8647">&#9670;&#160;</a></span>get_reduce_init_kernel()</h2>
<div class="memitem">
<div class="memproto">
@ -4434,6 +4429,16 @@ template&lt;typename... Arrays, typename = enable_for_arrays_t&lt;Arrays...&gt;
<td></td>
<td class="paramtype">const std::string &amp;</td> <td class="paramname"><span class="paramname"><em>kernel_name</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const std::string &amp;</td> <td class="paramname"><span class="paramname"><em>func_name</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const std::string &amp;</td> <td class="paramname"><span class="paramname"><em>op_name</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>

Binary file not shown.

View File

@ -867,6 +867,7 @@
<dt class="sig sig-object py" id="mlx.core.fast.metal_kernel">
<span class="sig-name descname"><span class="pre">metal_kernel</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">name</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><span class="pre">str</span></a></span></em>, <em class="sig-param"><span class="n"><span class="pre">input_names</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.13)"><span class="pre">Sequence</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><span class="pre">str</span></a><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">output_names</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence" title="(in Python v3.13)"><span class="pre">Sequence</span></a><span class="p"><span class="pre">[</span></span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><span class="pre">str</span></a><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">source</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><span class="pre">str</span></a></span></em>, <em class="sig-param"><span class="n"><span class="pre">header</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><span class="pre">str</span></a></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">''</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">ensure_row_contiguous</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><span class="pre">bool</span></a></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">atomic_outputs</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><span class="pre">bool</span></a></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><a class="reference external" href="https://docs.python.org/3/library/functions.html#object" title="(in Python v3.13)"><span class="pre">object</span></a></span></span><a class="headerlink" href="#mlx.core.fast.metal_kernel" title="Link to this definition">#</a></dt>
<dd><p>A jit-compiled custom Metal kernel defined from a source string.</p>
<p>Full documentation: <a class="reference internal" href="../../dev/custom_metal_kernels.html#custom-metal-kernels"><span class="std std-ref">Custom Metal Kernels</span></a>.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters<span class="colon">:</span></dt>
<dd class="field-odd"><ul class="simple">

View File

@ -140,9 +140,9 @@ Functions</h2></td></tr>
<tr class="memitem:a8e13c7d895624f738d2a6d9893b687fd" id="r_a8e13c7d895624f738d2a6d9893b687fd"><td class="memTemplParams" colspan="2">template&lt;typename T , int group_size, int bits&gt; </td></tr>
<tr class="memitem:a8e13c7d895624f738d2a6d9893b687fd"><td class="memTemplItemLeft" align="right" valign="top">METAL_FUNC void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a8e13c7d895624f738d2a6d9893b687fd">qmv_impl</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &amp;in_vec_size, const constant int &amp;out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:a8e13c7d895624f738d2a6d9893b687fd"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a4a8c8db7d5d480733726fd6d1a645e12" id="r_a4a8c8db7d5d480733726fd6d1a645e12"><td class="memTemplParams" colspan="2">template&lt;typename T , const int group_size, const int bits&gt; </td></tr>
<tr class="memitem:a4a8c8db7d5d480733726fd6d1a645e12"><td class="memTemplItemLeft" align="right" valign="top">METAL_FUNC void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a4a8c8db7d5d480733726fd6d1a645e12">qvm_impl</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &amp;in_vec_size, const constant int &amp;out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:a4a8c8db7d5d480733726fd6d1a645e12"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a1546533c5b925b2fbb3bec870ec7487a" id="r_a1546533c5b925b2fbb3bec870ec7487a"><td class="memTemplParams" colspan="2">template&lt;typename T , const int group_size, const int bits&gt; </td></tr>
<tr class="memitem:a1546533c5b925b2fbb3bec870ec7487a"><td class="memTemplItemLeft" align="right" valign="top">METAL_FUNC void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a1546533c5b925b2fbb3bec870ec7487a">qvm_impl</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const int in_vec_size, const int out_vec_size, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:a1546533c5b925b2fbb3bec870ec7487a"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:af5750a35e8f5462218effba719f7f5b8" id="r_af5750a35e8f5462218effba719f7f5b8"><td class="memTemplParams" colspan="2">template&lt;typename T , const int group_size, const int bits, const bool aligned_N, const int BM = 32, const int BK = 32, const int BN = 32&gt; </td></tr>
<tr class="memitem:af5750a35e8f5462218effba719f7f5b8"><td class="memTemplItemLeft" align="right" valign="top">METAL_FUNC void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#af5750a35e8f5462218effba719f7f5b8">qmm_t_impl</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, threadgroup T *Xs, threadgroup T *Ws, const constant int &amp;K, const constant int &amp;N, const constant int &amp;M, uint3 tid, uint lid, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:af5750a35e8f5462218effba719f7f5b8"><td class="memSeparator" colspan="2">&#160;</td></tr>
@ -167,6 +167,9 @@ Functions</h2></td></tr>
<tr class="memitem:ad84f7d5ab9e32dbbe3ca759ae5d5d5c5" id="r_ad84f7d5ab9e32dbbe3ca759ae5d5d5c5"><td class="memTemplParams" colspan="2">template&lt;typename T , const int group_size, const int bits, bool batched&gt; </td></tr>
<tr class="memitem:ad84f7d5ab9e32dbbe3ca759ae5d5d5c5"><td class="memTemplItemLeft" align="right" valign="top">void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#ad84f7d5ab9e32dbbe3ca759ae5d5d5c5">qvm</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &amp;in_vec_size, const constant int &amp;out_vec_size, const constant int &amp;x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &amp;w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:ad84f7d5ab9e32dbbe3ca759ae5d5d5c5"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:ab8243818512d6078d23e6ffb65fd7bb8" id="r_ab8243818512d6078d23e6ffb65fd7bb8"><td class="memTemplParams" colspan="2">template&lt;typename T , const int group_size, const int bits, int split_k = 32&gt; </td></tr>
<tr class="memitem:ab8243818512d6078d23e6ffb65fd7bb8"><td class="memTemplItemLeft" align="right" valign="top">void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#ab8243818512d6078d23e6ffb65fd7bb8">qvm_split_k</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &amp;in_vec_size, const constant int &amp;out_vec_size, const constant int &amp;x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &amp;w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, const constant int &amp;final_block_size, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:ab8243818512d6078d23e6ffb65fd7bb8"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:abe2e3ef0ee4ec2cb61dc5330ad463d10" id="r_abe2e3ef0ee4ec2cb61dc5330ad463d10"><td class="memTemplParams" colspan="2">template&lt;typename T , const int group_size, const int bits, const bool aligned_N, const bool batched, const int BM = 32, const int BK = 32, const int BN = 32&gt; </td></tr>
<tr class="memitem:abe2e3ef0ee4ec2cb61dc5330ad463d10"><td class="memTemplItemLeft" align="right" valign="top">void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#abe2e3ef0ee4ec2cb61dc5330ad463d10">qmm_t</a> (const device uint32_t *w, const device T *scales, const device T *biases, const device T *x, device T *y, const constant int &amp;K, const constant int &amp;N, const constant int &amp;M, const constant int &amp;x_batch_ndims, const constant int *x_shape, const constant size_t *x_strides, const constant int &amp;w_batch_ndims, const constant int *w_shape, const constant size_t *w_strides, const constant size_t *s_strides, const constant size_t *b_strides, uint3 tid, uint lid, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:abe2e3ef0ee4ec2cb61dc5330ad463d10"><td class="memSeparator" colspan="2">&#160;</td></tr>
@ -2485,8 +2488,8 @@ template&lt;typename T , const int group_size, const int bits, bool batched&gt;
</div>
</div>
<a id="a4a8c8db7d5d480733726fd6d1a645e12" name="a4a8c8db7d5d480733726fd6d1a645e12"></a>
<h2 class="memtitle"><span class="permalink"><a href="#a4a8c8db7d5d480733726fd6d1a645e12">&#9670;&#160;</a></span>qvm_impl()</h2>
<a id="a1546533c5b925b2fbb3bec870ec7487a" name="a1546533c5b925b2fbb3bec870ec7487a"></a>
<h2 class="memtitle"><span class="permalink"><a href="#a1546533c5b925b2fbb3bec870ec7487a">&#9670;&#160;</a></span>qvm_impl()</h2>
<div class="memitem">
<div class="memproto">
@ -2518,6 +2521,69 @@ template&lt;typename T , const int group_size, const int bits&gt; </div>
<td></td>
<td class="paramtype">device T *</td> <td class="paramname"><span class="paramname"><em>y</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const int</td> <td class="paramname"><span class="paramname"><em>in_vec_size</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const int</td> <td class="paramname"><span class="paramname"><em>out_vec_size</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>tid</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_gid</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_lid</em></span>&#160;)</td>
</tr>
</table>
</div><div class="memdoc">
</div>
</div>
<a id="ab8243818512d6078d23e6ffb65fd7bb8" name="ab8243818512d6078d23e6ffb65fd7bb8"></a>
<h2 class="memtitle"><span class="permalink"><a href="#ab8243818512d6078d23e6ffb65fd7bb8">&#9670;&#160;</a></span>qvm_split_k()</h2>
<div class="memitem">
<div class="memproto">
<div class="memtemplate">
template&lt;typename T , const int group_size, const int bits, int split_k = 32&gt; </div>
<table class="memname">
<tr>
<td class="memname">void qvm_split_k </td>
<td>(</td>
<td class="paramtype">const device uint32_t *</td> <td class="paramname"><span class="paramname"><em>w</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>scales</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>biases</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>x</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">device T *</td> <td class="paramname"><span class="paramname"><em>y</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
@ -2528,6 +2594,51 @@ template&lt;typename T , const int group_size, const int bits&gt; </div>
<td></td>
<td class="paramtype">const constant int &amp;</td> <td class="paramname"><span class="paramname"><em>out_vec_size</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant int &amp;</td> <td class="paramname"><span class="paramname"><em>x_batch_ndims</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>x_shape</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>x_strides</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant int &amp;</td> <td class="paramname"><span class="paramname"><em>w_batch_ndims</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>w_shape</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>w_strides</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>s_strides</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>b_strides</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant int &amp;</td> <td class="paramname"><span class="paramname"><em>final_block_size</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>

File diff suppressed because it is too large Load Diff

View File

@ -98,15 +98,207 @@ $(function(){ initResizable(false); });
<table class="memberdecls">
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a id="func-members" name="func-members"></a>
Functions</h2></td></tr>
<tr class="memitem:adf7aeb18cd1d5042cf6d9b46b582d8ce" id="r_adf7aeb18cd1d5042cf6d9b46b582d8ce"><td class="memTemplParams" colspan="2">template&lt;typename T , typename U , typename Op , int NDIMS, int N_READS = REDUCE_N_READS&gt; </td></tr>
<tr class="memitem:adf7aeb18cd1d5042cf6d9b46b582d8ce"><td class="memTemplItemLeft" align="right" valign="top">void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#adf7aeb18cd1d5042cf6d9b46b582d8ce">col_reduce_small</a> (const device T *in, device U *out, const constant size_t &amp;reduction_size, const constant size_t &amp;reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &amp;ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &amp;reduce_ndim, const constant size_t &amp;non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 tsize)</td></tr>
<tr class="separator:adf7aeb18cd1d5042cf6d9b46b582d8ce"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a7c378443a2b6f4d9210db8a21a9ac4f5" id="r_a7c378443a2b6f4d9210db8a21a9ac4f5"><td class="memTemplParams" colspan="2">template&lt;typename T , typename U , typename Op , int NDIMS&gt; </td></tr>
<tr class="memitem:a7c378443a2b6f4d9210db8a21a9ac4f5"><td class="memTemplItemLeft" align="right" valign="top">void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a7c378443a2b6f4d9210db8a21a9ac4f5">col_reduce_small</a> (const device T *in, device U *out, const constant size_t &amp;reduction_size, const constant size_t &amp;reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &amp;ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &amp;reduce_ndim, const constant size_t &amp;non_col_reductions, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)</td></tr>
<tr class="separator:a7c378443a2b6f4d9210db8a21a9ac4f5"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a5b4f4c4c247ad341ff8d31dcbbbce0eb" id="r_a5b4f4c4c247ad341ff8d31dcbbbce0eb"><td class="memTemplParams" colspan="2">template&lt;typename T , typename U , typename Op , int NDIMS&gt; </td></tr>
<tr class="memitem:a5b4f4c4c247ad341ff8d31dcbbbce0eb"><td class="memTemplItemLeft" align="right" valign="top">void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a5b4f4c4c247ad341ff8d31dcbbbce0eb">col_reduce_longcolumn</a> (const device T *in, device U *out, const constant size_t &amp;reduction_size, const constant size_t &amp;reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &amp;ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &amp;reduce_ndim, const constant size_t &amp;non_col_reductions, const constant size_t &amp;out_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)</td></tr>
<tr class="separator:a5b4f4c4c247ad341ff8d31dcbbbce0eb"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a11bfc6112ae2386ac03f5ea7b7d93385" id="r_a11bfc6112ae2386ac03f5ea7b7d93385"><td class="memTemplParams" colspan="2">template&lt;typename T , typename U , typename Op , int NDIMS, int BM, int BN&gt; </td></tr>
<tr class="memitem:a11bfc6112ae2386ac03f5ea7b7d93385"><td class="memTemplItemLeft" align="right" valign="top">void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a11bfc6112ae2386ac03f5ea7b7d93385">col_reduce_looped</a> (const device T *in, device U *out, const constant size_t &amp;reduction_size, const constant size_t &amp;reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &amp;ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &amp;reduce_ndim, const constant size_t &amp;non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)</td></tr>
<tr class="memdesc:a11bfc6112ae2386ac03f5ea7b7d93385"><td class="mdescLeft">&#160;</td><td class="mdescRight">Our approach is the following simple looped approach: <br /></td></tr>
<tr class="separator:a11bfc6112ae2386ac03f5ea7b7d93385"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a0e92fc74eeaa8ee2ceb83bafc6eb1d7d" id="r_a0e92fc74eeaa8ee2ceb83bafc6eb1d7d"><td class="memTemplParams" colspan="2">template&lt;typename T , typename U , typename Op , int NDIMS, int BM, int BN&gt; </td></tr>
<tr class="memitem:a0e92fc74eeaa8ee2ceb83bafc6eb1d7d"><td class="memTemplItemLeft" align="right" valign="top">void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d">col_reduce_2pass</a> (const device T *in, device U *out, const constant size_t &amp;reduction_size, const constant size_t &amp;reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &amp;ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &amp;reduce_ndim, const constant size_t &amp;non_col_reductions, const constant size_t &amp;out_size, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)</td></tr>
<tr class="separator:a0e92fc74eeaa8ee2ceb83bafc6eb1d7d"><td class="memSeparator" colspan="2">&#160;</td></tr>
</table>
<h2 class="groupheader">Function Documentation</h2>
<a id="a0e92fc74eeaa8ee2ceb83bafc6eb1d7d" name="a0e92fc74eeaa8ee2ceb83bafc6eb1d7d"></a>
<h2 class="memtitle"><span class="permalink"><a href="#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d">&#9670;&#160;</a></span>col_reduce_2pass()</h2>
<div class="memitem">
<div class="memproto">
<div class="memtemplate">
template&lt;typename T , typename U , typename Op , int NDIMS, int BM, int BN&gt; </div>
<table class="memname">
<tr>
<td class="memname">void col_reduce_2pass </td>
<td>(</td>
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>in</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">device U *</td> <td class="paramname"><span class="paramname"><em>out</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>reduction_size</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>reduction_stride</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>shape</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>strides</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant int &amp;</td> <td class="paramname"><span class="paramname"><em>ndim</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>reduce_shape</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>reduce_strides</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant int &amp;</td> <td class="paramname"><span class="paramname"><em>reduce_ndim</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>non_col_reductions</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>out_size</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>gid</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>gsize</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_lane_id</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_group_id</em></span>&#160;)</td>
</tr>
</table>
</div><div class="memdoc">
</div>
</div>
<a id="a5b4f4c4c247ad341ff8d31dcbbbce0eb" name="a5b4f4c4c247ad341ff8d31dcbbbce0eb"></a>
<h2 class="memtitle"><span class="permalink"><a href="#a5b4f4c4c247ad341ff8d31dcbbbce0eb">&#9670;&#160;</a></span>col_reduce_longcolumn()</h2>
<div class="memitem">
<div class="memproto">
<div class="memtemplate">
template&lt;typename T , typename U , typename Op , int NDIMS&gt; </div>
<table class="memname">
<tr>
<td class="memname">void col_reduce_longcolumn </td>
<td>(</td>
<td class="paramtype">const device T *</td> <td class="paramname"><span class="paramname"><em>in</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">device U *</td> <td class="paramname"><span class="paramname"><em>out</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>reduction_size</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>reduction_stride</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>shape</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>strides</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant int &amp;</td> <td class="paramname"><span class="paramname"><em>ndim</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant int *</td> <td class="paramname"><span class="paramname"><em>reduce_shape</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t *</td> <td class="paramname"><span class="paramname"><em>reduce_strides</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant int &amp;</td> <td class="paramname"><span class="paramname"><em>reduce_ndim</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>non_col_reductions</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>out_size</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>gid</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>gsize</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>lid</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>lsize</em></span>&#160;)</td>
</tr>
</table>
</div><div class="memdoc">
</div>
</div>
<a id="a11bfc6112ae2386ac03f5ea7b7d93385" name="a11bfc6112ae2386ac03f5ea7b7d93385"></a>
<h2 class="memtitle"><span class="permalink"><a href="#a11bfc6112ae2386ac03f5ea7b7d93385">&#9670;&#160;</a></span>col_reduce_looped()</h2>
@ -204,13 +396,13 @@ template&lt;typename T , typename U , typename Op , int NDIMS, int BM, int BN&gt
</div>
</div>
<a id="adf7aeb18cd1d5042cf6d9b46b582d8ce" name="adf7aeb18cd1d5042cf6d9b46b582d8ce"></a>
<h2 class="memtitle"><span class="permalink"><a href="#adf7aeb18cd1d5042cf6d9b46b582d8ce">&#9670;&#160;</a></span>col_reduce_small()</h2>
<a id="a7c378443a2b6f4d9210db8a21a9ac4f5" name="a7c378443a2b6f4d9210db8a21a9ac4f5"></a>
<h2 class="memtitle"><span class="permalink"><a href="#a7c378443a2b6f4d9210db8a21a9ac4f5">&#9670;&#160;</a></span>col_reduce_small()</h2>
<div class="memitem">
<div class="memproto">
<div class="memtemplate">
template&lt;typename T , typename U , typename Op , int NDIMS, int N_READS = REDUCE_N_READS&gt; </div>
template&lt;typename T , typename U , typename Op , int NDIMS&gt; </div>
<table class="memname">
<tr>
<td class="memname">void col_reduce_small </td>
@ -280,22 +472,12 @@ template&lt;typename T , typename U , typename Op , int NDIMS, int N_READS = RED
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_lane_id</em></span>, </td>
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>lid</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">uint</td> <td class="paramname"><span class="paramname"><em>simd_group_id</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>tid</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>tsize</em></span>&#160;)</td>
<td class="paramtype">uint3</td> <td class="paramname"><span class="paramname"><em>lsize</em></span>&#160;)</td>
</tr>
</table>
</div><div class="memdoc">

View File

@ -93,334 +93,392 @@ $(function(){ initResizable(false); });
<div class="contents">
<a href="reduce__col_8h.html">Go to the documentation of this file.</a><div class="fragment"><div class="line"><a id="l00001" name="l00001"></a><span class="lineno"> 1</span><span class="comment">// Copyright © 2023-2024 Apple Inc.</span></div>
<div class="line"><a id="l00002" name="l00002"></a><span class="lineno"> 2</span> </div>
<div class="line"><a id="l00003" name="l00003"></a><span class="lineno"> 3</span><span class="keyword">template</span> &lt;</div>
<div class="line"><a id="l00004" name="l00004"></a><span class="lineno"> 4</span> <span class="keyword">typename</span> T,</div>
<div class="line"><a id="l00005" name="l00005"></a><span class="lineno"> 5</span> <span class="keyword">typename</span> U,</div>
<div class="line"><a id="l00006" name="l00006"></a><span class="lineno"> 6</span> <span class="keyword">typename</span> Op,</div>
<div class="line"><a id="l00007" name="l00007"></a><span class="lineno"> 7</span> <span class="keywordtype">int</span> NDIMS,</div>
<div class="line"><a id="l00008" name="l00008"></a><span class="lineno"> 8</span> <span class="keywordtype">int</span> N_READS = <a class="code hl_variable" href="defines_8h.html#a2ad505864a2ab786147766900bc18c21">REDUCE_N_READS</a>&gt;</div>
<div class="foldopen" id="foldopen00009" data-start="{" data-end="}">
<div class="line"><a id="l00009" name="l00009"></a><span class="lineno"><a class="line" href="reduce__col_8h.html#adf7aeb18cd1d5042cf6d9b46b582d8ce"> 9</a></span>[[kernel]] <span class="keywordtype">void</span> <a class="code hl_function" href="reduce__col_8h.html#adf7aeb18cd1d5042cf6d9b46b582d8ce">col_reduce_small</a>(</div>
<div class="line"><a id="l00010" name="l00010"></a><span class="lineno"> 10</span> <span class="keyword">const</span> device T* in [[buffer(0)]],</div>
<div class="line"><a id="l00011" name="l00011"></a><span class="lineno"> 11</span> device U* out [[buffer(1)]],</div>
<div class="line"><a id="l00012" name="l00012"></a><span class="lineno"> 12</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; reduction_size [[buffer(2)]],</div>
<div class="line"><a id="l00013" name="l00013"></a><span class="lineno"> 13</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; reduction_stride [[buffer(3)]],</div>
<div class="line"><a id="l00014" name="l00014"></a><span class="lineno"> 14</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* shape [[buffer(4)]],</div>
<div class="line"><a id="l00015" name="l00015"></a><span class="lineno"> 15</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* strides [[buffer(5)]],</div>
<div class="line"><a id="l00016" name="l00016"></a><span class="lineno"> 16</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>&amp; ndim [[buffer(6)]],</div>
<div class="line"><a id="l00017" name="l00017"></a><span class="lineno"> 17</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* reduce_shape [[buffer(7)]],</div>
<div class="line"><a id="l00018" name="l00018"></a><span class="lineno"> 18</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* reduce_strides [[buffer(8)]],</div>
<div class="line"><a id="l00019" name="l00019"></a><span class="lineno"> 19</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>&amp; reduce_ndim [[buffer(9)]],</div>
<div class="line"><a id="l00020" name="l00020"></a><span class="lineno"> 20</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; non_col_reductions [[buffer(10)]],</div>
<div class="line"><a id="l00021" name="l00021"></a><span class="lineno"> 21</span> uint3 gid [[threadgroup_position_in_grid]],</div>
<div class="line"><a id="l00022" name="l00022"></a><span class="lineno"> 22</span> uint3 gsize [[threadgroups_per_grid]],</div>
<div class="line"><a id="l00023" name="l00023"></a><span class="lineno"> 23</span> uint simd_lane_id [[thread_index_in_simdgroup]],</div>
<div class="line"><a id="l00024" name="l00024"></a><span class="lineno"> 24</span> uint simd_group_id [[simdgroup_index_in_threadgroup]],</div>
<div class="line"><a id="l00025" name="l00025"></a><span class="lineno"> 25</span> uint3 tid [[thread_position_in_grid]],</div>
<div class="line"><a id="l00026" name="l00026"></a><span class="lineno"> 26</span> uint3 tsize [[threads_per_grid]]) {</div>
<div class="line"><a id="l00027" name="l00027"></a><span class="lineno"> 27</span> Op <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>;</div>
<div class="line"><a id="l00028" name="l00028"></a><span class="lineno"> 28</span> <a class="code hl_struct" href="structlooped__elem__to__loc.html">looped_elem_to_loc&lt;NDIMS&gt;</a> loop;</div>
<div class="line"><a id="l00029" name="l00029"></a><span class="lineno"> 29</span> <span class="keyword">const</span> device T* row;</div>
<div class="line"><a id="l00030" name="l00030"></a><span class="lineno"> 30</span> </div>
<div class="line"><a id="l00031" name="l00031"></a><span class="lineno"> 31</span> <span class="comment">// Case 1: Small row small column</span></div>
<div class="line"><a id="l00032" name="l00032"></a><span class="lineno"> 32</span> <span class="keywordflow">if</span> (reduction_size * non_col_reductions &lt; 64 &amp;&amp; reduction_stride &lt; 32) {</div>
<div class="line"><a id="l00033" name="l00033"></a><span class="lineno"> 33</span> U totals[31];</div>
<div class="line"><a id="l00034" name="l00034"></a><span class="lineno"> 34</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; 31; i++) {</div>
<div class="line"><a id="l00035" name="l00035"></a><span class="lineno"> 35</span> totals[i] = Op::init;</div>
<div class="line"><a id="l00036" name="l00036"></a><span class="lineno"> 36</span> }</div>
<div class="line"><a id="l00037" name="l00037"></a><span class="lineno"> 37</span> </div>
<div class="line"><a id="l00038" name="l00038"></a><span class="lineno"> 38</span> <span class="keywordtype">short</span> stride = reduction_stride;</div>
<div class="line"><a id="l00039" name="l00039"></a><span class="lineno"> 39</span> <span class="keywordtype">short</span> size = reduction_size;</div>
<div class="line"><a id="l00040" name="l00040"></a><span class="lineno"> 40</span> <span class="keywordtype">short</span> blocks = stride / N_READS;</div>
<div class="line"><a id="l00041" name="l00041"></a><span class="lineno"> 41</span> <span class="keywordtype">short</span> extra = stride - blocks * N_READS;</div>
<div class="line"><a id="l00042" name="l00042"></a><span class="lineno"> 42</span> </div>
<div class="line"><a id="l00043" name="l00043"></a><span class="lineno"> 43</span> <span class="keywordtype">size_t</span> out_idx = tid.x + tsize.y * size_t(tid.y);</div>
<div class="line"><a id="l00044" name="l00044"></a><span class="lineno"> 44</span> in += <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a8fd0c8fc6058e650fc99bca8b6acd7d1">elem_to_loc</a>(out_idx, shape, strides, ndim);</div>
<div class="line"><a id="l00045" name="l00045"></a><span class="lineno"> 45</span> </div>
<div class="line"><a id="l00046" name="l00046"></a><span class="lineno"> 46</span> <span class="keywordflow">for</span> (uint r = 0; r &lt; non_col_reductions; r++) {</div>
<div class="line"><a id="l00047" name="l00047"></a><span class="lineno"> 47</span> row = in + loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#accc6d4957a8aeb38f5062754793b74d2">location</a>(r, reduce_shape, reduce_strides, reduce_ndim);</div>
<div class="line"><a id="l00048" name="l00048"></a><span class="lineno"> 48</span> </div>
<div class="line"><a id="l00049" name="l00049"></a><span class="lineno"> 49</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> i = 0; i &lt; size; i++) {</div>
<div class="line"><a id="l00050" name="l00050"></a><span class="lineno"> 50</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> j = 0; j &lt; blocks; j++) {</div>
<div class="line"><a id="l00051" name="l00051"></a><span class="lineno"> 51</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> k = 0; k &lt; N_READS; k++) {</div>
<div class="line"><a id="l00052" name="l00052"></a><span class="lineno"> 52</span> totals[j * N_READS + k] =</div>
<div class="line"><a id="l00053" name="l00053"></a><span class="lineno"> 53</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(totals[j * N_READS + k],</div>
<div class="line"><a id="l00054" name="l00054"></a><span class="lineno"> 54</span> <span class="keyword">static_cast&lt;</span>U<span class="keyword">&gt;</span>(row[i * stride + j * N_READS + k]));</div>
<div class="line"><a id="l00055" name="l00055"></a><span class="lineno"> 55</span> }</div>
<div class="line"><a id="l00003" name="l00003"></a><span class="lineno"> 3</span><span class="keyword">template</span> &lt;<span class="keyword">typename</span> T, <span class="keyword">typename</span> U, <span class="keyword">typename</span> Op, <span class="keywordtype">int</span> NDIMS&gt;</div>
<div class="foldopen" id="foldopen00004" data-start="{" data-end="}">
<div class="line"><a id="l00004" name="l00004"></a><span class="lineno"><a class="line" href="reduce__col_8h.html#a7c378443a2b6f4d9210db8a21a9ac4f5"> 4</a></span>[[kernel]] <span class="keywordtype">void</span> <a class="code hl_function" href="reduce__col_8h.html#a7c378443a2b6f4d9210db8a21a9ac4f5">col_reduce_small</a>(</div>
<div class="line"><a id="l00005" name="l00005"></a><span class="lineno"> 5</span> <span class="keyword">const</span> device T* in [[buffer(0)]],</div>
<div class="line"><a id="l00006" name="l00006"></a><span class="lineno"> 6</span> device U* out [[buffer(1)]],</div>
<div class="line"><a id="l00007" name="l00007"></a><span class="lineno"> 7</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; reduction_size [[buffer(2)]],</div>
<div class="line"><a id="l00008" name="l00008"></a><span class="lineno"> 8</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; reduction_stride [[buffer(3)]],</div>
<div class="line"><a id="l00009" name="l00009"></a><span class="lineno"> 9</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* shape [[buffer(4)]],</div>
<div class="line"><a id="l00010" name="l00010"></a><span class="lineno"> 10</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* strides [[buffer(5)]],</div>
<div class="line"><a id="l00011" name="l00011"></a><span class="lineno"> 11</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>&amp; ndim [[buffer(6)]],</div>
<div class="line"><a id="l00012" name="l00012"></a><span class="lineno"> 12</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* reduce_shape [[buffer(7)]],</div>
<div class="line"><a id="l00013" name="l00013"></a><span class="lineno"> 13</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* reduce_strides [[buffer(8)]],</div>
<div class="line"><a id="l00014" name="l00014"></a><span class="lineno"> 14</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>&amp; reduce_ndim [[buffer(9)]],</div>
<div class="line"><a id="l00015" name="l00015"></a><span class="lineno"> 15</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; non_col_reductions [[buffer(10)]],</div>
<div class="line"><a id="l00016" name="l00016"></a><span class="lineno"> 16</span> uint3 gid [[threadgroup_position_in_grid]],</div>
<div class="line"><a id="l00017" name="l00017"></a><span class="lineno"> 17</span> uint3 gsize [[threadgroups_per_grid]],</div>
<div class="line"><a id="l00018" name="l00018"></a><span class="lineno"> 18</span> uint3 lid [[thread_position_in_threadgroup]],</div>
<div class="line"><a id="l00019" name="l00019"></a><span class="lineno"> 19</span> uint3 lsize [[threads_per_threadgroup]]) {</div>
<div class="line"><a id="l00020" name="l00020"></a><span class="lineno"> 20</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> n_reads = 4;</div>
<div class="line"><a id="l00021" name="l00021"></a><span class="lineno"> 21</span> Op <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>;</div>
<div class="line"><a id="l00022" name="l00022"></a><span class="lineno"> 22</span> <a class="code hl_struct" href="structlooped__elem__to__loc.html">looped_elem_to_loc&lt;NDIMS&gt;</a> loop;</div>
<div class="line"><a id="l00023" name="l00023"></a><span class="lineno"> 23</span> <span class="keyword">const</span> device T* row;</div>
<div class="line"><a id="l00024" name="l00024"></a><span class="lineno"> 24</span> </div>
<div class="line"><a id="l00025" name="l00025"></a><span class="lineno"> 25</span> U totals[n_reads];</div>
<div class="line"><a id="l00026" name="l00026"></a><span class="lineno"> 26</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00027" name="l00027"></a><span class="lineno"> 27</span> totals[i] = Op::init;</div>
<div class="line"><a id="l00028" name="l00028"></a><span class="lineno"> 28</span> }</div>
<div class="line"><a id="l00029" name="l00029"></a><span class="lineno"> 29</span> </div>
<div class="line"><a id="l00030" name="l00030"></a><span class="lineno"> 30</span> <span class="keywordtype">size_t</span> column = size_t(gid.x) * lsize.x * n_reads + lid.x * n_reads;</div>
<div class="line"><a id="l00031" name="l00031"></a><span class="lineno"> 31</span> <span class="keywordflow">if</span> (column &gt;= reduction_stride) {</div>
<div class="line"><a id="l00032" name="l00032"></a><span class="lineno"> 32</span> <span class="keywordflow">return</span>;</div>
<div class="line"><a id="l00033" name="l00033"></a><span class="lineno"> 33</span> }</div>
<div class="line"><a id="l00034" name="l00034"></a><span class="lineno"> 34</span> <span class="keywordtype">bool</span> safe = column + n_reads &lt;= reduction_stride;</div>
<div class="line"><a id="l00035" name="l00035"></a><span class="lineno"> 35</span> </div>
<div class="line"><a id="l00036" name="l00036"></a><span class="lineno"> 36</span> <span class="keywordtype">size_t</span> out_idx = gid.y + gsize.y * size_t(gid.z);</div>
<div class="line"><a id="l00037" name="l00037"></a><span class="lineno"> 37</span> <span class="keywordtype">size_t</span> in_idx = <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a8fd0c8fc6058e650fc99bca8b6acd7d1">elem_to_loc</a>(out_idx, shape, strides, ndim);</div>
<div class="line"><a id="l00038" name="l00038"></a><span class="lineno"> 38</span> in += in_idx + column;</div>
<div class="line"><a id="l00039" name="l00039"></a><span class="lineno"> 39</span> </div>
<div class="line"><a id="l00040" name="l00040"></a><span class="lineno"> 40</span> <span class="keywordtype">size_t</span> total_rows = non_col_reductions * reduction_size;</div>
<div class="line"><a id="l00041" name="l00041"></a><span class="lineno"> 41</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(lid.y, reduce_shape, reduce_strides);</div>
<div class="line"><a id="l00042" name="l00042"></a><span class="lineno"> 42</span> <span class="keywordflow">for</span> (<span class="keywordtype">size_t</span> r = lid.y; r &lt; total_rows; r += lsize.y) {</div>
<div class="line"><a id="l00043" name="l00043"></a><span class="lineno"> 43</span> row = in + loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#accc6d4957a8aeb38f5062754793b74d2">location</a>(r, reduce_shape, reduce_strides, reduce_ndim);</div>
<div class="line"><a id="l00044" name="l00044"></a><span class="lineno"> 44</span> <span class="keywordflow">if</span> (safe) {</div>
<div class="line"><a id="l00045" name="l00045"></a><span class="lineno"> 45</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00046" name="l00046"></a><span class="lineno"> 46</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(<span class="keyword">static_cast&lt;</span>U<span class="keyword">&gt;</span>(row[i]), totals[i]);</div>
<div class="line"><a id="l00047" name="l00047"></a><span class="lineno"> 47</span> }</div>
<div class="line"><a id="l00048" name="l00048"></a><span class="lineno"> 48</span> } <span class="keywordflow">else</span> {</div>
<div class="line"><a id="l00049" name="l00049"></a><span class="lineno"> 49</span> U vals[n_reads];</div>
<div class="line"><a id="l00050" name="l00050"></a><span class="lineno"> 50</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00051" name="l00051"></a><span class="lineno"> 51</span> vals[i] =</div>
<div class="line"><a id="l00052" name="l00052"></a><span class="lineno"> 52</span> (column + i &lt; reduction_stride) ? static_cast&lt;U&gt;(row[i]) : <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>.init;</div>
<div class="line"><a id="l00053" name="l00053"></a><span class="lineno"> 53</span> }</div>
<div class="line"><a id="l00054" name="l00054"></a><span class="lineno"> 54</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00055" name="l00055"></a><span class="lineno"> 55</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(vals[i], totals[i]);</div>
<div class="line"><a id="l00056" name="l00056"></a><span class="lineno"> 56</span> }</div>
<div class="line"><a id="l00057" name="l00057"></a><span class="lineno"> 57</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> k = 0; k &lt; extra; k++) {</div>
<div class="line"><a id="l00058" name="l00058"></a><span class="lineno"> 58</span> totals[blocks * N_READS + k] =</div>
<div class="line"><a id="l00059" name="l00059"></a><span class="lineno"> 59</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(totals[blocks * N_READS + k],</div>
<div class="line"><a id="l00060" name="l00060"></a><span class="lineno"> 60</span> <span class="keyword">static_cast&lt;</span>U<span class="keyword">&gt;</span>(row[i * stride + blocks * N_READS + k]));</div>
<div class="line"><a id="l00061" name="l00061"></a><span class="lineno"> 61</span> }</div>
<div class="line"><a id="l00062" name="l00062"></a><span class="lineno"> 62</span> }</div>
<div class="line"><a id="l00063" name="l00063"></a><span class="lineno"> 63</span> </div>
<div class="line"><a id="l00064" name="l00064"></a><span class="lineno"> 64</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(reduce_shape, reduce_strides);</div>
<div class="line"><a id="l00065" name="l00065"></a><span class="lineno"> 65</span> }</div>
<div class="line"><a id="l00066" name="l00066"></a><span class="lineno"> 66</span> out += out_idx * reduction_stride;</div>
<div class="line"><a id="l00067" name="l00067"></a><span class="lineno"> 67</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> j = 0; j &lt; stride; j++) {</div>
<div class="line"><a id="l00068" name="l00068"></a><span class="lineno"> 68</span> out[j] = totals[j];</div>
<div class="line"><a id="l00069" name="l00069"></a><span class="lineno"> 69</span> }</div>
<div class="line"><a id="l00070" name="l00070"></a><span class="lineno"> 70</span> }</div>
<div class="line"><a id="l00071" name="l00071"></a><span class="lineno"> 71</span> </div>
<div class="line"><a id="l00072" name="l00072"></a><span class="lineno"> 72</span> <span class="comment">// Case 2: Long row small column</span></div>
<div class="line"><a id="l00073" name="l00073"></a><span class="lineno"> 73</span> <span class="keywordflow">else</span> <span class="keywordflow">if</span> (reduction_size * non_col_reductions &lt; 32) {</div>
<div class="line"><a id="l00074" name="l00074"></a><span class="lineno"> 74</span> U totals[N_READS];</div>
<div class="line"><a id="l00075" name="l00075"></a><span class="lineno"> 75</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; N_READS; i++) {</div>
<div class="line"><a id="l00076" name="l00076"></a><span class="lineno"> 76</span> totals[i] = Op::init;</div>
<div class="line"><a id="l00057" name="l00057"></a><span class="lineno"> 57</span> }</div>
<div class="line"><a id="l00058" name="l00058"></a><span class="lineno"> 58</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(lsize.y, reduce_shape, reduce_strides);</div>
<div class="line"><a id="l00059" name="l00059"></a><span class="lineno"> 59</span> }</div>
<div class="line"><a id="l00060" name="l00060"></a><span class="lineno"> 60</span> </div>
<div class="line"><a id="l00061" name="l00061"></a><span class="lineno"> 61</span> <span class="keywordflow">if</span> (lsize.y &gt; 1) {</div>
<div class="line"><a id="l00062" name="l00062"></a><span class="lineno"> 62</span> <span class="comment">// lsize.y should be &lt;= 8</span></div>
<div class="line"><a id="l00063" name="l00063"></a><span class="lineno"> 63</span> threadgroup U shared_vals[32 * 8 * n_reads];</div>
<div class="line"><a id="l00064" name="l00064"></a><span class="lineno"> 64</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00065" name="l00065"></a><span class="lineno"> 65</span> shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i];</div>
<div class="line"><a id="l00066" name="l00066"></a><span class="lineno"> 66</span> }</div>
<div class="line"><a id="l00067" name="l00067"></a><span class="lineno"> 67</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
<div class="line"><a id="l00068" name="l00068"></a><span class="lineno"> 68</span> <span class="keywordflow">if</span> (lid.y == 0) {</div>
<div class="line"><a id="l00069" name="l00069"></a><span class="lineno"> 69</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00070" name="l00070"></a><span class="lineno"> 70</span> totals[i] = shared_vals[lid.x * n_reads + i];</div>
<div class="line"><a id="l00071" name="l00071"></a><span class="lineno"> 71</span> }</div>
<div class="line"><a id="l00072" name="l00072"></a><span class="lineno"> 72</span> <span class="keywordflow">for</span> (uint j = 1; j &lt; lsize.y; j++) {</div>
<div class="line"><a id="l00073" name="l00073"></a><span class="lineno"> 73</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00074" name="l00074"></a><span class="lineno"> 74</span> totals[i] =</div>
<div class="line"><a id="l00075" name="l00075"></a><span class="lineno"> 75</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i],</div>
<div class="line"><a id="l00076" name="l00076"></a><span class="lineno"> 76</span> totals[i]);</div>
<div class="line"><a id="l00077" name="l00077"></a><span class="lineno"> 77</span> }</div>
<div class="line"><a id="l00078" name="l00078"></a><span class="lineno"> 78</span> </div>
<div class="line"><a id="l00079" name="l00079"></a><span class="lineno"> 79</span> <span class="keywordtype">short</span> size = reduction_size;</div>
<div class="line"><a id="l00080" name="l00080"></a><span class="lineno"> 80</span> <span class="keywordtype">size_t</span> offset = size_t(tid.x) * N_READS;</div>
<div class="line"><a id="l00081" name="l00081"></a><span class="lineno"> 81</span> <span class="keywordtype">bool</span> safe = offset + N_READS &lt;= reduction_stride;</div>
<div class="line"><a id="l00082" name="l00082"></a><span class="lineno"> 82</span> <span class="keywordtype">short</span> extra = reduction_stride - offset;</div>
<div class="line"><a id="l00083" name="l00083"></a><span class="lineno"> 83</span> </div>
<div class="line"><a id="l00084" name="l00084"></a><span class="lineno"> 84</span> <span class="keywordtype">size_t</span> out_idx = tid.y + tsize.z * size_t(tid.z);</div>
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"> 85</span> in += <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a8fd0c8fc6058e650fc99bca8b6acd7d1">elem_to_loc</a>(out_idx, shape, strides, ndim) + offset;</div>
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"> 86</span> </div>
<div class="line"><a id="l00087" name="l00087"></a><span class="lineno"> 87</span> <span class="keywordflow">for</span> (uint r = 0; r &lt; non_col_reductions; r++) {</div>
<div class="line"><a id="l00088" name="l00088"></a><span class="lineno"> 88</span> row = in + loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#accc6d4957a8aeb38f5062754793b74d2">location</a>(r, reduce_shape, reduce_strides, reduce_ndim);</div>
<div class="line"><a id="l00089" name="l00089"></a><span class="lineno"> 89</span> </div>
<div class="line"><a id="l00090" name="l00090"></a><span class="lineno"> 90</span> <span class="keywordflow">if</span> (safe) {</div>
<div class="line"><a id="l00091" name="l00091"></a><span class="lineno"> 91</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> i = 0; i &lt; size; i++) {</div>
<div class="line"><a id="l00092" name="l00092"></a><span class="lineno"> 92</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> j = 0; j &lt; N_READS; j++) {</div>
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"> 93</span> totals[j] =</div>
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"> 94</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(<span class="keyword">static_cast&lt;</span>U<span class="keyword">&gt;</span>(row[i * reduction_stride + j]), totals[j]);</div>
<div class="line"><a id="l00095" name="l00095"></a><span class="lineno"> 95</span> }</div>
<div class="line"><a id="l00096" name="l00096"></a><span class="lineno"> 96</span> }</div>
<div class="line"><a id="l00097" name="l00097"></a><span class="lineno"> 97</span> } <span class="keywordflow">else</span> {</div>
<div class="line"><a id="l00098" name="l00098"></a><span class="lineno"> 98</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> i = 0; i &lt; size; i++) {</div>
<div class="line"><a id="l00099" name="l00099"></a><span class="lineno"> 99</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> j = 0; j &lt; extra; j++) {</div>
<div class="line"><a id="l00100" name="l00100"></a><span class="lineno"> 100</span> totals[j] =</div>
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"> 101</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(<span class="keyword">static_cast&lt;</span>U<span class="keyword">&gt;</span>(row[i * reduction_stride + j]), totals[j]);</div>
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"> 102</span> }</div>
<div class="line"><a id="l00103" name="l00103"></a><span class="lineno"> 103</span> }</div>
<div class="line"><a id="l00104" name="l00104"></a><span class="lineno"> 104</span> }</div>
<div class="line"><a id="l00105" name="l00105"></a><span class="lineno"> 105</span> </div>
<div class="line"><a id="l00106" name="l00106"></a><span class="lineno"> 106</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(reduce_shape, reduce_strides);</div>
<div class="line"><a id="l00107" name="l00107"></a><span class="lineno"> 107</span> }</div>
<div class="line"><a id="l00108" name="l00108"></a><span class="lineno"> 108</span> out += out_idx * reduction_stride + offset;</div>
<div class="line"><a id="l00109" name="l00109"></a><span class="lineno"> 109</span> <span class="keywordflow">if</span> (safe) {</div>
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"> 110</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> i = 0; i &lt; N_READS; i++) {</div>
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"> 111</span> out[i] = totals[i];</div>
<div class="line"><a id="l00112" name="l00112"></a><span class="lineno"> 112</span> }</div>
<div class="line"><a id="l00113" name="l00113"></a><span class="lineno"> 113</span> } <span class="keywordflow">else</span> {</div>
<div class="line"><a id="l00114" name="l00114"></a><span class="lineno"> 114</span> <span class="keywordflow">for</span> (<span class="keywordtype">short</span> i = 0; i &lt; extra; i++) {</div>
<div class="line"><a id="l00115" name="l00115"></a><span class="lineno"> 115</span> out[i] = totals[i];</div>
<div class="line"><a id="l00116" name="l00116"></a><span class="lineno"> 116</span> }</div>
<div class="line"><a id="l00117" name="l00117"></a><span class="lineno"> 117</span> }</div>
<div class="line"><a id="l00118" name="l00118"></a><span class="lineno"> 118</span> }</div>
<div class="line"><a id="l00119" name="l00119"></a><span class="lineno"> 119</span> </div>
<div class="line"><a id="l00120" name="l00120"></a><span class="lineno"> 120</span> <span class="comment">// Case 3: Long row medium column</span></div>
<div class="line"><a id="l00121" name="l00121"></a><span class="lineno"> 121</span> <span class="keywordflow">else</span> {</div>
<div class="line"><a id="l00122" name="l00122"></a><span class="lineno"> 122</span> threadgroup U shared_vals[1024];</div>
<div class="line"><a id="l00123" name="l00123"></a><span class="lineno"> 123</span> U totals[N_READS];</div>
<div class="line"><a id="l00124" name="l00124"></a><span class="lineno"> 124</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; N_READS; i++) {</div>
<div class="line"><a id="l00125" name="l00125"></a><span class="lineno"> 125</span> totals[i] = Op::init;</div>
<div class="line"><a id="l00126" name="l00126"></a><span class="lineno"> 126</span> }</div>
<div class="line"><a id="l00127" name="l00127"></a><span class="lineno"> 127</span> </div>
<div class="line"><a id="l00128" name="l00128"></a><span class="lineno"> 128</span> <span class="keywordtype">short</span> stride = reduction_stride;</div>
<div class="line"><a id="l00129" name="l00129"></a><span class="lineno"> 129</span> <span class="keywordtype">short</span> lid = simd_group_id * <a class="code hl_variable" href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a> + simd_lane_id;</div>
<div class="line"><a id="l00130" name="l00130"></a><span class="lineno"> 130</span> short2 tile((stride + N_READS - 1) / N_READS, 32);</div>
<div class="line"><a id="l00131" name="l00131"></a><span class="lineno"> 131</span> short2 offset((lid % tile.x) * N_READS, lid / tile.x);</div>
<div class="line"><a id="l00132" name="l00132"></a><span class="lineno"> 132</span> <span class="keywordtype">short</span> sm_stride = tile.x * N_READS;</div>
<div class="line"><a id="l00133" name="l00133"></a><span class="lineno"> 133</span> <span class="keywordtype">bool</span> safe = offset.x + N_READS &lt;= stride;</div>
<div class="line"><a id="l00134" name="l00134"></a><span class="lineno"> 134</span> </div>
<div class="line"><a id="l00135" name="l00135"></a><span class="lineno"> 135</span> <span class="keywordtype">size_t</span> out_idx = gid.y + gsize.y * size_t(gid.z);</div>
<div class="line"><a id="l00136" name="l00136"></a><span class="lineno"> 136</span> in += <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a8fd0c8fc6058e650fc99bca8b6acd7d1">elem_to_loc</a>(out_idx, shape, strides, ndim) + offset.x;</div>
<div class="line"><a id="l00137" name="l00137"></a><span class="lineno"> 137</span> </div>
<div class="line"><a id="l00138" name="l00138"></a><span class="lineno"> 138</span> <span class="comment">// Read cooperatively and contiguously and aggregate the partial results.</span></div>
<div class="line"><a id="l00139" name="l00139"></a><span class="lineno"> 139</span> <span class="keywordtype">size_t</span> total = non_col_reductions * reduction_size;</div>
<div class="line"><a id="l00140" name="l00140"></a><span class="lineno"> 140</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(offset.y, reduce_shape, reduce_strides);</div>
<div class="line"><a id="l00141" name="l00141"></a><span class="lineno"> 141</span> <span class="keywordflow">for</span> (<span class="keywordtype">size_t</span> r = offset.y; r &lt; total; r += <a class="code hl_variable" href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a>) {</div>
<div class="line"><a id="l00142" name="l00142"></a><span class="lineno"> 142</span> row = in + loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#accc6d4957a8aeb38f5062754793b74d2">location</a>(r, reduce_shape, reduce_strides, reduce_ndim);</div>
<div class="line"><a id="l00143" name="l00143"></a><span class="lineno"> 143</span> </div>
<div class="line"><a id="l00144" name="l00144"></a><span class="lineno"> 144</span> <span class="keywordflow">if</span> (safe) {</div>
<div class="line"><a id="l00145" name="l00145"></a><span class="lineno"> 145</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; N_READS; i++) {</div>
<div class="line"><a id="l00146" name="l00146"></a><span class="lineno"> 146</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(<span class="keyword">static_cast&lt;</span>U<span class="keyword">&gt;</span>(row[i]), totals[i]);</div>
<div class="line"><a id="l00147" name="l00147"></a><span class="lineno"> 147</span> }</div>
<div class="line"><a id="l00148" name="l00148"></a><span class="lineno"> 148</span> } <span class="keywordflow">else</span> {</div>
<div class="line"><a id="l00149" name="l00149"></a><span class="lineno"> 149</span> U vals[N_READS];</div>
<div class="line"><a id="l00150" name="l00150"></a><span class="lineno"> 150</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; N_READS; i++) {</div>
<div class="line"><a id="l00151" name="l00151"></a><span class="lineno"> 151</span> vals[i] = (offset.x + i &lt; stride) ? static_cast&lt;U&gt;(row[i]) : <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>.init;</div>
<div class="line"><a id="l00152" name="l00152"></a><span class="lineno"> 152</span> }</div>
<div class="line"><a id="l00153" name="l00153"></a><span class="lineno"> 153</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; N_READS; i++) {</div>
<div class="line"><a id="l00154" name="l00154"></a><span class="lineno"> 154</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(vals[i], totals[i]);</div>
<div class="line"><a id="l00155" name="l00155"></a><span class="lineno"> 155</span> }</div>
<div class="line"><a id="l00156" name="l00156"></a><span class="lineno"> 156</span> }</div>
<div class="line"><a id="l00157" name="l00157"></a><span class="lineno"> 157</span> </div>
<div class="line"><a id="l00158" name="l00158"></a><span class="lineno"> 158</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(<a class="code hl_variable" href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a>, reduce_shape, reduce_strides);</div>
<div class="line"><a id="l00159" name="l00159"></a><span class="lineno"> 159</span> }</div>
<div class="line"><a id="l00160" name="l00160"></a><span class="lineno"> 160</span> </div>
<div class="line"><a id="l00161" name="l00161"></a><span class="lineno"> 161</span> <span class="comment">// Each thread holds N_READS partial results but the simdgroups are not</span></div>
<div class="line"><a id="l00162" name="l00162"></a><span class="lineno"> 162</span> <span class="comment">// aligned to do the reduction across the simdgroup so we write our results</span></div>
<div class="line"><a id="l00163" name="l00163"></a><span class="lineno"> 163</span> <span class="comment">// in the shared memory and read them back according to the simdgroup.</span></div>
<div class="line"><a id="l00164" name="l00164"></a><span class="lineno"> 164</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; N_READS; i++) {</div>
<div class="line"><a id="l00165" name="l00165"></a><span class="lineno"> 165</span> shared_vals[offset.y * sm_stride + offset.x + i] = totals[i];</div>
<div class="line"><a id="l00166" name="l00166"></a><span class="lineno"> 166</span> }</div>
<div class="line"><a id="l00167" name="l00167"></a><span class="lineno"> 167</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
<div class="line"><a id="l00168" name="l00168"></a><span class="lineno"> 168</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; N_READS; i++) {</div>
<div class="line"><a id="l00169" name="l00169"></a><span class="lineno"> 169</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>.simd_reduce(</div>
<div class="line"><a id="l00170" name="l00170"></a><span class="lineno"> 170</span> shared_vals[simd_lane_id * sm_stride + simd_group_id * N_READS + i]);</div>
<div class="line"><a id="l00171" name="l00171"></a><span class="lineno"> 171</span> }</div>
<div class="line"><a id="l00172" name="l00172"></a><span class="lineno"> 172</span> </div>
<div class="line"><a id="l00173" name="l00173"></a><span class="lineno"> 173</span> <span class="comment">// Write the output.</span></div>
<div class="line"><a id="l00174" name="l00174"></a><span class="lineno"> 174</span> <span class="keywordflow">if</span> (simd_lane_id == 0) {</div>
<div class="line"><a id="l00175" name="l00175"></a><span class="lineno"> 175</span> <span class="keywordtype">short</span> column = simd_group_id * N_READS;</div>
<div class="line"><a id="l00176" name="l00176"></a><span class="lineno"> 176</span> out += out_idx * reduction_stride + column;</div>
<div class="line"><a id="l00177" name="l00177"></a><span class="lineno"> 177</span> <span class="keywordflow">if</span> (column + N_READS &lt;= stride) {</div>
<div class="line"><a id="l00178" name="l00178"></a><span class="lineno"> 178</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; N_READS; i++) {</div>
<div class="line"><a id="l00179" name="l00179"></a><span class="lineno"> 179</span> out[i] = totals[i];</div>
<div class="line"><a id="l00180" name="l00180"></a><span class="lineno"> 180</span> }</div>
<div class="line"><a id="l00181" name="l00181"></a><span class="lineno"> 181</span> } <span class="keywordflow">else</span> {</div>
<div class="line"><a id="l00182" name="l00182"></a><span class="lineno"> 182</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; column + i &lt; stride; i++) {</div>
<div class="line"><a id="l00183" name="l00183"></a><span class="lineno"> 183</span> out[i] = totals[i];</div>
<div class="line"><a id="l00184" name="l00184"></a><span class="lineno"> 184</span> }</div>
<div class="line"><a id="l00185" name="l00185"></a><span class="lineno"> 185</span> }</div>
<div class="line"><a id="l00186" name="l00186"></a><span class="lineno"> 186</span> }</div>
<div class="line"><a id="l00187" name="l00187"></a><span class="lineno"> 187</span> }</div>
<div class="line"><a id="l00188" name="l00188"></a><span class="lineno"> 188</span>}</div>
<div class="line"><a id="l00078" name="l00078"></a><span class="lineno"> 78</span> }</div>
<div class="line"><a id="l00079" name="l00079"></a><span class="lineno"> 79</span> }</div>
<div class="line"><a id="l00080" name="l00080"></a><span class="lineno"> 80</span> }</div>
<div class="line"><a id="l00081" name="l00081"></a><span class="lineno"> 81</span> </div>
<div class="line"><a id="l00082" name="l00082"></a><span class="lineno"> 82</span> <span class="keywordflow">if</span> (lid.y == 0) {</div>
<div class="line"><a id="l00083" name="l00083"></a><span class="lineno"> 83</span> out += out_idx * reduction_stride + column;</div>
<div class="line"><a id="l00084" name="l00084"></a><span class="lineno"> 84</span> <span class="keywordflow">if</span> (safe) {</div>
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"> 85</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"> 86</span> out[i] = totals[i];</div>
<div class="line"><a id="l00087" name="l00087"></a><span class="lineno"> 87</span> }</div>
<div class="line"><a id="l00088" name="l00088"></a><span class="lineno"> 88</span> } <span class="keywordflow">else</span> {</div>
<div class="line"><a id="l00089" name="l00089"></a><span class="lineno"> 89</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; column + i &lt; reduction_stride; i++) {</div>
<div class="line"><a id="l00090" name="l00090"></a><span class="lineno"> 90</span> out[i] = totals[i];</div>
<div class="line"><a id="l00091" name="l00091"></a><span class="lineno"> 91</span> }</div>
<div class="line"><a id="l00092" name="l00092"></a><span class="lineno"> 92</span> }</div>
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"> 93</span> }</div>
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"> 94</span>}</div>
</div>
<div class="line"><a id="l00189" name="l00189"></a><span class="lineno"> 189</span> </div>
<div class="line"><a id="l00201" name="l00201"></a><span class="lineno"> 201</span><span class="keyword">template</span> &lt;<span class="keyword">typename</span> T, <span class="keyword">typename</span> U, <span class="keyword">typename</span> Op, <span class="keywordtype">int</span> NDIMS, <span class="keywordtype">int</span> BM, <span class="keywordtype">int</span> BN&gt;</div>
<div class="foldopen" id="foldopen00202" data-start="{" data-end="}">
<div class="line"><a id="l00202" name="l00202"></a><span class="lineno"><a class="line" href="reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385"> 202</a></span>[[kernel]] <span class="keywordtype">void</span> <a class="code hl_function" href="reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385">col_reduce_looped</a>(</div>
<div class="line"><a id="l00203" name="l00203"></a><span class="lineno"> 203</span> <span class="keyword">const</span> device T* in [[buffer(0)]],</div>
<div class="line"><a id="l00204" name="l00204"></a><span class="lineno"> 204</span> device U* out [[buffer(1)]],</div>
<div class="line"><a id="l00205" name="l00205"></a><span class="lineno"> 205</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; reduction_size [[buffer(2)]],</div>
<div class="line"><a id="l00206" name="l00206"></a><span class="lineno"> 206</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; reduction_stride [[buffer(3)]],</div>
<div class="line"><a id="l00207" name="l00207"></a><span class="lineno"> 207</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* shape [[buffer(4)]],</div>
<div class="line"><a id="l00208" name="l00208"></a><span class="lineno"> 208</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* strides [[buffer(5)]],</div>
<div class="line"><a id="l00209" name="l00209"></a><span class="lineno"> 209</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>&amp; ndim [[buffer(6)]],</div>
<div class="line"><a id="l00210" name="l00210"></a><span class="lineno"> 210</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* reduce_shape [[buffer(7)]],</div>
<div class="line"><a id="l00211" name="l00211"></a><span class="lineno"> 211</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* reduce_strides [[buffer(8)]],</div>
<div class="line"><a id="l00212" name="l00212"></a><span class="lineno"> 212</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>&amp; reduce_ndim [[buffer(9)]],</div>
<div class="line"><a id="l00213" name="l00213"></a><span class="lineno"> 213</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; non_col_reductions [[buffer(10)]],</div>
<div class="line"><a id="l00214" name="l00214"></a><span class="lineno"> 214</span> uint3 gid [[threadgroup_position_in_grid]],</div>
<div class="line"><a id="l00215" name="l00215"></a><span class="lineno"> 215</span> uint3 gsize [[threadgroups_per_grid]],</div>
<div class="line"><a id="l00216" name="l00216"></a><span class="lineno"> 216</span> uint simd_lane_id [[thread_index_in_simdgroup]],</div>
<div class="line"><a id="l00217" name="l00217"></a><span class="lineno"> 217</span> uint simd_group_id [[simdgroup_index_in_threadgroup]]) {</div>
<div class="line"><a id="l00218" name="l00218"></a><span class="lineno"> 218</span> Op <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>;</div>
<div class="line"><a id="l00219" name="l00219"></a><span class="lineno"> 219</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> n_simdgroups = 4;</div>
<div class="line"><a id="l00220" name="l00220"></a><span class="lineno"> 220</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> tgp_size = n_simdgroups * <a class="code hl_variable" href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a>;</div>
<div class="line"><a id="l00221" name="l00221"></a><span class="lineno"> 221</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> n_reads = (BM * BN) / tgp_size;</div>
<div class="line"><a id="l00222" name="l00222"></a><span class="lineno"> 222</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> n_read_blocks = BN / n_reads;</div>
<div class="line"><a id="l00223" name="l00223"></a><span class="lineno"> 223</span> </div>
<div class="line"><a id="l00224" name="l00224"></a><span class="lineno"> 224</span> threadgroup U shared_vals[BN * BM];</div>
<div class="line"><a id="l00225" name="l00225"></a><span class="lineno"> 225</span> U totals[n_reads];</div>
<div class="line"><a id="l00226" name="l00226"></a><span class="lineno"> 226</span> <a class="code hl_struct" href="structlooped__elem__to__loc.html">looped_elem_to_loc&lt;NDIMS&gt;</a> loop;</div>
<div class="line"><a id="l00227" name="l00227"></a><span class="lineno"> 227</span> <span class="keyword">const</span> device T* row;</div>
<div class="line"><a id="l00228" name="l00228"></a><span class="lineno"> 228</span> </div>
<div class="line"><a id="l00229" name="l00229"></a><span class="lineno"> 229</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00230" name="l00230"></a><span class="lineno"> 230</span> totals[i] = Op::init;</div>
<div class="line"><a id="l00231" name="l00231"></a><span class="lineno"> 231</span> }</div>
<div class="line"><a id="l00232" name="l00232"></a><span class="lineno"> 232</span> </div>
<div class="line"><a id="l00233" name="l00233"></a><span class="lineno"> 233</span> <span class="keywordtype">short</span> lid = simd_group_id * <a class="code hl_variable" href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a> + simd_lane_id;</div>
<div class="line"><a id="l00234" name="l00234"></a><span class="lineno"> 234</span> short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);</div>
<div class="line"><a id="l00235" name="l00235"></a><span class="lineno"> 235</span> <span class="keywordtype">size_t</span> column = BN * gid.x + offset.x;</div>
<div class="line"><a id="l00236" name="l00236"></a><span class="lineno"> 236</span> <span class="keywordtype">bool</span> safe = column + n_reads &lt;= reduction_stride;</div>
<div class="line"><a id="l00237" name="l00237"></a><span class="lineno"> 237</span> </div>
<div class="line"><a id="l00238" name="l00238"></a><span class="lineno"> 238</span> <span class="keywordtype">size_t</span> out_idx = gid.y + gsize.y * size_t(gid.z);</div>
<div class="line"><a id="l00239" name="l00239"></a><span class="lineno"> 239</span> <span class="keywordtype">size_t</span> in_idx = <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a8fd0c8fc6058e650fc99bca8b6acd7d1">elem_to_loc</a>(out_idx, shape, strides, ndim);</div>
<div class="line"><a id="l00240" name="l00240"></a><span class="lineno"> 240</span> in += in_idx + column;</div>
<div class="line"><a id="l00241" name="l00241"></a><span class="lineno"> 241</span> </div>
<div class="line"><a id="l00242" name="l00242"></a><span class="lineno"> 242</span> <span class="keywordtype">size_t</span> total = non_col_reductions * reduction_size;</div>
<div class="line"><a id="l00243" name="l00243"></a><span class="lineno"> 243</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(offset.y, reduce_shape, reduce_strides);</div>
<div class="line"><a id="l00244" name="l00244"></a><span class="lineno"> 244</span> <span class="keywordflow">for</span> (<span class="keywordtype">size_t</span> r = offset.y; r &lt; total; r += BM) {</div>
<div class="line"><a id="l00245" name="l00245"></a><span class="lineno"> 245</span> row = in + loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#accc6d4957a8aeb38f5062754793b74d2">location</a>(r, reduce_shape, reduce_strides, reduce_ndim);</div>
<div class="line"><a id="l00246" name="l00246"></a><span class="lineno"> 246</span> </div>
<div class="line"><a id="l00247" name="l00247"></a><span class="lineno"> 247</span> <span class="keywordflow">if</span> (safe) {</div>
<div class="line"><a id="l00248" name="l00248"></a><span class="lineno"> 248</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00249" name="l00249"></a><span class="lineno"> 249</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(<span class="keyword">static_cast&lt;</span>U<span class="keyword">&gt;</span>(row[i]), totals[i]);</div>
<div class="line"><a id="l00095" name="l00095"></a><span class="lineno"> 95</span> </div>
<div class="line"><a id="l00096" name="l00096"></a><span class="lineno"> 96</span><span class="keyword">template</span> &lt;<span class="keyword">typename</span> T, <span class="keyword">typename</span> U, <span class="keyword">typename</span> Op, <span class="keywordtype">int</span> NDIMS&gt;</div>
<div class="foldopen" id="foldopen00097" data-start="{" data-end="}">
<div class="line"><a id="l00097" name="l00097"></a><span class="lineno"><a class="line" href="reduce__col_8h.html#a5b4f4c4c247ad341ff8d31dcbbbce0eb"> 97</a></span>[[kernel]] <span class="keywordtype">void</span> <a class="code hl_function" href="reduce__col_8h.html#a5b4f4c4c247ad341ff8d31dcbbbce0eb">col_reduce_longcolumn</a>(</div>
<div class="line"><a id="l00098" name="l00098"></a><span class="lineno"> 98</span> <span class="keyword">const</span> device T* in [[buffer(0)]],</div>
<div class="line"><a id="l00099" name="l00099"></a><span class="lineno"> 99</span> device U* out [[buffer(1)]],</div>
<div class="line"><a id="l00100" name="l00100"></a><span class="lineno"> 100</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; reduction_size [[buffer(2)]],</div>
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"> 101</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; reduction_stride [[buffer(3)]],</div>
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"> 102</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* shape [[buffer(4)]],</div>
<div class="line"><a id="l00103" name="l00103"></a><span class="lineno"> 103</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* strides [[buffer(5)]],</div>
<div class="line"><a id="l00104" name="l00104"></a><span class="lineno"> 104</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>&amp; ndim [[buffer(6)]],</div>
<div class="line"><a id="l00105" name="l00105"></a><span class="lineno"> 105</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* reduce_shape [[buffer(7)]],</div>
<div class="line"><a id="l00106" name="l00106"></a><span class="lineno"> 106</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* reduce_strides [[buffer(8)]],</div>
<div class="line"><a id="l00107" name="l00107"></a><span class="lineno"> 107</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>&amp; reduce_ndim [[buffer(9)]],</div>
<div class="line"><a id="l00108" name="l00108"></a><span class="lineno"> 108</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; non_col_reductions [[buffer(10)]],</div>
<div class="line"><a id="l00109" name="l00109"></a><span class="lineno"> 109</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; out_size [[buffer(11)]],</div>
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"> 110</span> uint3 gid [[threadgroup_position_in_grid]],</div>
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"> 111</span> uint3 gsize [[threadgroups_per_grid]],</div>
<div class="line"><a id="l00112" name="l00112"></a><span class="lineno"> 112</span> uint3 lid [[thread_position_in_threadgroup]],</div>
<div class="line"><a id="l00113" name="l00113"></a><span class="lineno"> 113</span> uint3 lsize [[threads_per_threadgroup]]) {</div>
<div class="line"><a id="l00114" name="l00114"></a><span class="lineno"> 114</span> Op <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>;</div>
<div class="line"><a id="l00115" name="l00115"></a><span class="lineno"> 115</span> <a class="code hl_struct" href="structlooped__elem__to__loc.html">looped_elem_to_loc&lt;NDIMS&gt;</a> loop;</div>
<div class="line"><a id="l00116" name="l00116"></a><span class="lineno"> 116</span> <span class="keyword">const</span> device T* row;</div>
<div class="line"><a id="l00117" name="l00117"></a><span class="lineno"> 117</span> </div>
<div class="line"><a id="l00118" name="l00118"></a><span class="lineno"> 118</span> <span class="keywordtype">size_t</span> out_idx = gid.x + gsize.x * size_t(gid.y);</div>
<div class="line"><a id="l00119" name="l00119"></a><span class="lineno"> 119</span> <span class="keywordtype">size_t</span> in_idx = <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a8fd0c8fc6058e650fc99bca8b6acd7d1">elem_to_loc</a>(out_idx, shape, strides, ndim);</div>
<div class="line"><a id="l00120" name="l00120"></a><span class="lineno"> 120</span> in += in_idx + lid.x;</div>
<div class="line"><a id="l00121" name="l00121"></a><span class="lineno"> 121</span> </div>
<div class="line"><a id="l00122" name="l00122"></a><span class="lineno"> 122</span> U total = Op::init;</div>
<div class="line"><a id="l00123" name="l00123"></a><span class="lineno"> 123</span> <span class="keywordtype">size_t</span> total_rows = non_col_reductions * reduction_size;</div>
<div class="line"><a id="l00124" name="l00124"></a><span class="lineno"> 124</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);</div>
<div class="line"><a id="l00125" name="l00125"></a><span class="lineno"> 125</span> <span class="keywordflow">for</span> (<span class="keywordtype">size_t</span> r = gid.z * lsize.y + lid.y; r &lt; total_rows;</div>
<div class="line"><a id="l00126" name="l00126"></a><span class="lineno"> 126</span> r += lsize.y * gsize.z) {</div>
<div class="line"><a id="l00127" name="l00127"></a><span class="lineno"> 127</span> row = in + loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#accc6d4957a8aeb38f5062754793b74d2">location</a>(r, reduce_shape, reduce_strides, reduce_ndim);</div>
<div class="line"><a id="l00128" name="l00128"></a><span class="lineno"> 128</span> total = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(<span class="keyword">static_cast&lt;</span>U<span class="keyword">&gt;</span>(*row), total);</div>
<div class="line"><a id="l00129" name="l00129"></a><span class="lineno"> 129</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(lsize.y * gsize.z, reduce_shape, reduce_strides);</div>
<div class="line"><a id="l00130" name="l00130"></a><span class="lineno"> 130</span> }</div>
<div class="line"><a id="l00131" name="l00131"></a><span class="lineno"> 131</span> </div>
<div class="line"><a id="l00132" name="l00132"></a><span class="lineno"> 132</span> threadgroup U shared_vals[32 * 32];</div>
<div class="line"><a id="l00133" name="l00133"></a><span class="lineno"> 133</span> shared_vals[lid.y * lsize.x + lid.x] = total;</div>
<div class="line"><a id="l00134" name="l00134"></a><span class="lineno"> 134</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
<div class="line"><a id="l00135" name="l00135"></a><span class="lineno"> 135</span> <span class="keywordflow">if</span> (lid.y == 0) {</div>
<div class="line"><a id="l00136" name="l00136"></a><span class="lineno"> 136</span> <span class="keywordflow">for</span> (uint i = 1; i &lt; lsize.y; i++) {</div>
<div class="line"><a id="l00137" name="l00137"></a><span class="lineno"> 137</span> total = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(total, shared_vals[i * lsize.x + lid.x]);</div>
<div class="line"><a id="l00138" name="l00138"></a><span class="lineno"> 138</span> }</div>
<div class="line"><a id="l00139" name="l00139"></a><span class="lineno"> 139</span> out[gid.z * out_size + out_idx * reduction_stride + lid.x] = total;</div>
<div class="line"><a id="l00140" name="l00140"></a><span class="lineno"> 140</span> }</div>
<div class="line"><a id="l00141" name="l00141"></a><span class="lineno"> 141</span>}</div>
</div>
<div class="line"><a id="l00142" name="l00142"></a><span class="lineno"> 142</span> </div>
<div class="line"><a id="l00154" name="l00154"></a><span class="lineno"> 154</span><span class="keyword">template</span> &lt;<span class="keyword">typename</span> T, <span class="keyword">typename</span> U, <span class="keyword">typename</span> Op, <span class="keywordtype">int</span> NDIMS, <span class="keywordtype">int</span> BM, <span class="keywordtype">int</span> BN&gt;</div>
<div class="foldopen" id="foldopen00155" data-start="{" data-end="}">
<div class="line"><a id="l00155" name="l00155"></a><span class="lineno"><a class="line" href="reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385"> 155</a></span>[[kernel]] <span class="keywordtype">void</span> <a class="code hl_function" href="reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385">col_reduce_looped</a>(</div>
<div class="line"><a id="l00156" name="l00156"></a><span class="lineno"> 156</span> <span class="keyword">const</span> device T* in [[buffer(0)]],</div>
<div class="line"><a id="l00157" name="l00157"></a><span class="lineno"> 157</span> device U* out [[buffer(1)]],</div>
<div class="line"><a id="l00158" name="l00158"></a><span class="lineno"> 158</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; reduction_size [[buffer(2)]],</div>
<div class="line"><a id="l00159" name="l00159"></a><span class="lineno"> 159</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; reduction_stride [[buffer(3)]],</div>
<div class="line"><a id="l00160" name="l00160"></a><span class="lineno"> 160</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* shape [[buffer(4)]],</div>
<div class="line"><a id="l00161" name="l00161"></a><span class="lineno"> 161</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* strides [[buffer(5)]],</div>
<div class="line"><a id="l00162" name="l00162"></a><span class="lineno"> 162</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>&amp; ndim [[buffer(6)]],</div>
<div class="line"><a id="l00163" name="l00163"></a><span class="lineno"> 163</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* reduce_shape [[buffer(7)]],</div>
<div class="line"><a id="l00164" name="l00164"></a><span class="lineno"> 164</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* reduce_strides [[buffer(8)]],</div>
<div class="line"><a id="l00165" name="l00165"></a><span class="lineno"> 165</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>&amp; reduce_ndim [[buffer(9)]],</div>
<div class="line"><a id="l00166" name="l00166"></a><span class="lineno"> 166</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; non_col_reductions [[buffer(10)]],</div>
<div class="line"><a id="l00167" name="l00167"></a><span class="lineno"> 167</span> uint3 gid [[threadgroup_position_in_grid]],</div>
<div class="line"><a id="l00168" name="l00168"></a><span class="lineno"> 168</span> uint3 gsize [[threadgroups_per_grid]],</div>
<div class="line"><a id="l00169" name="l00169"></a><span class="lineno"> 169</span> uint simd_lane_id [[thread_index_in_simdgroup]],</div>
<div class="line"><a id="l00170" name="l00170"></a><span class="lineno"> 170</span> uint simd_group_id [[simdgroup_index_in_threadgroup]]) {</div>
<div class="line"><a id="l00171" name="l00171"></a><span class="lineno"> 171</span> Op <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>;</div>
<div class="line"><a id="l00172" name="l00172"></a><span class="lineno"> 172</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> n_simdgroups = 8;</div>
<div class="line"><a id="l00173" name="l00173"></a><span class="lineno"> 173</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> tgp_size = n_simdgroups * <a class="code hl_variable" href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a>;</div>
<div class="line"><a id="l00174" name="l00174"></a><span class="lineno"> 174</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> n_reads = (BM * BN) / tgp_size;</div>
<div class="line"><a id="l00175" name="l00175"></a><span class="lineno"> 175</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> n_read_blocks = BN / n_reads;</div>
<div class="line"><a id="l00176" name="l00176"></a><span class="lineno"> 176</span> </div>
<div class="line"><a id="l00177" name="l00177"></a><span class="lineno"> 177</span> threadgroup U shared_vals[BN * BM];</div>
<div class="line"><a id="l00178" name="l00178"></a><span class="lineno"> 178</span> U totals[n_reads];</div>
<div class="line"><a id="l00179" name="l00179"></a><span class="lineno"> 179</span> <a class="code hl_struct" href="structlooped__elem__to__loc.html">looped_elem_to_loc&lt;NDIMS&gt;</a> loop;</div>
<div class="line"><a id="l00180" name="l00180"></a><span class="lineno"> 180</span> <span class="keyword">const</span> device T* row;</div>
<div class="line"><a id="l00181" name="l00181"></a><span class="lineno"> 181</span> </div>
<div class="line"><a id="l00182" name="l00182"></a><span class="lineno"> 182</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00183" name="l00183"></a><span class="lineno"> 183</span> totals[i] = Op::init;</div>
<div class="line"><a id="l00184" name="l00184"></a><span class="lineno"> 184</span> }</div>
<div class="line"><a id="l00185" name="l00185"></a><span class="lineno"> 185</span> </div>
<div class="line"><a id="l00186" name="l00186"></a><span class="lineno"> 186</span> <span class="keywordtype">short</span> lid = simd_group_id * <a class="code hl_variable" href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a> + simd_lane_id;</div>
<div class="line"><a id="l00187" name="l00187"></a><span class="lineno"> 187</span> short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);</div>
<div class="line"><a id="l00188" name="l00188"></a><span class="lineno"> 188</span> <span class="keywordtype">size_t</span> column = BN * gid.x + offset.x;</div>
<div class="line"><a id="l00189" name="l00189"></a><span class="lineno"> 189</span> <span class="keywordtype">bool</span> safe = column + n_reads &lt;= reduction_stride;</div>
<div class="line"><a id="l00190" name="l00190"></a><span class="lineno"> 190</span> </div>
<div class="line"><a id="l00191" name="l00191"></a><span class="lineno"> 191</span> <span class="keywordtype">size_t</span> out_idx = gid.y + gsize.y * size_t(gid.z);</div>
<div class="line"><a id="l00192" name="l00192"></a><span class="lineno"> 192</span> <span class="keywordtype">size_t</span> in_idx = <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a8fd0c8fc6058e650fc99bca8b6acd7d1">elem_to_loc</a>(out_idx, shape, strides, ndim);</div>
<div class="line"><a id="l00193" name="l00193"></a><span class="lineno"> 193</span> in += in_idx + column;</div>
<div class="line"><a id="l00194" name="l00194"></a><span class="lineno"> 194</span> </div>
<div class="line"><a id="l00195" name="l00195"></a><span class="lineno"> 195</span> <span class="keywordtype">size_t</span> total = non_col_reductions * reduction_size;</div>
<div class="line"><a id="l00196" name="l00196"></a><span class="lineno"> 196</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(offset.y, reduce_shape, reduce_strides);</div>
<div class="line"><a id="l00197" name="l00197"></a><span class="lineno"> 197</span> <span class="keywordflow">for</span> (<span class="keywordtype">size_t</span> r = offset.y; r &lt; total; r += BM) {</div>
<div class="line"><a id="l00198" name="l00198"></a><span class="lineno"> 198</span> row = in + loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#accc6d4957a8aeb38f5062754793b74d2">location</a>(r, reduce_shape, reduce_strides, reduce_ndim);</div>
<div class="line"><a id="l00199" name="l00199"></a><span class="lineno"> 199</span> </div>
<div class="line"><a id="l00200" name="l00200"></a><span class="lineno"> 200</span> <span class="keywordflow">if</span> (safe) {</div>
<div class="line"><a id="l00201" name="l00201"></a><span class="lineno"> 201</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00202" name="l00202"></a><span class="lineno"> 202</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(<span class="keyword">static_cast&lt;</span>U<span class="keyword">&gt;</span>(row[i]), totals[i]);</div>
<div class="line"><a id="l00203" name="l00203"></a><span class="lineno"> 203</span> }</div>
<div class="line"><a id="l00204" name="l00204"></a><span class="lineno"> 204</span> } <span class="keywordflow">else</span> {</div>
<div class="line"><a id="l00205" name="l00205"></a><span class="lineno"> 205</span> U vals[n_reads];</div>
<div class="line"><a id="l00206" name="l00206"></a><span class="lineno"> 206</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00207" name="l00207"></a><span class="lineno"> 207</span> vals[i] =</div>
<div class="line"><a id="l00208" name="l00208"></a><span class="lineno"> 208</span> (column + i &lt; reduction_stride) ? static_cast&lt;U&gt;(row[i]) : <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>.init;</div>
<div class="line"><a id="l00209" name="l00209"></a><span class="lineno"> 209</span> }</div>
<div class="line"><a id="l00210" name="l00210"></a><span class="lineno"> 210</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00211" name="l00211"></a><span class="lineno"> 211</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(vals[i], totals[i]);</div>
<div class="line"><a id="l00212" name="l00212"></a><span class="lineno"> 212</span> }</div>
<div class="line"><a id="l00213" name="l00213"></a><span class="lineno"> 213</span> }</div>
<div class="line"><a id="l00214" name="l00214"></a><span class="lineno"> 214</span> </div>
<div class="line"><a id="l00215" name="l00215"></a><span class="lineno"> 215</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(BM, reduce_shape, reduce_strides);</div>
<div class="line"><a id="l00216" name="l00216"></a><span class="lineno"> 216</span> }</div>
<div class="line"><a id="l00217" name="l00217"></a><span class="lineno"> 217</span> </div>
<div class="line"><a id="l00218" name="l00218"></a><span class="lineno"> 218</span> <span class="comment">// We can use a simd reduction to accumulate across BM so each thread writes</span></div>
<div class="line"><a id="l00219" name="l00219"></a><span class="lineno"> 219</span> <span class="comment">// the partial output to SM and then each simdgroup does BN / n_simdgroups</span></div>
<div class="line"><a id="l00220" name="l00220"></a><span class="lineno"> 220</span> <span class="comment">// accumulations.</span></div>
<div class="line"><a id="l00221" name="l00221"></a><span class="lineno"> 221</span> <span class="keywordflow">if</span> (BM == 32) {</div>
<div class="line"><a id="l00222" name="l00222"></a><span class="lineno"> 222</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> n_outputs = BN / n_simdgroups;</div>
<div class="line"><a id="l00223" name="l00223"></a><span class="lineno"> 223</span> <span class="keyword">static_assert</span>(</div>
<div class="line"><a id="l00224" name="l00224"></a><span class="lineno"> 224</span> BM != 32 || n_outputs == n_reads,</div>
<div class="line"><a id="l00225" name="l00225"></a><span class="lineno"> 225</span> <span class="stringliteral">&quot;The tile should be selected such that n_outputs == n_reads&quot;</span>);</div>
<div class="line"><a id="l00226" name="l00226"></a><span class="lineno"> 226</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00227" name="l00227"></a><span class="lineno"> 227</span> shared_vals[offset.y * BN + offset.x + i] = totals[i];</div>
<div class="line"><a id="l00228" name="l00228"></a><span class="lineno"> 228</span> }</div>
<div class="line"><a id="l00229" name="l00229"></a><span class="lineno"> 229</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
<div class="line"><a id="l00230" name="l00230"></a><span class="lineno"> 230</span> short2 out_offset(simd_group_id * n_outputs, simd_lane_id);</div>
<div class="line"><a id="l00231" name="l00231"></a><span class="lineno"> 231</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_outputs; i++) {</div>
<div class="line"><a id="l00232" name="l00232"></a><span class="lineno"> 232</span> totals[i] =</div>
<div class="line"><a id="l00233" name="l00233"></a><span class="lineno"> 233</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);</div>
<div class="line"><a id="l00234" name="l00234"></a><span class="lineno"> 234</span> }</div>
<div class="line"><a id="l00235" name="l00235"></a><span class="lineno"> 235</span> </div>
<div class="line"><a id="l00236" name="l00236"></a><span class="lineno"> 236</span> <span class="comment">// Write the output.</span></div>
<div class="line"><a id="l00237" name="l00237"></a><span class="lineno"> 237</span> <span class="keywordflow">if</span> (simd_lane_id == 0) {</div>
<div class="line"><a id="l00238" name="l00238"></a><span class="lineno"> 238</span> <span class="keywordtype">size_t</span> out_column = BN * gid.x + out_offset.x;</div>
<div class="line"><a id="l00239" name="l00239"></a><span class="lineno"> 239</span> out += out_idx * reduction_stride + out_column;</div>
<div class="line"><a id="l00240" name="l00240"></a><span class="lineno"> 240</span> <span class="keywordflow">if</span> (out_column + n_outputs &lt;= reduction_stride) {</div>
<div class="line"><a id="l00241" name="l00241"></a><span class="lineno"> 241</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_outputs; i++) {</div>
<div class="line"><a id="l00242" name="l00242"></a><span class="lineno"> 242</span> out[i] = totals[i];</div>
<div class="line"><a id="l00243" name="l00243"></a><span class="lineno"> 243</span> }</div>
<div class="line"><a id="l00244" name="l00244"></a><span class="lineno"> 244</span> } <span class="keywordflow">else</span> {</div>
<div class="line"><a id="l00245" name="l00245"></a><span class="lineno"> 245</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; out_column + i &lt; reduction_stride; i++) {</div>
<div class="line"><a id="l00246" name="l00246"></a><span class="lineno"> 246</span> out[i] = totals[i];</div>
<div class="line"><a id="l00247" name="l00247"></a><span class="lineno"> 247</span> }</div>
<div class="line"><a id="l00248" name="l00248"></a><span class="lineno"> 248</span> }</div>
<div class="line"><a id="l00249" name="l00249"></a><span class="lineno"> 249</span> }</div>
<div class="line"><a id="l00250" name="l00250"></a><span class="lineno"> 250</span> }</div>
<div class="line"><a id="l00251" name="l00251"></a><span class="lineno"> 251</span> } <span class="keywordflow">else</span> {</div>
<div class="line"><a id="l00252" name="l00252"></a><span class="lineno"> 252</span> U vals[n_reads];</div>
<div class="line"><a id="l00253" name="l00253"></a><span class="lineno"> 253</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00254" name="l00254"></a><span class="lineno"> 254</span> vals[i] =</div>
<div class="line"><a id="l00255" name="l00255"></a><span class="lineno"> 255</span> (column + i &lt; reduction_stride) ? static_cast&lt;U&gt;(row[i]) : <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>.init;</div>
<div class="line"><a id="l00256" name="l00256"></a><span class="lineno"> 256</span> }</div>
<div class="line"><a id="l00251" name="l00251"></a><span class="lineno"> 251</span> </div>
<div class="line"><a id="l00252" name="l00252"></a><span class="lineno"> 252</span> <span class="comment">// Each thread holds n_reads partial results. We write them all out to shared</span></div>
<div class="line"><a id="l00253" name="l00253"></a><span class="lineno"> 253</span> <span class="comment">// memory and threads with offset.y == 0 aggregate the columns and write the</span></div>
<div class="line"><a id="l00254" name="l00254"></a><span class="lineno"> 254</span> <span class="comment">// outputs.</span></div>
<div class="line"><a id="l00255" name="l00255"></a><span class="lineno"> 255</span> <span class="keywordflow">else</span> {</div>
<div class="line"><a id="l00256" name="l00256"></a><span class="lineno"> 256</span> <span class="keywordtype">short</span> x_block = offset.x / n_reads;</div>
<div class="line"><a id="l00257" name="l00257"></a><span class="lineno"> 257</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00258" name="l00258"></a><span class="lineno"> 258</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(vals[i], totals[i]);</div>
<div class="line"><a id="l00258" name="l00258"></a><span class="lineno"> 258</span> shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];</div>
<div class="line"><a id="l00259" name="l00259"></a><span class="lineno"> 259</span> }</div>
<div class="line"><a id="l00260" name="l00260"></a><span class="lineno"> 260</span> }</div>
<div class="line"><a id="l00261" name="l00261"></a><span class="lineno"> 261</span> </div>
<div class="line"><a id="l00262" name="l00262"></a><span class="lineno"> 262</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(BM, reduce_shape, reduce_strides);</div>
<div class="line"><a id="l00263" name="l00263"></a><span class="lineno"> 263</span> }</div>
<div class="line"><a id="l00264" name="l00264"></a><span class="lineno"> 264</span> </div>
<div class="line"><a id="l00265" name="l00265"></a><span class="lineno"> 265</span> <span class="comment">// We can use a simd reduction to accumulate across BM so each thread writes</span></div>
<div class="line"><a id="l00266" name="l00266"></a><span class="lineno"> 266</span> <span class="comment">// the partial output to SM and then each simdgroup does BN / n_simdgroups</span></div>
<div class="line"><a id="l00267" name="l00267"></a><span class="lineno"> 267</span> <span class="comment">// accumulations.</span></div>
<div class="line"><a id="l00268" name="l00268"></a><span class="lineno"> 268</span> <span class="keywordflow">if</span> (BM == 32) {</div>
<div class="line"><a id="l00269" name="l00269"></a><span class="lineno"> 269</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> n_outputs = BN / n_simdgroups;</div>
<div class="line"><a id="l00270" name="l00270"></a><span class="lineno"> 270</span> <span class="keyword">static_assert</span>(</div>
<div class="line"><a id="l00271" name="l00271"></a><span class="lineno"> 271</span> BM != 32 || n_outputs == n_reads,</div>
<div class="line"><a id="l00272" name="l00272"></a><span class="lineno"> 272</span> <span class="stringliteral">&quot;The tile should be selected such that n_outputs == n_reads&quot;</span>);</div>
<div class="line"><a id="l00273" name="l00273"></a><span class="lineno"> 273</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00274" name="l00274"></a><span class="lineno"> 274</span> shared_vals[offset.y * BN + offset.x + i] = totals[i];</div>
<div class="line"><a id="l00275" name="l00275"></a><span class="lineno"> 275</span> }</div>
<div class="line"><a id="l00276" name="l00276"></a><span class="lineno"> 276</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
<div class="line"><a id="l00277" name="l00277"></a><span class="lineno"> 277</span> short2 out_offset(simd_group_id * n_outputs, simd_lane_id);</div>
<div class="line"><a id="l00278" name="l00278"></a><span class="lineno"> 278</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_outputs; i++) {</div>
<div class="line"><a id="l00279" name="l00279"></a><span class="lineno"> 279</span> totals[i] =</div>
<div class="line"><a id="l00280" name="l00280"></a><span class="lineno"> 280</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);</div>
<div class="line"><a id="l00260" name="l00260"></a><span class="lineno"> 260</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
<div class="line"><a id="l00261" name="l00261"></a><span class="lineno"> 261</span> <span class="keywordflow">if</span> (offset.y == 0) {</div>
<div class="line"><a id="l00262" name="l00262"></a><span class="lineno"> 262</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00263" name="l00263"></a><span class="lineno"> 263</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> j = 1; j &lt; BM; j++) {</div>
<div class="line"><a id="l00264" name="l00264"></a><span class="lineno"> 264</span> totals[i] =</div>
<div class="line"><a id="l00265" name="l00265"></a><span class="lineno"> 265</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);</div>
<div class="line"><a id="l00266" name="l00266"></a><span class="lineno"> 266</span> }</div>
<div class="line"><a id="l00267" name="l00267"></a><span class="lineno"> 267</span> }</div>
<div class="line"><a id="l00268" name="l00268"></a><span class="lineno"> 268</span> }</div>
<div class="line"><a id="l00269" name="l00269"></a><span class="lineno"> 269</span> </div>
<div class="line"><a id="l00270" name="l00270"></a><span class="lineno"> 270</span> <span class="comment">// Write the output.</span></div>
<div class="line"><a id="l00271" name="l00271"></a><span class="lineno"> 271</span> <span class="keywordflow">if</span> (offset.y == 0) {</div>
<div class="line"><a id="l00272" name="l00272"></a><span class="lineno"> 272</span> out += out_idx * reduction_stride + column;</div>
<div class="line"><a id="l00273" name="l00273"></a><span class="lineno"> 273</span> <span class="keywordflow">if</span> (safe) {</div>
<div class="line"><a id="l00274" name="l00274"></a><span class="lineno"> 274</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00275" name="l00275"></a><span class="lineno"> 275</span> out[i] = totals[i];</div>
<div class="line"><a id="l00276" name="l00276"></a><span class="lineno"> 276</span> }</div>
<div class="line"><a id="l00277" name="l00277"></a><span class="lineno"> 277</span> } <span class="keywordflow">else</span> {</div>
<div class="line"><a id="l00278" name="l00278"></a><span class="lineno"> 278</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; column + i &lt; reduction_stride; i++) {</div>
<div class="line"><a id="l00279" name="l00279"></a><span class="lineno"> 279</span> out[i] = totals[i];</div>
<div class="line"><a id="l00280" name="l00280"></a><span class="lineno"> 280</span> }</div>
<div class="line"><a id="l00281" name="l00281"></a><span class="lineno"> 281</span> }</div>
<div class="line"><a id="l00282" name="l00282"></a><span class="lineno"> 282</span> </div>
<div class="line"><a id="l00283" name="l00283"></a><span class="lineno"> 283</span> <span class="comment">// Write the output.</span></div>
<div class="line"><a id="l00284" name="l00284"></a><span class="lineno"> 284</span> <span class="keywordflow">if</span> (simd_lane_id == 0) {</div>
<div class="line"><a id="l00285" name="l00285"></a><span class="lineno"> 285</span> <span class="keywordtype">size_t</span> out_column = BN * gid.x + out_offset.x;</div>
<div class="line"><a id="l00286" name="l00286"></a><span class="lineno"> 286</span> out += out_idx * reduction_stride + out_column;</div>
<div class="line"><a id="l00287" name="l00287"></a><span class="lineno"> 287</span> <span class="keywordflow">if</span> (out_column + n_outputs &lt;= reduction_stride) {</div>
<div class="line"><a id="l00288" name="l00288"></a><span class="lineno"> 288</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_outputs; i++) {</div>
<div class="line"><a id="l00289" name="l00289"></a><span class="lineno"> 289</span> out[i] = totals[i];</div>
<div class="line"><a id="l00290" name="l00290"></a><span class="lineno"> 290</span> }</div>
<div class="line"><a id="l00291" name="l00291"></a><span class="lineno"> 291</span> } <span class="keywordflow">else</span> {</div>
<div class="line"><a id="l00292" name="l00292"></a><span class="lineno"> 292</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; out_column + i &lt; reduction_stride; i++) {</div>
<div class="line"><a id="l00293" name="l00293"></a><span class="lineno"> 293</span> out[i] = totals[i];</div>
<div class="line"><a id="l00294" name="l00294"></a><span class="lineno"> 294</span> }</div>
<div class="line"><a id="l00295" name="l00295"></a><span class="lineno"> 295</span> }</div>
<div class="line"><a id="l00296" name="l00296"></a><span class="lineno"> 296</span> }</div>
<div class="line"><a id="l00297" name="l00297"></a><span class="lineno"> 297</span> }</div>
<div class="line"><a id="l00298" name="l00298"></a><span class="lineno"> 298</span> </div>
<div class="line"><a id="l00299" name="l00299"></a><span class="lineno"> 299</span> <span class="comment">// Each thread holds n_reads partial results. We write them all out to shared</span></div>
<div class="line"><a id="l00300" name="l00300"></a><span class="lineno"> 300</span> <span class="comment">// memory and threads with offset.y == 0 aggregate the columns and write the</span></div>
<div class="line"><a id="l00301" name="l00301"></a><span class="lineno"> 301</span> <span class="comment">// outputs.</span></div>
<div class="line"><a id="l00302" name="l00302"></a><span class="lineno"> 302</span> <span class="keywordflow">else</span> {</div>
<div class="line"><a id="l00303" name="l00303"></a><span class="lineno"> 303</span> <span class="keywordtype">short</span> x_block = offset.x / n_reads;</div>
<div class="line"><a id="l00304" name="l00304"></a><span class="lineno"> 304</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00305" name="l00305"></a><span class="lineno"> 305</span> shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];</div>
<div class="line"><a id="l00306" name="l00306"></a><span class="lineno"> 306</span> }</div>
<div class="line"><a id="l00307" name="l00307"></a><span class="lineno"> 307</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
<div class="line"><a id="l00308" name="l00308"></a><span class="lineno"> 308</span> <span class="keywordflow">if</span> (offset.y == 0) {</div>
<div class="line"><a id="l00309" name="l00309"></a><span class="lineno"> 309</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00310" name="l00310"></a><span class="lineno"> 310</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> j = 1; j &lt; BM; j++) {</div>
<div class="line"><a id="l00311" name="l00311"></a><span class="lineno"> 311</span> totals[i] =</div>
<div class="line"><a id="l00312" name="l00312"></a><span class="lineno"> 312</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);</div>
<div class="line"><a id="l00313" name="l00313"></a><span class="lineno"> 313</span> }</div>
<div class="line"><a id="l00314" name="l00314"></a><span class="lineno"> 314</span> }</div>
<div class="line"><a id="l00315" name="l00315"></a><span class="lineno"> 315</span> }</div>
<div class="line"><a id="l00316" name="l00316"></a><span class="lineno"> 316</span> </div>
<div class="line"><a id="l00317" name="l00317"></a><span class="lineno"> 317</span> <span class="comment">// Write the output.</span></div>
<div class="line"><a id="l00318" name="l00318"></a><span class="lineno"> 318</span> <span class="keywordflow">if</span> (offset.y == 0) {</div>
<div class="line"><a id="l00319" name="l00319"></a><span class="lineno"> 319</span> out += out_idx * reduction_stride + column;</div>
<div class="line"><a id="l00320" name="l00320"></a><span class="lineno"> 320</span> <span class="keywordflow">if</span> (safe) {</div>
<div class="line"><a id="l00321" name="l00321"></a><span class="lineno"> 321</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00322" name="l00322"></a><span class="lineno"> 322</span> out[i] = totals[i];</div>
<div class="line"><a id="l00323" name="l00323"></a><span class="lineno"> 323</span> }</div>
<div class="line"><a id="l00324" name="l00324"></a><span class="lineno"> 324</span> } <span class="keywordflow">else</span> {</div>
<div class="line"><a id="l00325" name="l00325"></a><span class="lineno"> 325</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; column + i &lt; reduction_stride; i++) {</div>
<div class="line"><a id="l00326" name="l00326"></a><span class="lineno"> 326</span> out[i] = totals[i];</div>
<div class="line"><a id="l00327" name="l00327"></a><span class="lineno"> 327</span> }</div>
<div class="line"><a id="l00328" name="l00328"></a><span class="lineno"> 328</span> }</div>
<div class="line"><a id="l00329" name="l00329"></a><span class="lineno"> 329</span> }</div>
<div class="line"><a id="l00330" name="l00330"></a><span class="lineno"> 330</span> }</div>
<div class="line"><a id="l00331" name="l00331"></a><span class="lineno"> 331</span>}</div>
<div class="line"><a id="l00282" name="l00282"></a><span class="lineno"> 282</span> }</div>
<div class="line"><a id="l00283" name="l00283"></a><span class="lineno"> 283</span> }</div>
<div class="line"><a id="l00284" name="l00284"></a><span class="lineno"> 284</span>}</div>
</div>
<div class="line"><a id="l00285" name="l00285"></a><span class="lineno"> 285</span> </div>
<div class="line"><a id="l00286" name="l00286"></a><span class="lineno"> 286</span><span class="keyword">template</span> &lt;<span class="keyword">typename</span> T, <span class="keyword">typename</span> U, <span class="keyword">typename</span> Op, <span class="keywordtype">int</span> NDIMS, <span class="keywordtype">int</span> BM, <span class="keywordtype">int</span> BN&gt;</div>
<div class="foldopen" id="foldopen00287" data-start="{" data-end="}">
<div class="line"><a id="l00287" name="l00287"></a><span class="lineno"><a class="line" href="reduce__col_8h.html#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d"> 287</a></span>[[kernel]] <span class="keywordtype">void</span> <a class="code hl_function" href="reduce__col_8h.html#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d">col_reduce_2pass</a>(</div>
<div class="line"><a id="l00288" name="l00288"></a><span class="lineno"> 288</span> <span class="keyword">const</span> device T* in [[buffer(0)]],</div>
<div class="line"><a id="l00289" name="l00289"></a><span class="lineno"> 289</span> device U* out [[buffer(1)]],</div>
<div class="line"><a id="l00290" name="l00290"></a><span class="lineno"> 290</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; reduction_size [[buffer(2)]],</div>
<div class="line"><a id="l00291" name="l00291"></a><span class="lineno"> 291</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; reduction_stride [[buffer(3)]],</div>
<div class="line"><a id="l00292" name="l00292"></a><span class="lineno"> 292</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* shape [[buffer(4)]],</div>
<div class="line"><a id="l00293" name="l00293"></a><span class="lineno"> 293</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* strides [[buffer(5)]],</div>
<div class="line"><a id="l00294" name="l00294"></a><span class="lineno"> 294</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>&amp; ndim [[buffer(6)]],</div>
<div class="line"><a id="l00295" name="l00295"></a><span class="lineno"> 295</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>* reduce_shape [[buffer(7)]],</div>
<div class="line"><a id="l00296" name="l00296"></a><span class="lineno"> 296</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>* reduce_strides [[buffer(8)]],</div>
<div class="line"><a id="l00297" name="l00297"></a><span class="lineno"> 297</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>&amp; reduce_ndim [[buffer(9)]],</div>
<div class="line"><a id="l00298" name="l00298"></a><span class="lineno"> 298</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; non_col_reductions [[buffer(10)]],</div>
<div class="line"><a id="l00299" name="l00299"></a><span class="lineno"> 299</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; out_size [[buffer(11)]],</div>
<div class="line"><a id="l00300" name="l00300"></a><span class="lineno"> 300</span> uint3 gid [[threadgroup_position_in_grid]],</div>
<div class="line"><a id="l00301" name="l00301"></a><span class="lineno"> 301</span> uint3 gsize [[threadgroups_per_grid]],</div>
<div class="line"><a id="l00302" name="l00302"></a><span class="lineno"> 302</span> uint simd_lane_id [[thread_index_in_simdgroup]],</div>
<div class="line"><a id="l00303" name="l00303"></a><span class="lineno"> 303</span> uint simd_group_id [[simdgroup_index_in_threadgroup]]) {</div>
<div class="line"><a id="l00304" name="l00304"></a><span class="lineno"> 304</span> Op <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>;</div>
<div class="line"><a id="l00305" name="l00305"></a><span class="lineno"> 305</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> n_simdgroups = 8;</div>
<div class="line"><a id="l00306" name="l00306"></a><span class="lineno"> 306</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> tgp_size = n_simdgroups * <a class="code hl_variable" href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a>;</div>
<div class="line"><a id="l00307" name="l00307"></a><span class="lineno"> 307</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> n_reads = (BM * BN) / tgp_size;</div>
<div class="line"><a id="l00308" name="l00308"></a><span class="lineno"> 308</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> n_read_blocks = BN / n_reads;</div>
<div class="line"><a id="l00309" name="l00309"></a><span class="lineno"> 309</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> n_outputs = BN / n_simdgroups;</div>
<div class="line"><a id="l00310" name="l00310"></a><span class="lineno"> 310</span> <span class="keyword">constexpr</span> <span class="keywordtype">short</span> outer_blocks = 32;</div>
<div class="line"><a id="l00311" name="l00311"></a><span class="lineno"> 311</span> <span class="keyword">static_assert</span>(BM == 32, <span class="stringliteral">&quot;BM should be equal to 32&quot;</span>);</div>
<div class="line"><a id="l00312" name="l00312"></a><span class="lineno"> 312</span> </div>
<div class="line"><a id="l00313" name="l00313"></a><span class="lineno"> 313</span> threadgroup U shared_vals[BN * BM];</div>
<div class="line"><a id="l00314" name="l00314"></a><span class="lineno"> 314</span> U totals[n_reads];</div>
<div class="line"><a id="l00315" name="l00315"></a><span class="lineno"> 315</span> <a class="code hl_struct" href="structlooped__elem__to__loc.html">looped_elem_to_loc&lt;NDIMS&gt;</a> loop;</div>
<div class="line"><a id="l00316" name="l00316"></a><span class="lineno"> 316</span> <span class="keyword">const</span> device T* row;</div>
<div class="line"><a id="l00317" name="l00317"></a><span class="lineno"> 317</span> </div>
<div class="line"><a id="l00318" name="l00318"></a><span class="lineno"> 318</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00319" name="l00319"></a><span class="lineno"> 319</span> totals[i] = Op::init;</div>
<div class="line"><a id="l00320" name="l00320"></a><span class="lineno"> 320</span> }</div>
<div class="line"><a id="l00321" name="l00321"></a><span class="lineno"> 321</span> </div>
<div class="line"><a id="l00322" name="l00322"></a><span class="lineno"> 322</span> <span class="keywordtype">short</span> lid = simd_group_id * <a class="code hl_variable" href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a> + simd_lane_id;</div>
<div class="line"><a id="l00323" name="l00323"></a><span class="lineno"> 323</span> short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);</div>
<div class="line"><a id="l00324" name="l00324"></a><span class="lineno"> 324</span> <span class="keywordtype">size_t</span> column = BN * gid.x + offset.x;</div>
<div class="line"><a id="l00325" name="l00325"></a><span class="lineno"> 325</span> <span class="keywordtype">bool</span> safe = column + n_reads &lt;= reduction_stride;</div>
<div class="line"><a id="l00326" name="l00326"></a><span class="lineno"> 326</span> </div>
<div class="line"><a id="l00327" name="l00327"></a><span class="lineno"> 327</span> <span class="keywordtype">size_t</span> full_idx = gid.y + gsize.y * size_t(gid.z);</div>
<div class="line"><a id="l00328" name="l00328"></a><span class="lineno"> 328</span> <span class="keywordtype">size_t</span> block_idx = full_idx / out_size;</div>
<div class="line"><a id="l00329" name="l00329"></a><span class="lineno"> 329</span> <span class="keywordtype">size_t</span> out_idx = full_idx % out_size;</div>
<div class="line"><a id="l00330" name="l00330"></a><span class="lineno"> 330</span> <span class="keywordtype">size_t</span> in_idx = <a class="code hl_function" href="backend_2metal_2kernels_2utils_8h.html#a8fd0c8fc6058e650fc99bca8b6acd7d1">elem_to_loc</a>(out_idx, shape, strides, ndim);</div>
<div class="line"><a id="l00331" name="l00331"></a><span class="lineno"> 331</span> in += in_idx + column;</div>
<div class="line"><a id="l00332" name="l00332"></a><span class="lineno"> 332</span> </div>
<div class="line"><a id="l00333" name="l00333"></a><span class="lineno"> 333</span> <span class="keywordtype">size_t</span> total = non_col_reductions * reduction_size;</div>
<div class="line"><a id="l00334" name="l00334"></a><span class="lineno"> 334</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(offset.y + block_idx * BM, reduce_shape, reduce_strides);</div>
<div class="line"><a id="l00335" name="l00335"></a><span class="lineno"> 335</span> <span class="keywordflow">for</span> (<span class="keywordtype">size_t</span> r = offset.y + block_idx * BM; r &lt; total;</div>
<div class="line"><a id="l00336" name="l00336"></a><span class="lineno"> 336</span> r += outer_blocks * BM) {</div>
<div class="line"><a id="l00337" name="l00337"></a><span class="lineno"> 337</span> row = in + loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#accc6d4957a8aeb38f5062754793b74d2">location</a>(r, reduce_shape, reduce_strides, reduce_ndim);</div>
<div class="line"><a id="l00338" name="l00338"></a><span class="lineno"> 338</span> </div>
<div class="line"><a id="l00339" name="l00339"></a><span class="lineno"> 339</span> <span class="keywordflow">if</span> (safe) {</div>
<div class="line"><a id="l00340" name="l00340"></a><span class="lineno"> 340</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00341" name="l00341"></a><span class="lineno"> 341</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(<span class="keyword">static_cast&lt;</span>U<span class="keyword">&gt;</span>(row[i]), totals[i]);</div>
<div class="line"><a id="l00342" name="l00342"></a><span class="lineno"> 342</span> }</div>
<div class="line"><a id="l00343" name="l00343"></a><span class="lineno"> 343</span> } <span class="keywordflow">else</span> {</div>
<div class="line"><a id="l00344" name="l00344"></a><span class="lineno"> 344</span> U vals[n_reads];</div>
<div class="line"><a id="l00345" name="l00345"></a><span class="lineno"> 345</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00346" name="l00346"></a><span class="lineno"> 346</span> vals[i] =</div>
<div class="line"><a id="l00347" name="l00347"></a><span class="lineno"> 347</span> (column + i &lt; reduction_stride) ? static_cast&lt;U&gt;(row[i]) : <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>.init;</div>
<div class="line"><a id="l00348" name="l00348"></a><span class="lineno"> 348</span> }</div>
<div class="line"><a id="l00349" name="l00349"></a><span class="lineno"> 349</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00350" name="l00350"></a><span class="lineno"> 350</span> totals[i] = <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>(vals[i], totals[i]);</div>
<div class="line"><a id="l00351" name="l00351"></a><span class="lineno"> 351</span> }</div>
<div class="line"><a id="l00352" name="l00352"></a><span class="lineno"> 352</span> }</div>
<div class="line"><a id="l00353" name="l00353"></a><span class="lineno"> 353</span> </div>
<div class="line"><a id="l00354" name="l00354"></a><span class="lineno"> 354</span> loop.<a class="code hl_function" href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">next</a>(outer_blocks * BM, reduce_shape, reduce_strides);</div>
<div class="line"><a id="l00355" name="l00355"></a><span class="lineno"> 355</span> }</div>
<div class="line"><a id="l00356" name="l00356"></a><span class="lineno"> 356</span> </div>
<div class="line"><a id="l00357" name="l00357"></a><span class="lineno"> 357</span> <span class="comment">// We can use a simd reduction to accumulate across BM so each thread writes</span></div>
<div class="line"><a id="l00358" name="l00358"></a><span class="lineno"> 358</span> <span class="comment">// the partial output to SM and then each simdgroup does BN / n_simdgroups</span></div>
<div class="line"><a id="l00359" name="l00359"></a><span class="lineno"> 359</span> <span class="comment">// accumulations.</span></div>
<div class="line"><a id="l00360" name="l00360"></a><span class="lineno"> 360</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_reads; i++) {</div>
<div class="line"><a id="l00361" name="l00361"></a><span class="lineno"> 361</span> shared_vals[offset.y * BN + offset.x + i] = totals[i];</div>
<div class="line"><a id="l00362" name="l00362"></a><span class="lineno"> 362</span> }</div>
<div class="line"><a id="l00363" name="l00363"></a><span class="lineno"> 363</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
<div class="line"><a id="l00364" name="l00364"></a><span class="lineno"> 364</span> short2 out_offset(simd_group_id * n_outputs, simd_lane_id);</div>
<div class="line"><a id="l00365" name="l00365"></a><span class="lineno"> 365</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_outputs; i++) {</div>
<div class="line"><a id="l00366" name="l00366"></a><span class="lineno"> 366</span> totals[i] =</div>
<div class="line"><a id="l00367" name="l00367"></a><span class="lineno"> 367</span> <a class="code hl_variable" href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a>.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);</div>
<div class="line"><a id="l00368" name="l00368"></a><span class="lineno"> 368</span> }</div>
<div class="line"><a id="l00369" name="l00369"></a><span class="lineno"> 369</span> </div>
<div class="line"><a id="l00370" name="l00370"></a><span class="lineno"> 370</span> <span class="comment">// Write the output.</span></div>
<div class="line"><a id="l00371" name="l00371"></a><span class="lineno"> 371</span> <span class="keywordflow">if</span> (simd_lane_id == 0) {</div>
<div class="line"><a id="l00372" name="l00372"></a><span class="lineno"> 372</span> <span class="keywordtype">size_t</span> out_column = BN * gid.x + out_offset.x;</div>
<div class="line"><a id="l00373" name="l00373"></a><span class="lineno"> 373</span> out += full_idx * reduction_stride + out_column;</div>
<div class="line"><a id="l00374" name="l00374"></a><span class="lineno"> 374</span> <span class="keywordflow">if</span> (out_column + n_outputs &lt;= reduction_stride) {</div>
<div class="line"><a id="l00375" name="l00375"></a><span class="lineno"> 375</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; n_outputs; i++) {</div>
<div class="line"><a id="l00376" name="l00376"></a><span class="lineno"> 376</span> out[i] = totals[i];</div>
<div class="line"><a id="l00377" name="l00377"></a><span class="lineno"> 377</span> }</div>
<div class="line"><a id="l00378" name="l00378"></a><span class="lineno"> 378</span> } <span class="keywordflow">else</span> {</div>
<div class="line"><a id="l00379" name="l00379"></a><span class="lineno"> 379</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; out_column + i &lt; reduction_stride; i++) {</div>
<div class="line"><a id="l00380" name="l00380"></a><span class="lineno"> 380</span> out[i] = totals[i];</div>
<div class="line"><a id="l00381" name="l00381"></a><span class="lineno"> 381</span> }</div>
<div class="line"><a id="l00382" name="l00382"></a><span class="lineno"> 382</span> }</div>
<div class="line"><a id="l00383" name="l00383"></a><span class="lineno"> 383</span> }</div>
<div class="line"><a id="l00384" name="l00384"></a><span class="lineno"> 384</span>}</div>
</div>
<div class="ttc" id="abackend_2metal_2kernels_2reduction_2ops_8h_html_a515b75d563a93d3c09ee677948dc83e3"><div class="ttname"><a href="backend_2metal_2kernels_2reduction_2ops_8h.html#a515b75d563a93d3c09ee677948dc83e3">simd_size</a></div><div class="ttdeci">static constant constexpr const uint8_t simd_size</div><div class="ttdef"><b>Definition</b> ops.h:22</div></div>
<div class="ttc" id="abackend_2metal_2kernels_2utils_8h_html_a8fd0c8fc6058e650fc99bca8b6acd7d1"><div class="ttname"><a href="backend_2metal_2kernels_2utils_8h.html#a8fd0c8fc6058e650fc99bca8b6acd7d1">elem_to_loc</a></div><div class="ttdeci">METAL_FUNC stride_t elem_to_loc(uint elem, constant const int *shape, constant const stride_t *strides, int ndim)</div><div class="ttdef"><b>Definition</b> utils.h:87</div></div>
<div class="ttc" id="acommon_2binary_8h_html_a70228731d29946574b238d21fb4b360c"><div class="ttname"><a href="common_2binary_8h.html#a70228731d29946574b238d21fb4b360c">op</a></div><div class="ttdeci">Op op</div><div class="ttdef"><b>Definition</b> binary.h:129</div></div>
<div class="ttc" id="adefines_8h_html_a2ad505864a2ab786147766900bc18c21"><div class="ttname"><a href="defines_8h.html#a2ad505864a2ab786147766900bc18c21">REDUCE_N_READS</a></div><div class="ttdeci">static constexpr int REDUCE_N_READS</div><div class="ttdef"><b>Definition</b> defines.h:12</div></div>
<div class="ttc" id="areduce__col_8h_html_a11bfc6112ae2386ac03f5ea7b7d93385"><div class="ttname"><a href="reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385">col_reduce_looped</a></div><div class="ttdeci">void col_reduce_looped(const device T *in, device U *out, const constant size_t &amp;reduction_size, const constant size_t &amp;reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &amp;ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &amp;reduce_ndim, const constant size_t &amp;non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)</div><div class="ttdoc">Our approach is the following simple looped approach:</div><div class="ttdef"><b>Definition</b> reduce_col.h:202</div></div>
<div class="ttc" id="areduce__col_8h_html_adf7aeb18cd1d5042cf6d9b46b582d8ce"><div class="ttname"><a href="reduce__col_8h.html#adf7aeb18cd1d5042cf6d9b46b582d8ce">col_reduce_small</a></div><div class="ttdeci">void col_reduce_small(const device T *in, device U *out, const constant size_t &amp;reduction_size, const constant size_t &amp;reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &amp;ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &amp;reduce_ndim, const constant size_t &amp;non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id, uint3 tid, uint3 tsize)</div><div class="ttdef"><b>Definition</b> reduce_col.h:9</div></div>
<div class="ttc" id="areduce__col_8h_html_a0e92fc74eeaa8ee2ceb83bafc6eb1d7d"><div class="ttname"><a href="reduce__col_8h.html#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d">col_reduce_2pass</a></div><div class="ttdeci">void col_reduce_2pass(const device T *in, device U *out, const constant size_t &amp;reduction_size, const constant size_t &amp;reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &amp;ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &amp;reduce_ndim, const constant size_t &amp;non_col_reductions, const constant size_t &amp;out_size, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)</div><div class="ttdef"><b>Definition</b> reduce_col.h:287</div></div>
<div class="ttc" id="areduce__col_8h_html_a11bfc6112ae2386ac03f5ea7b7d93385"><div class="ttname"><a href="reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385">col_reduce_looped</a></div><div class="ttdeci">void col_reduce_looped(const device T *in, device U *out, const constant size_t &amp;reduction_size, const constant size_t &amp;reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &amp;ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &amp;reduce_ndim, const constant size_t &amp;non_col_reductions, uint3 gid, uint3 gsize, uint simd_lane_id, uint simd_group_id)</div><div class="ttdoc">Our approach is the following simple looped approach:</div><div class="ttdef"><b>Definition</b> reduce_col.h:155</div></div>
<div class="ttc" id="areduce__col_8h_html_a5b4f4c4c247ad341ff8d31dcbbbce0eb"><div class="ttname"><a href="reduce__col_8h.html#a5b4f4c4c247ad341ff8d31dcbbbce0eb">col_reduce_longcolumn</a></div><div class="ttdeci">void col_reduce_longcolumn(const device T *in, device U *out, const constant size_t &amp;reduction_size, const constant size_t &amp;reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &amp;ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &amp;reduce_ndim, const constant size_t &amp;non_col_reductions, const constant size_t &amp;out_size, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)</div><div class="ttdef"><b>Definition</b> reduce_col.h:97</div></div>
<div class="ttc" id="areduce__col_8h_html_a7c378443a2b6f4d9210db8a21a9ac4f5"><div class="ttname"><a href="reduce__col_8h.html#a7c378443a2b6f4d9210db8a21a9ac4f5">col_reduce_small</a></div><div class="ttdeci">void col_reduce_small(const device T *in, device U *out, const constant size_t &amp;reduction_size, const constant size_t &amp;reduction_stride, const constant int *shape, const constant size_t *strides, const constant int &amp;ndim, const constant int *reduce_shape, const constant size_t *reduce_strides, const constant int &amp;reduce_ndim, const constant size_t &amp;non_col_reductions, uint3 gid, uint3 gsize, uint3 lid, uint3 lsize)</div><div class="ttdef"><b>Definition</b> reduce_col.h:4</div></div>
<div class="ttc" id="astructlooped__elem__to__loc_html"><div class="ttname"><a href="structlooped__elem__to__loc.html">looped_elem_to_loc</a></div><div class="ttdef"><b>Definition</b> utils.h:197</div></div>
<div class="ttc" id="astructlooped__elem__to__loc_html_a05558dabba889ee0d80ed4b567d901ca"><div class="ttname"><a href="structlooped__elem__to__loc.html#a05558dabba889ee0d80ed4b567d901ca">looped_elem_to_loc::next</a></div><div class="ttdeci">void next(const constant int *shape, const constant size_t *strides)</div><div class="ttdef"><b>Definition</b> utils.h:202</div></div>
<div class="ttc" id="astructlooped__elem__to__loc_html_accc6d4957a8aeb38f5062754793b74d2"><div class="ttname"><a href="structlooped__elem__to__loc.html#accc6d4957a8aeb38f5062754793b74d2">looped_elem_to_loc::location</a></div><div class="ttdeci">offset_t location(offset_t, const constant int *, const constant size_t *, int)</div><div class="ttdef"><b>Definition</b> utils.h:229</div></div>

View File

@ -99,13 +99,13 @@ $(function(){ initResizable(false); });
<table class="memberdecls">
<tr class="heading"><td colspan="2"><h2 class="groupheader"><a id="func-members" name="func-members"></a>
Functions</h2></td></tr>
<tr class="memitem:a6f0d7918430064bab910bdaa6c64e927" id="r_a6f0d7918430064bab910bdaa6c64e927"><td class="memTemplParams" colspan="2">template&lt;typename T , int D&gt; </td></tr>
<tr class="memitem:a6f0d7918430064bab910bdaa6c64e927"><td class="memTemplItemLeft" align="right" valign="top">void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a6f0d7918430064bab910bdaa6c64e927">sdpa_vector</a> (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &amp;gqa_factor, const constant int &amp;N, const constant size_t &amp;k_stride, const constant float &amp;scale, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:a6f0d7918430064bab910bdaa6c64e927"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a4bf36f16e16c1c62d9b243573568e5ae" id="r_a4bf36f16e16c1c62d9b243573568e5ae"><td class="memTemplParams" colspan="2">template&lt;typename T , int D&gt; </td></tr>
<tr class="memitem:a4bf36f16e16c1c62d9b243573568e5ae"><td class="memTemplItemLeft" align="right" valign="top">void&#160;</td><td class="memTemplItemRight" valign="bottom"><a class="el" href="#a4bf36f16e16c1c62d9b243573568e5ae">sdpa_vector</a> (const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &amp;gqa_factor, const constant int &amp;N, const constant size_t &amp;k_stride, const constant size_t &amp;v_stride, const constant float &amp;scale, uint3 tid, uint simd_gid, uint simd_lid)</td></tr>
<tr class="separator:a4bf36f16e16c1c62d9b243573568e5ae"><td class="memSeparator" colspan="2">&#160;</td></tr>
</table>
<h2 class="groupheader">Function Documentation</h2>
<a id="a6f0d7918430064bab910bdaa6c64e927" name="a6f0d7918430064bab910bdaa6c64e927"></a>
<h2 class="memtitle"><span class="permalink"><a href="#a6f0d7918430064bab910bdaa6c64e927">&#9670;&#160;</a></span>sdpa_vector()</h2>
<a id="a4bf36f16e16c1c62d9b243573568e5ae" name="a4bf36f16e16c1c62d9b243573568e5ae"></a>
<h2 class="memtitle"><span class="permalink"><a href="#a4bf36f16e16c1c62d9b243573568e5ae">&#9670;&#160;</a></span>sdpa_vector()</h2>
<div class="memitem">
<div class="memproto">
@ -147,6 +147,11 @@ template&lt;typename T , int D&gt; </div>
<td></td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>k_stride</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>
<td class="paramtype">const constant size_t &amp;</td> <td class="paramname"><span class="paramname"><em>v_stride</em></span>, </td>
</tr>
<tr>
<td class="paramkey"></td>
<td></td>

View File

@ -99,7 +99,7 @@ $(function(){ initResizable(false); });
<div class="line"><a id="l00006" name="l00006"></a><span class="lineno"> 6</span> </div>
<div class="line"><a id="l00007" name="l00007"></a><span class="lineno"> 7</span><span class="keyword">template</span> &lt;<span class="keyword">typename</span> T, <span class="keywordtype">int</span> D&gt;</div>
<div class="foldopen" id="foldopen00008" data-start="{" data-end="}">
<div class="line"><a id="l00008" name="l00008"></a><span class="lineno"><a class="line" href="sdpa__vector_8h.html#a6f0d7918430064bab910bdaa6c64e927"> 8</a></span>[[kernel]] <span class="keywordtype">void</span> <a class="code hl_function" href="sdpa__vector_8h.html#a6f0d7918430064bab910bdaa6c64e927">sdpa_vector</a>(</div>
<div class="line"><a id="l00008" name="l00008"></a><span class="lineno"><a class="line" href="sdpa__vector_8h.html#a4bf36f16e16c1c62d9b243573568e5ae"> 8</a></span>[[kernel]] <span class="keywordtype">void</span> <a class="code hl_function" href="sdpa__vector_8h.html#a4bf36f16e16c1c62d9b243573568e5ae">sdpa_vector</a>(</div>
<div class="line"><a id="l00009" name="l00009"></a><span class="lineno"> 9</span> <span class="keyword">const</span> device T* queries [[buffer(0)]],</div>
<div class="line"><a id="l00010" name="l00010"></a><span class="lineno"> 10</span> <span class="keyword">const</span> device T* keys [[buffer(1)]],</div>
<div class="line"><a id="l00011" name="l00011"></a><span class="lineno"> 11</span> <span class="keyword">const</span> device T* values [[buffer(2)]],</div>
@ -107,113 +107,114 @@ $(function(){ initResizable(false); });
<div class="line"><a id="l00013" name="l00013"></a><span class="lineno"> 13</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>&amp; gqa_factor,</div>
<div class="line"><a id="l00014" name="l00014"></a><span class="lineno"> 14</span> <span class="keyword">const</span> constant <span class="keywordtype">int</span>&amp; N,</div>
<div class="line"><a id="l00015" name="l00015"></a><span class="lineno"> 15</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; k_stride,</div>
<div class="line"><a id="l00016" name="l00016"></a><span class="lineno"> 16</span> <span class="keyword">const</span> constant <span class="keywordtype">float</span>&amp; scale,</div>
<div class="line"><a id="l00017" name="l00017"></a><span class="lineno"> 17</span> uint3 tid [[threadgroup_position_in_grid]],</div>
<div class="line"><a id="l00018" name="l00018"></a><span class="lineno"> 18</span> uint simd_gid [[simdgroup_index_in_threadgroup]],</div>
<div class="line"><a id="l00019" name="l00019"></a><span class="lineno"> 19</span> uint simd_lid [[thread_index_in_simdgroup]]) {</div>
<div class="line"><a id="l00020" name="l00020"></a><span class="lineno"> 20</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> BN = 32;</div>
<div class="line"><a id="l00021" name="l00021"></a><span class="lineno"> 21</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> BD = 32;</div>
<div class="line"><a id="l00022" name="l00022"></a><span class="lineno"> 22</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> elem_per_thread = D / BD;</div>
<div class="line"><a id="l00023" name="l00023"></a><span class="lineno"> 23</span> </div>
<div class="line"><a id="l00024" name="l00024"></a><span class="lineno"> 24</span> <span class="keyword">const</span> <span class="keywordtype">int</span> stride = BN * D;</div>
<div class="line"><a id="l00025" name="l00025"></a><span class="lineno"> 25</span> </div>
<div class="line"><a id="l00026" name="l00026"></a><span class="lineno"> 26</span> <span class="keyword">typedef</span> <span class="keywordtype">float</span> U;</div>
<div class="line"><a id="l00027" name="l00027"></a><span class="lineno"> 27</span> </div>
<div class="line"><a id="l00028" name="l00028"></a><span class="lineno"> 28</span> thread U q[elem_per_thread];</div>
<div class="line"><a id="l00029" name="l00029"></a><span class="lineno"> 29</span> thread U k[elem_per_thread];</div>
<div class="line"><a id="l00030" name="l00030"></a><span class="lineno"> 30</span> thread U o[elem_per_thread];</div>
<div class="line"><a id="l00031" name="l00031"></a><span class="lineno"> 31</span> </div>
<div class="line"><a id="l00032" name="l00032"></a><span class="lineno"> 32</span> threadgroup U outputs[BN * BD];</div>
<div class="line"><a id="l00033" name="l00033"></a><span class="lineno"> 33</span> threadgroup U max_scores[BN];</div>
<div class="line"><a id="l00034" name="l00034"></a><span class="lineno"> 34</span> threadgroup U sum_exp_scores[BN];</div>
<div class="line"><a id="l00035" name="l00035"></a><span class="lineno"> 35</span> </div>
<div class="line"><a id="l00036" name="l00036"></a><span class="lineno"> 36</span> <span class="comment">// Adjust positions</span></div>
<div class="line"><a id="l00037" name="l00037"></a><span class="lineno"> 37</span> <span class="keyword">const</span> <span class="keywordtype">int</span> head_idx = tid.y;</div>
<div class="line"><a id="l00038" name="l00038"></a><span class="lineno"> 38</span> <span class="keyword">const</span> <span class="keywordtype">int</span> kv_head_idx = head_idx / gqa_factor;</div>
<div class="line"><a id="l00039" name="l00039"></a><span class="lineno"> 39</span> queries += head_idx * D + simd_lid * elem_per_thread;</div>
<div class="line"><a id="l00040" name="l00040"></a><span class="lineno"> 40</span> keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;</div>
<div class="line"><a id="l00041" name="l00041"></a><span class="lineno"> 41</span> values += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;</div>
<div class="line"><a id="l00042" name="l00042"></a><span class="lineno"> 42</span> out += head_idx * D + simd_gid * elem_per_thread;</div>
<div class="line"><a id="l00043" name="l00043"></a><span class="lineno"> 43</span> </div>
<div class="line"><a id="l00044" name="l00044"></a><span class="lineno"> 44</span> <span class="comment">// Read the query and 0 the output accumulator</span></div>
<div class="line"><a id="l00045" name="l00045"></a><span class="lineno"> 45</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; elem_per_thread; i++) {</div>
<div class="line"><a id="l00046" name="l00046"></a><span class="lineno"> 46</span> q[i] = <span class="keyword">static_cast&lt;</span>U<span class="keyword">&gt;</span>(scale) * queries[i];</div>
<div class="line"><a id="l00047" name="l00047"></a><span class="lineno"> 47</span> }</div>
<div class="line"><a id="l00048" name="l00048"></a><span class="lineno"> 48</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; elem_per_thread; i++) {</div>
<div class="line"><a id="l00049" name="l00049"></a><span class="lineno"> 49</span> o[i] = 0;</div>
<div class="line"><a id="l00050" name="l00050"></a><span class="lineno"> 50</span> }</div>
<div class="line"><a id="l00051" name="l00051"></a><span class="lineno"> 51</span> </div>
<div class="line"><a id="l00052" name="l00052"></a><span class="lineno"> 52</span> U max_score = -INFINITY;</div>
<div class="line"><a id="l00053" name="l00053"></a><span class="lineno"> 53</span> U sum_exp_score = 0;</div>
<div class="line"><a id="l00054" name="l00054"></a><span class="lineno"> 54</span> </div>
<div class="line"><a id="l00055" name="l00055"></a><span class="lineno"> 55</span> <span class="comment">// For each key</span></div>
<div class="line"><a id="l00056" name="l00056"></a><span class="lineno"> 56</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = simd_gid; i &lt; N; i += BN) {</div>
<div class="line"><a id="l00057" name="l00057"></a><span class="lineno"> 57</span> <span class="comment">// Read the key</span></div>
<div class="line"><a id="l00058" name="l00058"></a><span class="lineno"> 58</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; elem_per_thread; i++) {</div>
<div class="line"><a id="l00059" name="l00059"></a><span class="lineno"> 59</span> k[i] = keys[i];</div>
<div class="line"><a id="l00060" name="l00060"></a><span class="lineno"> 60</span> }</div>
<div class="line"><a id="l00061" name="l00061"></a><span class="lineno"> 61</span> </div>
<div class="line"><a id="l00062" name="l00062"></a><span class="lineno"> 62</span> <span class="comment">// Compute the i-th score</span></div>
<div class="line"><a id="l00063" name="l00063"></a><span class="lineno"> 63</span> U score = 0;</div>
<div class="line"><a id="l00064" name="l00064"></a><span class="lineno"> 64</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; elem_per_thread; i++) {</div>
<div class="line"><a id="l00065" name="l00065"></a><span class="lineno"> 65</span> score += q[i] * k[i];</div>
<div class="line"><a id="l00066" name="l00066"></a><span class="lineno"> 66</span> }</div>
<div class="line"><a id="l00067" name="l00067"></a><span class="lineno"> 67</span> score = <a class="code hl_function" href="namespacemetal.html#a85181e37a00cb4a4217f1bb25389bce5">simd_sum</a>(score);</div>
<div class="line"><a id="l00068" name="l00068"></a><span class="lineno"> 68</span> </div>
<div class="line"><a id="l00069" name="l00069"></a><span class="lineno"> 69</span> <span class="comment">// Update the accumulators</span></div>
<div class="line"><a id="l00070" name="l00070"></a><span class="lineno"> 70</span> U new_max = <a class="code hl_function" href="namespacemetal.html#a853c80479ab2264d9c4587c7bcac767b">max</a>(max_score, score);</div>
<div class="line"><a id="l00071" name="l00071"></a><span class="lineno"> 71</span> U factor = <a class="code hl_function" href="namespacemetal_1_1fast.html#ad3dbd387b63373c29e3449609f763ede">fast::exp</a>(max_score - new_max);</div>
<div class="line"><a id="l00072" name="l00072"></a><span class="lineno"> 72</span> U exp_score = <a class="code hl_function" href="namespacemetal_1_1fast.html#ad3dbd387b63373c29e3449609f763ede">fast::exp</a>(score - new_max);</div>
<div class="line"><a id="l00073" name="l00073"></a><span class="lineno"> 73</span> </div>
<div class="line"><a id="l00074" name="l00074"></a><span class="lineno"> 74</span> max_score = new_max;</div>
<div class="line"><a id="l00075" name="l00075"></a><span class="lineno"> 75</span> sum_exp_score = sum_exp_score * factor + exp_score;</div>
<div class="line"><a id="l00076" name="l00076"></a><span class="lineno"> 76</span> </div>
<div class="line"><a id="l00077" name="l00077"></a><span class="lineno"> 77</span> <span class="comment">// Update the output accumulator</span></div>
<div class="line"><a id="l00078" name="l00078"></a><span class="lineno"> 78</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; elem_per_thread; i++) {</div>
<div class="line"><a id="l00079" name="l00079"></a><span class="lineno"> 79</span> o[i] = o[i] * factor + exp_score * values[i];</div>
<div class="line"><a id="l00080" name="l00080"></a><span class="lineno"> 80</span> }</div>
<div class="line"><a id="l00081" name="l00081"></a><span class="lineno"> 81</span> </div>
<div class="line"><a id="l00082" name="l00082"></a><span class="lineno"> 82</span> <span class="comment">// Move the pointers to the next kv</span></div>
<div class="line"><a id="l00083" name="l00083"></a><span class="lineno"> 83</span> keys += stride;</div>
<div class="line"><a id="l00084" name="l00084"></a><span class="lineno"> 84</span> values += stride;</div>
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"> 85</span> }</div>
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"> 86</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
<div class="line"><a id="l00087" name="l00087"></a><span class="lineno"> 87</span> </div>
<div class="line"><a id="l00088" name="l00088"></a><span class="lineno"> 88</span> <span class="comment">// Each thread has a partial part of the output so we need to combine them.</span></div>
<div class="line"><a id="l00089" name="l00089"></a><span class="lineno"> 89</span> </div>
<div class="line"><a id="l00090" name="l00090"></a><span class="lineno"> 90</span> <span class="comment">// First let&#39;s communicate the max and sum_exp</span></div>
<div class="line"><a id="l00091" name="l00091"></a><span class="lineno"> 91</span> <span class="keywordflow">if</span> (simd_lid == 0) {</div>
<div class="line"><a id="l00092" name="l00092"></a><span class="lineno"> 92</span> max_scores[simd_gid] = max_score;</div>
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"> 93</span> sum_exp_scores[simd_gid] = sum_exp_score;</div>
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"> 94</span> }</div>
<div class="line"><a id="l00095" name="l00095"></a><span class="lineno"> 95</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
<div class="line"><a id="l00096" name="l00096"></a><span class="lineno"> 96</span> max_score = max_scores[simd_lid];</div>
<div class="line"><a id="l00097" name="l00097"></a><span class="lineno"> 97</span> U new_max = <a class="code hl_function" href="namespacemetal.html#a048cad0aca52cb737ebf103e76bd1c49">simd_max</a>(max_score);</div>
<div class="line"><a id="l00098" name="l00098"></a><span class="lineno"> 98</span> U factor = <a class="code hl_function" href="namespacemetal_1_1fast.html#ad3dbd387b63373c29e3449609f763ede">fast::exp</a>(max_score - new_max);</div>
<div class="line"><a id="l00099" name="l00099"></a><span class="lineno"> 99</span> sum_exp_score = <a class="code hl_function" href="namespacemetal.html#a85181e37a00cb4a4217f1bb25389bce5">simd_sum</a>(sum_exp_scores[simd_lid] * factor);</div>
<div class="line"><a id="l00100" name="l00100"></a><span class="lineno"> 100</span> </div>
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"> 101</span> <span class="comment">// Now we need to aggregate all the outputs</span></div>
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"> 102</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; elem_per_thread; i++) {</div>
<div class="line"><a id="l00103" name="l00103"></a><span class="lineno"> 103</span> outputs[simd_lid * BD + simd_gid] = o[i];</div>
<div class="line"><a id="l00104" name="l00104"></a><span class="lineno"> 104</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
<div class="line"><a id="l00105" name="l00105"></a><span class="lineno"> 105</span> o[i] = <a class="code hl_function" href="namespacemetal.html#a85181e37a00cb4a4217f1bb25389bce5">simd_sum</a>(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;</div>
<div class="line"><a id="l00106" name="l00106"></a><span class="lineno"> 106</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
<div class="line"><a id="l00107" name="l00107"></a><span class="lineno"> 107</span> }</div>
<div class="line"><a id="l00108" name="l00108"></a><span class="lineno"> 108</span> </div>
<div class="line"><a id="l00109" name="l00109"></a><span class="lineno"> 109</span> <span class="comment">// And write the output</span></div>
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"> 110</span> <span class="keywordflow">if</span> (simd_lid == 0) {</div>
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"> 111</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; elem_per_thread; i++) {</div>
<div class="line"><a id="l00112" name="l00112"></a><span class="lineno"> 112</span> out[i] = <span class="keyword">static_cast&lt;</span>T<span class="keyword">&gt;</span>(o[i]);</div>
<div class="line"><a id="l00113" name="l00113"></a><span class="lineno"> 113</span> }</div>
<div class="line"><a id="l00016" name="l00016"></a><span class="lineno"> 16</span> <span class="keyword">const</span> constant <span class="keywordtype">size_t</span>&amp; v_stride,</div>
<div class="line"><a id="l00017" name="l00017"></a><span class="lineno"> 17</span> <span class="keyword">const</span> constant <span class="keywordtype">float</span>&amp; scale,</div>
<div class="line"><a id="l00018" name="l00018"></a><span class="lineno"> 18</span> uint3 tid [[threadgroup_position_in_grid]],</div>
<div class="line"><a id="l00019" name="l00019"></a><span class="lineno"> 19</span> uint simd_gid [[simdgroup_index_in_threadgroup]],</div>
<div class="line"><a id="l00020" name="l00020"></a><span class="lineno"> 20</span> uint simd_lid [[thread_index_in_simdgroup]]) {</div>
<div class="line"><a id="l00021" name="l00021"></a><span class="lineno"> 21</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> BN = 32;</div>
<div class="line"><a id="l00022" name="l00022"></a><span class="lineno"> 22</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> BD = 32;</div>
<div class="line"><a id="l00023" name="l00023"></a><span class="lineno"> 23</span> <span class="keyword">constexpr</span> <span class="keywordtype">int</span> elem_per_thread = D / BD;</div>
<div class="line"><a id="l00024" name="l00024"></a><span class="lineno"> 24</span> </div>
<div class="line"><a id="l00025" name="l00025"></a><span class="lineno"> 25</span> <span class="keyword">const</span> <span class="keywordtype">int</span> stride = BN * D;</div>
<div class="line"><a id="l00026" name="l00026"></a><span class="lineno"> 26</span> </div>
<div class="line"><a id="l00027" name="l00027"></a><span class="lineno"> 27</span> <span class="keyword">typedef</span> <span class="keywordtype">float</span> U;</div>
<div class="line"><a id="l00028" name="l00028"></a><span class="lineno"> 28</span> </div>
<div class="line"><a id="l00029" name="l00029"></a><span class="lineno"> 29</span> thread U q[elem_per_thread];</div>
<div class="line"><a id="l00030" name="l00030"></a><span class="lineno"> 30</span> thread U k[elem_per_thread];</div>
<div class="line"><a id="l00031" name="l00031"></a><span class="lineno"> 31</span> thread U o[elem_per_thread];</div>
<div class="line"><a id="l00032" name="l00032"></a><span class="lineno"> 32</span> </div>
<div class="line"><a id="l00033" name="l00033"></a><span class="lineno"> 33</span> threadgroup U outputs[BN * BD];</div>
<div class="line"><a id="l00034" name="l00034"></a><span class="lineno"> 34</span> threadgroup U max_scores[BN];</div>
<div class="line"><a id="l00035" name="l00035"></a><span class="lineno"> 35</span> threadgroup U sum_exp_scores[BN];</div>
<div class="line"><a id="l00036" name="l00036"></a><span class="lineno"> 36</span> </div>
<div class="line"><a id="l00037" name="l00037"></a><span class="lineno"> 37</span> <span class="comment">// Adjust positions</span></div>
<div class="line"><a id="l00038" name="l00038"></a><span class="lineno"> 38</span> <span class="keyword">const</span> <span class="keywordtype">int</span> head_idx = tid.y;</div>
<div class="line"><a id="l00039" name="l00039"></a><span class="lineno"> 39</span> <span class="keyword">const</span> <span class="keywordtype">int</span> kv_head_idx = head_idx / gqa_factor;</div>
<div class="line"><a id="l00040" name="l00040"></a><span class="lineno"> 40</span> queries += head_idx * D + simd_lid * elem_per_thread;</div>
<div class="line"><a id="l00041" name="l00041"></a><span class="lineno"> 41</span> keys += kv_head_idx * k_stride + simd_gid * D + simd_lid * elem_per_thread;</div>
<div class="line"><a id="l00042" name="l00042"></a><span class="lineno"> 42</span> values += kv_head_idx * v_stride + simd_gid * D + simd_lid * elem_per_thread;</div>
<div class="line"><a id="l00043" name="l00043"></a><span class="lineno"> 43</span> out += head_idx * D + simd_gid * elem_per_thread;</div>
<div class="line"><a id="l00044" name="l00044"></a><span class="lineno"> 44</span> </div>
<div class="line"><a id="l00045" name="l00045"></a><span class="lineno"> 45</span> <span class="comment">// Read the query and 0 the output accumulator</span></div>
<div class="line"><a id="l00046" name="l00046"></a><span class="lineno"> 46</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; elem_per_thread; i++) {</div>
<div class="line"><a id="l00047" name="l00047"></a><span class="lineno"> 47</span> q[i] = <span class="keyword">static_cast&lt;</span>U<span class="keyword">&gt;</span>(scale) * queries[i];</div>
<div class="line"><a id="l00048" name="l00048"></a><span class="lineno"> 48</span> }</div>
<div class="line"><a id="l00049" name="l00049"></a><span class="lineno"> 49</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; elem_per_thread; i++) {</div>
<div class="line"><a id="l00050" name="l00050"></a><span class="lineno"> 50</span> o[i] = 0;</div>
<div class="line"><a id="l00051" name="l00051"></a><span class="lineno"> 51</span> }</div>
<div class="line"><a id="l00052" name="l00052"></a><span class="lineno"> 52</span> </div>
<div class="line"><a id="l00053" name="l00053"></a><span class="lineno"> 53</span> U max_score = -INFINITY;</div>
<div class="line"><a id="l00054" name="l00054"></a><span class="lineno"> 54</span> U sum_exp_score = 0;</div>
<div class="line"><a id="l00055" name="l00055"></a><span class="lineno"> 55</span> </div>
<div class="line"><a id="l00056" name="l00056"></a><span class="lineno"> 56</span> <span class="comment">// For each key</span></div>
<div class="line"><a id="l00057" name="l00057"></a><span class="lineno"> 57</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = simd_gid; i &lt; N; i += BN) {</div>
<div class="line"><a id="l00058" name="l00058"></a><span class="lineno"> 58</span> <span class="comment">// Read the key</span></div>
<div class="line"><a id="l00059" name="l00059"></a><span class="lineno"> 59</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; elem_per_thread; i++) {</div>
<div class="line"><a id="l00060" name="l00060"></a><span class="lineno"> 60</span> k[i] = keys[i];</div>
<div class="line"><a id="l00061" name="l00061"></a><span class="lineno"> 61</span> }</div>
<div class="line"><a id="l00062" name="l00062"></a><span class="lineno"> 62</span> </div>
<div class="line"><a id="l00063" name="l00063"></a><span class="lineno"> 63</span> <span class="comment">// Compute the i-th score</span></div>
<div class="line"><a id="l00064" name="l00064"></a><span class="lineno"> 64</span> U score = 0;</div>
<div class="line"><a id="l00065" name="l00065"></a><span class="lineno"> 65</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; elem_per_thread; i++) {</div>
<div class="line"><a id="l00066" name="l00066"></a><span class="lineno"> 66</span> score += q[i] * k[i];</div>
<div class="line"><a id="l00067" name="l00067"></a><span class="lineno"> 67</span> }</div>
<div class="line"><a id="l00068" name="l00068"></a><span class="lineno"> 68</span> score = <a class="code hl_function" href="namespacemetal.html#a85181e37a00cb4a4217f1bb25389bce5">simd_sum</a>(score);</div>
<div class="line"><a id="l00069" name="l00069"></a><span class="lineno"> 69</span> </div>
<div class="line"><a id="l00070" name="l00070"></a><span class="lineno"> 70</span> <span class="comment">// Update the accumulators</span></div>
<div class="line"><a id="l00071" name="l00071"></a><span class="lineno"> 71</span> U new_max = <a class="code hl_function" href="namespacemetal.html#a853c80479ab2264d9c4587c7bcac767b">max</a>(max_score, score);</div>
<div class="line"><a id="l00072" name="l00072"></a><span class="lineno"> 72</span> U factor = <a class="code hl_function" href="namespacemetal_1_1fast.html#ad3dbd387b63373c29e3449609f763ede">fast::exp</a>(max_score - new_max);</div>
<div class="line"><a id="l00073" name="l00073"></a><span class="lineno"> 73</span> U exp_score = <a class="code hl_function" href="namespacemetal_1_1fast.html#ad3dbd387b63373c29e3449609f763ede">fast::exp</a>(score - new_max);</div>
<div class="line"><a id="l00074" name="l00074"></a><span class="lineno"> 74</span> </div>
<div class="line"><a id="l00075" name="l00075"></a><span class="lineno"> 75</span> max_score = new_max;</div>
<div class="line"><a id="l00076" name="l00076"></a><span class="lineno"> 76</span> sum_exp_score = sum_exp_score * factor + exp_score;</div>
<div class="line"><a id="l00077" name="l00077"></a><span class="lineno"> 77</span> </div>
<div class="line"><a id="l00078" name="l00078"></a><span class="lineno"> 78</span> <span class="comment">// Update the output accumulator</span></div>
<div class="line"><a id="l00079" name="l00079"></a><span class="lineno"> 79</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; elem_per_thread; i++) {</div>
<div class="line"><a id="l00080" name="l00080"></a><span class="lineno"> 80</span> o[i] = o[i] * factor + exp_score * values[i];</div>
<div class="line"><a id="l00081" name="l00081"></a><span class="lineno"> 81</span> }</div>
<div class="line"><a id="l00082" name="l00082"></a><span class="lineno"> 82</span> </div>
<div class="line"><a id="l00083" name="l00083"></a><span class="lineno"> 83</span> <span class="comment">// Move the pointers to the next kv</span></div>
<div class="line"><a id="l00084" name="l00084"></a><span class="lineno"> 84</span> keys += stride;</div>
<div class="line"><a id="l00085" name="l00085"></a><span class="lineno"> 85</span> values += stride;</div>
<div class="line"><a id="l00086" name="l00086"></a><span class="lineno"> 86</span> }</div>
<div class="line"><a id="l00087" name="l00087"></a><span class="lineno"> 87</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
<div class="line"><a id="l00088" name="l00088"></a><span class="lineno"> 88</span> </div>
<div class="line"><a id="l00089" name="l00089"></a><span class="lineno"> 89</span> <span class="comment">// Each thread has a partial part of the output so we need to combine them.</span></div>
<div class="line"><a id="l00090" name="l00090"></a><span class="lineno"> 90</span> </div>
<div class="line"><a id="l00091" name="l00091"></a><span class="lineno"> 91</span> <span class="comment">// First let&#39;s communicate the max and sum_exp</span></div>
<div class="line"><a id="l00092" name="l00092"></a><span class="lineno"> 92</span> <span class="keywordflow">if</span> (simd_lid == 0) {</div>
<div class="line"><a id="l00093" name="l00093"></a><span class="lineno"> 93</span> max_scores[simd_gid] = max_score;</div>
<div class="line"><a id="l00094" name="l00094"></a><span class="lineno"> 94</span> sum_exp_scores[simd_gid] = sum_exp_score;</div>
<div class="line"><a id="l00095" name="l00095"></a><span class="lineno"> 95</span> }</div>
<div class="line"><a id="l00096" name="l00096"></a><span class="lineno"> 96</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
<div class="line"><a id="l00097" name="l00097"></a><span class="lineno"> 97</span> max_score = max_scores[simd_lid];</div>
<div class="line"><a id="l00098" name="l00098"></a><span class="lineno"> 98</span> U new_max = <a class="code hl_function" href="namespacemetal.html#a048cad0aca52cb737ebf103e76bd1c49">simd_max</a>(max_score);</div>
<div class="line"><a id="l00099" name="l00099"></a><span class="lineno"> 99</span> U factor = <a class="code hl_function" href="namespacemetal_1_1fast.html#ad3dbd387b63373c29e3449609f763ede">fast::exp</a>(max_score - new_max);</div>
<div class="line"><a id="l00100" name="l00100"></a><span class="lineno"> 100</span> sum_exp_score = <a class="code hl_function" href="namespacemetal.html#a85181e37a00cb4a4217f1bb25389bce5">simd_sum</a>(sum_exp_scores[simd_lid] * factor);</div>
<div class="line"><a id="l00101" name="l00101"></a><span class="lineno"> 101</span> </div>
<div class="line"><a id="l00102" name="l00102"></a><span class="lineno"> 102</span> <span class="comment">// Now we need to aggregate all the outputs</span></div>
<div class="line"><a id="l00103" name="l00103"></a><span class="lineno"> 103</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; elem_per_thread; i++) {</div>
<div class="line"><a id="l00104" name="l00104"></a><span class="lineno"> 104</span> outputs[simd_lid * BD + simd_gid] = o[i];</div>
<div class="line"><a id="l00105" name="l00105"></a><span class="lineno"> 105</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
<div class="line"><a id="l00106" name="l00106"></a><span class="lineno"> 106</span> o[i] = <a class="code hl_function" href="namespacemetal.html#a85181e37a00cb4a4217f1bb25389bce5">simd_sum</a>(outputs[simd_gid * BD + simd_lid] * factor) / sum_exp_score;</div>
<div class="line"><a id="l00107" name="l00107"></a><span class="lineno"> 107</span> threadgroup_barrier(mem_flags::mem_threadgroup);</div>
<div class="line"><a id="l00108" name="l00108"></a><span class="lineno"> 108</span> }</div>
<div class="line"><a id="l00109" name="l00109"></a><span class="lineno"> 109</span> </div>
<div class="line"><a id="l00110" name="l00110"></a><span class="lineno"> 110</span> <span class="comment">// And write the output</span></div>
<div class="line"><a id="l00111" name="l00111"></a><span class="lineno"> 111</span> <span class="keywordflow">if</span> (simd_lid == 0) {</div>
<div class="line"><a id="l00112" name="l00112"></a><span class="lineno"> 112</span> <span class="keywordflow">for</span> (<span class="keywordtype">int</span> i = 0; i &lt; elem_per_thread; i++) {</div>
<div class="line"><a id="l00113" name="l00113"></a><span class="lineno"> 113</span> out[i] = <span class="keyword">static_cast&lt;</span>T<span class="keyword">&gt;</span>(o[i]);</div>
<div class="line"><a id="l00114" name="l00114"></a><span class="lineno"> 114</span> }</div>
<div class="line"><a id="l00115" name="l00115"></a><span class="lineno"> 115</span>}</div>
<div class="line"><a id="l00115" name="l00115"></a><span class="lineno"> 115</span> }</div>
<div class="line"><a id="l00116" name="l00116"></a><span class="lineno"> 116</span>}</div>
</div>
<div class="ttc" id="anamespacemetal_1_1fast_html_ad3dbd387b63373c29e3449609f763ede"><div class="ttname"><a href="namespacemetal_1_1fast.html#ad3dbd387b63373c29e3449609f763ede">metal::fast::exp</a></div><div class="ttdeci">METAL_FUNC bfloat16_t exp(bfloat16_t x)</div><div class="ttdef"><b>Definition</b> bf16_math.h:242</div></div>
<div class="ttc" id="anamespacemetal_html"><div class="ttname"><a href="namespacemetal.html">metal</a></div><div class="ttdef"><b>Definition</b> bf16.h:265</div></div>
<div class="ttc" id="anamespacemetal_html_a048cad0aca52cb737ebf103e76bd1c49"><div class="ttname"><a href="namespacemetal.html#a048cad0aca52cb737ebf103e76bd1c49">metal::simd_max</a></div><div class="ttdeci">METAL_FUNC bfloat16_t simd_max(bfloat16_t data)</div><div class="ttdef"><b>Definition</b> bf16_math.h:392</div></div>
<div class="ttc" id="anamespacemetal_html_a85181e37a00cb4a4217f1bb25389bce5"><div class="ttname"><a href="namespacemetal.html#a85181e37a00cb4a4217f1bb25389bce5">metal::simd_sum</a></div><div class="ttdeci">METAL_FUNC bfloat16_t simd_sum(bfloat16_t data)</div><div class="ttdef"><b>Definition</b> bf16_math.h:392</div></div>
<div class="ttc" id="anamespacemetal_html_a853c80479ab2264d9c4587c7bcac767b"><div class="ttname"><a href="namespacemetal.html#a853c80479ab2264d9c4587c7bcac767b">metal::max</a></div><div class="ttdeci">METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)</div><div class="ttdef"><b>Definition</b> bf16_math.h:234</div></div>
<div class="ttc" id="asdpa__vector_8h_html_a6f0d7918430064bab910bdaa6c64e927"><div class="ttname"><a href="sdpa__vector_8h.html#a6f0d7918430064bab910bdaa6c64e927">sdpa_vector</a></div><div class="ttdeci">void sdpa_vector(const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &amp;gqa_factor, const constant int &amp;N, const constant size_t &amp;k_stride, const constant float &amp;scale, uint3 tid, uint simd_gid, uint simd_lid)</div><div class="ttdef"><b>Definition</b> sdpa_vector.h:8</div></div>
<div class="ttc" id="asdpa__vector_8h_html_a4bf36f16e16c1c62d9b243573568e5ae"><div class="ttname"><a href="sdpa__vector_8h.html#a4bf36f16e16c1c62d9b243573568e5ae">sdpa_vector</a></div><div class="ttdeci">void sdpa_vector(const device T *queries, const device T *keys, const device T *values, device T *out, const constant int &amp;gqa_factor, const constant int &amp;N, const constant size_t &amp;k_stride, const constant size_t &amp;v_stride, const constant float &amp;scale, uint3 tid, uint simd_gid, uint simd_lid)</div><div class="ttdef"><b>Definition</b> sdpa_vector.h:8</div></div>
</div><!-- fragment --></div><!-- contents -->
<!-- start footer part -->
<hr class="footer"/><address class="footer"><small>

View File

@ -38,7 +38,7 @@ var searchData=
['all_35',['all',['../group__ops.html#ga3b1b90ef1275ca17655b6d7f25d3ee68',1,'mlx::core::all(const array &amp;a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga3689e12e8f42dadb4cbe2b07dc4099f4',1,'mlx::core::all(const array &amp;a, StreamOrDevice s={})'],['../group__ops.html#gac0919c6ba53aea35a7683dea7e9a9a59',1,'mlx::core::all(const array &amp;a, const std::vector&lt; int &gt; &amp;axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#gae2d5fcc5b62d673cca76c08b7b4afbbc',1,'mlx::core::all(const array &amp;a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
['all_5fgather_36',['all_gather',['../namespacemlx_1_1core_1_1distributed_1_1detail.html#aeb5a1726358213bc75756506f7b54d04',1,'mlx::core::distributed::detail::all_gather()'],['../namespacemlx_1_1core_1_1distributed.html#a82ef5e8cc7ac62cd228e51b1c1b77cb7',1,'mlx::core::distributed::all_gather()']]],
['all_5freduce_37',['all_reduce',['../reduce__all_8h.html#a99ef48ae72b3e715c5f4d7ea07cd213d',1,'reduce_all.h']]],
['all_5freduce_5fdispatch_38',['all_reduce_dispatch',['../namespacemlx_1_1core.html#af7b7ca7c6aa87558d9f98cee5c7a99a8',1,'mlx::core']]],
['all_5freduce_5fdispatch_38',['all_reduce_dispatch',['../namespacemlx_1_1core.html#a3ab0fd997d9a35782106ff083a72e098',1,'mlx::core']]],
['all_5fsum_39',['all_sum',['../namespacemlx_1_1core_1_1distributed_1_1detail.html#aa1d225b25f7b6426c48c5e35860ee960',1,'mlx::core::distributed::detail::all_sum()'],['../namespacemlx_1_1core_1_1distributed.html#a67ccb1a5445fc6f5db49dd36a15e5980',1,'mlx::core::distributed::all_sum()']]],
['allclose_40',['allclose',['../group__ops.html#gaf0cd4257de7542daf9faf5e605e31020',1,'mlx::core']]],
['allgather_41',['AllGather',['../classmlx_1_1core_1_1distributed_1_1_all_gather.html',1,'mlx::core::distributed::AllGather'],['../classmlx_1_1core_1_1distributed_1_1_all_gather.html#af4b10a5b61f160fb64353057c185b661',1,'mlx::core::distributed::AllGather::AllGather()']]],

View File

@ -27,5 +27,6 @@ var searchData=
['queue_24',['queue',['../structmlx_1_1core_1_1metal_1_1_device_stream.html#a77c75a63c51ea56815a86bd882ed190d',1,'mlx::core::metal::DeviceStream']]],
['quiet_5fnan_25',['quiet_NaN',['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#aebeb07c01984be246bc2d1b8f8e4ac7b',1,'metal::_numeric_limits_impl&lt; bfloat16_t &gt;']]],
['qvm_26',['qvm',['../quantized_8h.html#ad84f7d5ab9e32dbbe3ca759ae5d5d5c5',1,'quantized.h']]],
['qvm_5fimpl_27',['qvm_impl',['../quantized_8h.html#a4a8c8db7d5d480733726fd6d1a645e12',1,'quantized.h']]]
['qvm_5fimpl_27',['qvm_impl',['../quantized_8h.html#a1546533c5b925b2fbb3bec870ec7487a',1,'quantized.h']]],
['qvm_5fsplit_5fk_28',['qvm_split_k',['../quantized_8h.html#ab8243818512d6078d23e6ffb65fd7bb8',1,'quantized.h']]]
];

View File

@ -28,7 +28,7 @@ var searchData=
['scheduler_25',['Scheduler',['../classmlx_1_1core_1_1scheduler_1_1_scheduler.html',1,'mlx::core::scheduler::Scheduler'],['../classmlx_1_1core_1_1scheduler_1_1_scheduler.html#a3ae42aed78a2200e9d02776fcd2316ba',1,'mlx::core::scheduler::Scheduler::Scheduler()'],['../classmlx_1_1core_1_1scheduler_1_1_scheduler.html#a61a74e3628899e66dde600e24a750648',1,'mlx::core::scheduler::Scheduler::Scheduler(const Scheduler &amp;)=delete'],['../classmlx_1_1core_1_1scheduler_1_1_scheduler.html#ac3f77b7c93220dadd0b3bb2e903b7059',1,'mlx::core::scheduler::Scheduler::Scheduler(Scheduler &amp;&amp;)=delete']]],
['scheduler_26',['scheduler',['../namespacemlx_1_1core_1_1scheduler.html#ae856e468c2f7c8f8ec672522cc13730b',1,'mlx::core::scheduler']]],
['scheduler_2eh_27',['scheduler.h',['../scheduler_8h.html',1,'']]],
['sdpa_5fvector_28',['sdpa_vector',['../sdpa__vector_8h.html#a6f0d7918430064bab910bdaa6c64e927',1,'sdpa_vector.h']]],
['sdpa_5fvector_28',['sdpa_vector',['../sdpa__vector_8h.html#a4bf36f16e16c1c62d9b243573568e5ae',1,'sdpa_vector.h']]],
['sdpa_5fvector_2eh_29',['sdpa_vector.h',['../sdpa__vector_8h.html',1,'']]],
['seed_30',['seed',['../classmlx_1_1core_1_1random_1_1_key_sequence.html#a9f19c5da2031cba50d0ff996924347d8',1,'mlx::core::random::KeySequence::seed()'],['../namespacemlx_1_1core_1_1random.html#ac4ad325b613257306df74595d3d0e23b',1,'mlx::core::random::seed()']]],
['seek_31',['seek',['../structmlx_1_1core_1_1_contiguous_iterator.html#a24719ee9e8667885d29c2ad74445520c',1,'mlx::core::ContiguousIterator::seek()'],['../classmlx_1_1core_1_1io_1_1_reader.html#acea55078bd39ccaa27a9a36f17a39cd1',1,'mlx::core::io::Reader::seek()'],['../classmlx_1_1core_1_1io_1_1_writer.html#a9c1716dda53aa36faea9c8fb1a3e34d4',1,'mlx::core::io::Writer::seek()'],['../classmlx_1_1core_1_1io_1_1_parallel_file_reader.html#a673c16b669f3cee13f387b7b0a1f39f7',1,'mlx::core::io::ParallelFileReader::seek()'],['../classmlx_1_1core_1_1io_1_1_file_writer.html#a9646f4ea048ae58719daeb588e2de433',1,'mlx::core::io::FileWriter::seek()']]],

View File

@ -35,118 +35,120 @@ var searchData=
['cmplx_3c_20thigh_20_3e_32',['cmplx&lt; Thigh &gt;',['../structpocketfft_1_1detail_1_1cmplx.html',1,'pocketfft::detail']]],
['cndarr_33',['cndarr',['../classpocketfft_1_1detail_1_1cndarr.html',1,'pocketfft::detail::cndarr&lt; T &gt;'],['../classpocketfft_1_1detail_1_1cndarr.html#abf73f1b4ddcfb27d7f85cfa441607129',1,'pocketfft::detail::cndarr::cndarr()']]],
['col_5fcontiguous_34',['col_contiguous',['../structmlx_1_1core_1_1array_1_1_flags.html#ae24709026598d635e6b5c24a15f8a802',1,'mlx::core::array::Flags']]],
['col_5freduce_5flooped_35',['col_reduce_looped',['../reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385',1,'reduce_col.h']]],
['col_5freduce_5fsmall_36',['col_reduce_small',['../reduce__col_8h.html#adf7aeb18cd1d5042cf6d9b46b582d8ce',1,'reduce_col.h']]],
['collapse_5fcontiguous_5fdims_37',['collapse_contiguous_dims',['../namespacemlx_1_1core.html#a38fe6ec5220d13d96c7dad7556d2b613',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; std::vector&lt; int64_t &gt; &gt; &amp;strides, int64_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#af2895f9b0083efd8221275eb8cadccbe',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; std::vector&lt; size_t &gt; &gt; &amp;strides, size_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#a90e2b6edc0fe82230cb93f5ea39febb4',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; array &gt; &amp;xs, size_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#ac813412cce77fc1340dcfefc6e099276',1,'mlx::core::collapse_contiguous_dims(Arrays &amp;&amp;... xs)'],['../namespacemlx_1_1core.html#aab3cc7f3808934ae0727b920eba231bd',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; int64_t &gt; &amp;strides, int64_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#a1e0cbcf109d32794ffc8efc7302ba9b0',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; size_t &gt; &amp;strides, size_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#a4ee50bfb240512d0c0ce151dfe2c74ef',1,'mlx::core::collapse_contiguous_dims(const array &amp;a, size_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())']]],
['commandencoder_38',['CommandEncoder',['../structmlx_1_1core_1_1metal_1_1_command_encoder.html',1,'mlx::core::metal::CommandEncoder'],['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#a2334774486f447213ee997e55c2e52a3',1,'mlx::core::metal::CommandEncoder::CommandEncoder(MTL::CommandBuffer *cbuf)'],['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#ac68ca977b5bde5434284ce7979647f14',1,'mlx::core::metal::CommandEncoder::CommandEncoder(const CommandEncoder &amp;)=delete']]],
['commit_5fcommand_5fbuffer_39',['commit_command_buffer',['../classmlx_1_1core_1_1metal_1_1_device.html#a95248f1387824067fd4fed23ace5ac0c',1,'mlx::core::metal::Device']]],
['commonallocator_40',['CommonAllocator',['../classmlx_1_1core_1_1allocator_1_1_common_allocator.html',1,'mlx::core::allocator']]],
['communication_5fstream_41',['communication_stream',['../namespacemlx_1_1core_1_1distributed_1_1detail.html#ac3612edf0e0e18c1e4ba0ce7c6e35cd6',1,'mlx::core::distributed::detail']]],
['compile_42',['compile',['../namespacemlx_1_1core.html#a3ac798e65e59fe10b7fb5c522efce782',1,'mlx::core::compile()'],['../namespacemlx_1_1core_1_1detail.html#ac3b7b09892ff7290d5f3ef26cb444329',1,'mlx::core::detail::compile()']]],
['compile_2eh_43',['compile.h',['../compile_8h.html',1,'']]],
['compile_5favailable_5ffor_5fdevice_44',['compile_available_for_device',['../namespacemlx_1_1core_1_1detail.html#aeeff2ba6ec3d9d4ed090de6d2681dbc2',1,'mlx::core::detail']]],
['compile_5fclear_5fcache_45',['compile_clear_cache',['../namespacemlx_1_1core_1_1detail.html#a3fb927c209b946aefebb195993fbe4cf',1,'mlx::core::detail']]],
['compile_5ferase_46',['compile_erase',['../namespacemlx_1_1core_1_1detail.html#a69eb76a14f845ca000f1ccb2edda0175',1,'mlx::core::detail']]],
['compile_5fimpl_2eh_47',['compile_impl.h',['../compile__impl_8h.html',1,'']]],
['compiled_48',['Compiled',['../classmlx_1_1core_1_1_compiled.html',1,'mlx::core::Compiled'],['../classmlx_1_1core_1_1_compiled.html#a2d8cefff835c419a48a077d306b8e051',1,'mlx::core::Compiled::Compiled()']]],
['compiled_2eh_49',['compiled.h',['../compiled_8h.html',1,'']]],
['compiled_5fallocate_5foutputs_50',['compiled_allocate_outputs',['../namespacemlx_1_1core.html#ab8c3c4fc05745f586de922c8266f4fce',1,'mlx::core']]],
['compiled_5fcheck_5fcontiguity_51',['compiled_check_contiguity',['../namespacemlx_1_1core.html#a3b900ab319948c5a01a3ecd30a709027',1,'mlx::core']]],
['compiled_5fpreamble_2eh_52',['compiled_preamble.h',['../compiled__preamble_8h.html',1,'']]],
['compilemode_53',['CompileMode',['../namespacemlx_1_1core.html#adb15ff2b1ca5207fd4f6e631e2c3bcb4',1,'mlx::core']]],
['complex_2eh_54',['complex.h',['../backend_2metal_2kernels_2complex_8h.html',1,'(Global Namespace)'],['../types_2complex_8h.html',1,'(Global Namespace)']]],
['complex128_5ft_55',['complex128_t',['../structmlx_1_1core_1_1complex128__t.html',1,'mlx::core::complex128_t'],['../structmlx_1_1core_1_1complex128__t.html#aa15d0b805f8790f7c7b76fc7b9d677e0',1,'mlx::core::complex128_t::complex128_t(double v, double u)'],['../structmlx_1_1core_1_1complex128__t.html#abf2842253b874f9f13f39ea68a89e5b6',1,'mlx::core::complex128_t::complex128_t(std::complex&lt; double &gt; v)'],['../structmlx_1_1core_1_1complex128__t.html#a526fba96d7e815360cb4226af085a1bf',1,'mlx::core::complex128_t::complex128_t(T x)']]],
['complex64_56',['complex64',['../structmlx_1_1core_1_1_dtype.html#ade845ef5dcebead13a37fe696436e1daa8c022579455bcd2c681f007e84f4e2cf',1,'mlx::core::Dtype::complex64'],['../namespacemlx_1_1core.html#af99db87e0078bfcdb383f5689bc874d4',1,'mlx::core::complex64']]],
['complex64_5ft_57',['complex64_t',['../structcomplex64__t.html',1,'complex64_t'],['../structmlx_1_1core_1_1complex64__t.html',1,'mlx::core::complex64_t'],['../structcomplex64__t.html#adbd392a5e92d31997380ad0a38be4be8',1,'complex64_t::complex64_t(float real, float imag)'],['../structcomplex64__t.html#a29782289bb90d6294099667b86509cd3',1,'complex64_t::complex64_t()'],['../structcomplex64__t.html#a905b048d70eb8d748a62454268242291',1,'complex64_t::complex64_t() threadgroup'],['../structcomplex64__t.html#a33a2452eb33b5ed53655773539c357a5',1,'complex64_t::complex64_t(T x) thread'],['../structcomplex64__t.html#a89b65ace8588b7bf215355f705eb23d9',1,'complex64_t::complex64_t(T x) threadgroup'],['../structcomplex64__t.html#ac81b486f642fb3b26c5d659917bdbcd0',1,'complex64_t::complex64_t(T x) device'],['../structcomplex64__t.html#a0a27a41206400f1e62b60ceb56960c93',1,'complex64_t::complex64_t(T x) const ant'],['../structmlx_1_1core_1_1complex64__t.html#a697cc973ae27d63c8e00d830e780bd8c',1,'mlx::core::complex64_t::complex64_t(float v, float u)'],['../structmlx_1_1core_1_1complex64__t.html#ae065e39938f9c4374b4116f4c67d4d09',1,'mlx::core::complex64_t::complex64_t(std::complex&lt; float &gt; v)'],['../structmlx_1_1core_1_1complex64__t.html#a2232cbbe591a9d2bc228cb23fac38b50',1,'mlx::core::complex64_t::complex64_t(T x)']]],
['complex_5fbinop_58',['complex_binop',['../types_2complex_8h.html#a9c7995d495359894e1b30c0f1678d6bd',1,'complex.h']]],
['complex_5fbinop_5fhelper_59',['complex_binop_helper',['../types_2complex_8h.html#ac6890f9852de12339b09b65757ebc8c4',1,'complex.h']]],
['complex_5fmul_60',['complex_mul',['../radix_8h.html#a5bfc53b531214c9ce277bebc18aa67d6',1,'radix.h']]],
['complex_5fmul_5fconj_61',['complex_mul_conj',['../radix_8h.html#a0e2dfd3d1dda09f47ccc64eec35629f3',1,'radix.h']]],
['complexfloating_62',['complexfloating',['../structmlx_1_1core_1_1_dtype.html#ac091c39cbd6686ef69aa1e5a2425aa2dafb203630099d501ff7c255a574bc4812',1,'mlx::core::Dtype::complexfloating'],['../namespacemlx_1_1core.html#a70b8e88c9df750af984757105af33423',1,'mlx::core::complexfloating']]],
['compute_5fstrided_5findices_63',['compute_strided_indices',['../struct_read_writer.html#a7c903fbb8b85a856ba5564d7df537cdf',1,'ReadWriter']]],
['concatenate_64',['Concatenate',['../classmlx_1_1core_1_1_concatenate.html',1,'mlx::core::Concatenate'],['../classmlx_1_1core_1_1_concatenate.html#acff07853de2d31faeec7c4ca40ce0888',1,'mlx::core::Concatenate::Concatenate()']]],
['concatenate_65',['concatenate',['../group__ops.html#gabdc36fa65697d0361c8d67495de77129',1,'mlx::core::concatenate(const std::vector&lt; array &gt; &amp;arrays, int axis, StreamOrDevice s={})'],['../group__ops.html#gaa95c34ca3a8877f2c50cb60e7fa312b8',1,'mlx::core::concatenate(const std::vector&lt; array &gt; &amp;arrays, StreamOrDevice s={})']]],
['concatenate_5fgpu_66',['concatenate_gpu',['../namespacemlx_1_1core.html#a050299d0d366ca5c9d09d1004dcc3e7d',1,'mlx::core']]],
['concurrent_5fqueue_67',['concurrent_queue',['../classpocketfft_1_1detail_1_1threading_1_1concurrent__queue.html',1,'pocketfft::detail::threading']]],
['concurrent_5fqueue_3c_20std_3a_3afunction_3c_20void_28_29_3e_20_3e_68',['concurrent_queue&lt; std::function&lt; void()&gt; &gt;',['../classpocketfft_1_1detail_1_1threading_1_1concurrent__queue.html',1,'pocketfft::detail::threading']]],
['concurrentcontext_69',['ConcurrentContext',['../structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html',1,'mlx::core::metal::CommandEncoder::ConcurrentContext'],['../structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html#aee044d7729739c96e845823f9ecc5174',1,'mlx::core::metal::CommandEncoder::ConcurrentContext::ConcurrentContext()']]],
['cond_70',['cond',['../structmlx_1_1core_1_1scheduler_1_1_stream_thread.html#a4ffd524d6a5bedd1a303b63bdde6701c',1,'mlx::core::scheduler::StreamThread']]],
['conj_71',['conj',['../namespacepocketfft_1_1detail.html#a66d79051d502046a9b9f103e744dbad3',1,'pocketfft::detail']]],
['conjugate_72',['Conjugate',['../struct_conjugate.html',1,'Conjugate'],['../classmlx_1_1core_1_1_conjugate.html',1,'mlx::core::Conjugate'],['../structmlx_1_1core_1_1detail_1_1_conjugate.html',1,'mlx::core::detail::Conjugate'],['../classmlx_1_1core_1_1_conjugate.html#a627f9e6a8729fb3ffb3ca3228d007c87',1,'mlx::core::Conjugate::Conjugate()']]],
['conjugate_73',['conjugate',['../group__ops.html#ga5b596906bf8cdc8d97ed6ddc9aeb4c23',1,'mlx::core']]],
['contiguous_74',['contiguous',['../structmlx_1_1core_1_1array_1_1_flags.html#afd0ab11e7a486a2a8e50ee84b971ac8a',1,'mlx::core::array::Flags']]],
['contiguous_5fscan_75',['contiguous_scan',['../scan_8h.html#a60d279b9add7d56639bb209408f09d79',1,'scan.h']]],
['contiguousallreduce_76',['ContiguousAllReduce',['../namespacemlx_1_1core.html#a12412984a1cabfe1189942c898f8fe65ae4e34c7154eb8dc47aa8503209730424',1,'mlx::core']]],
['contiguousiterator_77',['ContiguousIterator',['../structmlx_1_1core_1_1_contiguous_iterator.html',1,'mlx::core::ContiguousIterator&lt; StrideT &gt;'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a68794af4a442d3d8ac4647817af8e1f6',1,'mlx::core::ContiguousIterator::ContiguousIterator()'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a6cb378408b6f546eeb6ade1a4faafe3c',1,'mlx::core::ContiguousIterator::ContiguousIterator(const array &amp;a)'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a16bdacb53f65b7284068cd49d4cba292',1,'mlx::core::ContiguousIterator::ContiguousIterator(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; StrideT &gt; &amp;strides, int dims)']]],
['contiguousreduce_78',['ContiguousReduce',['../namespacemlx_1_1core.html#a12412984a1cabfe1189942c898f8fe65ad2547f25dffe8d8936dbec25601cfc84',1,'mlx::core']]],
['contiguousstridedreduce_79',['ContiguousStridedReduce',['../namespacemlx_1_1core.html#a12412984a1cabfe1189942c898f8fe65ab48dac7508a2c790de1bdc33f29177ed',1,'mlx::core']]],
['conv_80',['conv',['../namespacemlx_1_1core_1_1metal.html#ab1704e853394c725668c06752ebb5c24',1,'mlx::core::metal']]],
['conv_2eh_81',['conv.h',['../conv_8h.html',1,'']]],
['conv1d_82',['conv1d',['../group__ops.html#ga30d47e08093c03a3676f235f9f559411',1,'mlx::core']]],
['conv2d_83',['conv2d',['../group__ops.html#ga73b02833229678786e7f302d458d5a83',1,'mlx::core']]],
['conv2dgeneralbaseinfo_84',['Conv2DGeneralBaseInfo',['../structmlx_1_1steel_1_1_conv2_d_general_base_info.html',1,'mlx::steel']]],
['conv2dgeneraljumpparams_85',['Conv2DGeneralJumpParams',['../structmlx_1_1steel_1_1_conv2_d_general_jump_params.html',1,'mlx::steel']]],
['conv2dinputblockloadergeneral_86',['Conv2DInputBlockLoaderGeneral',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_general.html',1,'mlx::steel::Conv2DInputBlockLoaderGeneral&lt; T, BM, BN, BK, tgp_size, tgp_padding &gt;'],['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_general.html#a1d83af561a483432bf8dcb42e734b23b',1,'mlx::steel::Conv2DInputBlockLoaderGeneral::Conv2DInputBlockLoaderGeneral()']]],
['conv2dinputblockloaderlargefilter_87',['Conv2DInputBlockLoaderLargeFilter',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_large_filter.html',1,'mlx::steel::Conv2DInputBlockLoaderLargeFilter&lt; T, BM, BN, BK, tgp_size, tgp_padding &gt;'],['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_large_filter.html#a8755116a535539744e4947bc69f9c50f',1,'mlx::steel::Conv2DInputBlockLoaderLargeFilter::Conv2DInputBlockLoaderLargeFilter()']]],
['conv2dinputblockloadersmallchannels_88',['Conv2DInputBlockLoaderSmallChannels',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_channels.html',1,'mlx::steel::Conv2DInputBlockLoaderSmallChannels&lt; T, BM, BN, BK, tgp_size, n_channels, tgp_padding &gt;'],['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_channels.html#ab9fd3fdeab94470dde3326f1dd5c455a',1,'mlx::steel::Conv2DInputBlockLoaderSmallChannels::Conv2DInputBlockLoaderSmallChannels()']]],
['conv2dinputblockloadersmallfilter_89',['Conv2DInputBlockLoaderSmallFilter',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_filter.html',1,'mlx::steel::Conv2DInputBlockLoaderSmallFilter&lt; T, BM, BN, BK, tgp_size, tgp_padding &gt;'],['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_filter.html#a0a2cbf57c51cd928722e3f06aafcf933',1,'mlx::steel::Conv2DInputBlockLoaderSmallFilter::Conv2DInputBlockLoaderSmallFilter()']]],
['conv2dweightblockloader_90',['Conv2DWeightBlockLoader',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader.html',1,'mlx::steel::Conv2DWeightBlockLoader&lt; T, BM, BN, BK, tgp_size, tgp_padding &gt;'],['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader.html#a9a7dca3512b64cffb6eac305d795831c',1,'mlx::steel::Conv2DWeightBlockLoader::Conv2DWeightBlockLoader()']]],
['conv2dweightblockloadergeneral_91',['Conv2DWeightBlockLoaderGeneral',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_general.html',1,'mlx::steel::Conv2DWeightBlockLoaderGeneral&lt; T, BM, BN, BK, tgp_size, tgp_padding &gt;'],['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_general.html#ad0550fabbdc9297559381a5b488e9af1',1,'mlx::steel::Conv2DWeightBlockLoaderGeneral::Conv2DWeightBlockLoaderGeneral()']]],
['conv2dweightblockloadersmallchannels_92',['Conv2DWeightBlockLoaderSmallChannels',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_small_channels.html',1,'mlx::steel::Conv2DWeightBlockLoaderSmallChannels&lt; T, BM, BN, BK, tgp_size, n_channels, tgp_padding &gt;'],['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_small_channels.html#ae1806ea1c19713819dee83a38ab35fa6',1,'mlx::steel::Conv2DWeightBlockLoaderSmallChannels::Conv2DWeightBlockLoaderSmallChannels()']]],
['conv3d_93',['conv3d',['../group__ops.html#ga6e9907d2f14dc4803e4306b3dbc4b3ca',1,'mlx::core']]],
['conv_5fgeneral_94',['conv_general',['../group__ops.html#ga2236e5dfc7e52e28abf6c21675d0a51e',1,'mlx::core::conv_general(array input, array weight, std::vector&lt; int &gt; stride={}, std::vector&lt; int &gt; padding_lo={}, std::vector&lt; int &gt; padding_hi={}, std::vector&lt; int &gt; kernel_dilation={}, std::vector&lt; int &gt; input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})'],['../group__ops.html#gab59f89942cd1efaadffe9e8762e3c99d',1,'mlx::core::conv_general(const array &amp;input, const array &amp;weight, std::vector&lt; int &gt; stride={}, std::vector&lt; int &gt; padding={}, std::vector&lt; int &gt; kernel_dilation={}, std::vector&lt; int &gt; input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})']]],
['conv_5ftranspose1d_95',['conv_transpose1d',['../group__ops.html#gaa30bf1adcd78d1c2595d07b215731714',1,'mlx::core']]],
['conv_5ftranspose2d_96',['conv_transpose2d',['../group__ops.html#gaebb59971cb9bc45005dc1d398e4f0a3d',1,'mlx::core']]],
['conv_5ftranspose3d_97',['conv_transpose3d',['../group__ops.html#ga8db814da631d9cd32a8d6563bf4ac530',1,'mlx::core']]],
['convolution_98',['Convolution',['../classmlx_1_1core_1_1_convolution.html',1,'mlx::core::Convolution'],['../classmlx_1_1core_1_1_convolution.html#a6f1de77b719bb13217b0d8c64cabb8ef',1,'mlx::core::Convolution::Convolution()']]],
['copy_99',['Copy',['../classmlx_1_1core_1_1_copy.html',1,'mlx::core::Copy'],['../classmlx_1_1core_1_1_copy.html#a6243e044af119105ffaaed7d405cd584',1,'mlx::core::Copy::Copy()']]],
['copy_100',['copy',['../namespacemlx_1_1core.html#a479648542a2bea151b947b18f0e79dd2',1,'mlx::core::copy()'],['../namespacemlx_1_1core_1_1metal.html#aa215e631e2680f04a591b88d91571719',1,'mlx::core::metal::copy()'],['../group__ops.html#gae306e93af12f774bd80bad6c231b09d6',1,'mlx::core::copy()']]],
['copy_2eh_101',['copy.h',['../common_2copy_8h.html',1,'(Global Namespace)'],['../metal_2copy_8h.html',1,'(Global Namespace)'],['../metal_2kernels_2copy_8h.html',1,'(Global Namespace)']]],
['copy_5fg_102',['copy_g',['../metal_2kernels_2copy_8h.html#a778ce2dbfbaa23b24bd5efbe68448c36',1,'copy.h']]],
['copy_5fg_5fnd1_103',['copy_g_nd1',['../metal_2kernels_2copy_8h.html#aba4530a7db6a61ca36f50e4f5e58fb77',1,'copy.h']]],
['copy_5fg_5fnd2_104',['copy_g_nd2',['../metal_2kernels_2copy_8h.html#aee678c7c31119f3e609685589f37490c',1,'copy.h']]],
['copy_5fg_5fnd3_105',['copy_g_nd3',['../metal_2kernels_2copy_8h.html#a821f8f3f3891159a295c66fc25aed1ff',1,'copy.h']]],
['copy_5fgg_106',['copy_gg',['../metal_2kernels_2copy_8h.html#a1e39c2683eeaf05955e7619fbd34aea5',1,'copy.h']]],
['copy_5fgg_5fnd1_107',['copy_gg_nd1',['../metal_2kernels_2copy_8h.html#a3278d9c999718bee3ccbe2922f501bf1',1,'copy.h']]],
['copy_5fgg_5fnd2_108',['copy_gg_nd2',['../metal_2kernels_2copy_8h.html#a3e2d3cc7f34f56170409b6735f51a950',1,'copy.h']]],
['copy_5fgg_5fnd3_109',['copy_gg_nd3',['../metal_2kernels_2copy_8h.html#a59f43b5bffed936d7559ceb06a10aabd',1,'copy.h']]],
['copy_5fgpu_110',['copy_gpu',['../namespacemlx_1_1core.html#addaa46a13ac2deb1d9ce621338320e0e',1,'mlx::core::copy_gpu(const array &amp;src, array &amp;out, CopyType ctype, const Stream &amp;s)'],['../namespacemlx_1_1core.html#a6a6f4e46c8fc44fdc74c50ace02bcf38',1,'mlx::core::copy_gpu(const array &amp;src, array &amp;out, CopyType ctype)']]],
['copy_5fgpu_5finplace_111',['copy_gpu_inplace',['../namespacemlx_1_1core.html#a69e30f5d30a6d72ac0ffe4886f24b7ba',1,'mlx::core::copy_gpu_inplace(const array &amp;in, array &amp;out, const std::vector&lt; int &gt; &amp;data_shape, const std::vector&lt; stride_t &gt; &amp;i_strides, const std::vector&lt; stride_t &gt; &amp;o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype, const Stream &amp;s)'],['../namespacemlx_1_1core.html#a8e1ccb0ed9387b0a789311d9f8964803',1,'mlx::core::copy_gpu_inplace(const array &amp;src, array &amp;out, CopyType ctype, const Stream &amp;s)'],['../namespacemlx_1_1core.html#ae55b801b09ccf55cba96278163a9b1ef',1,'mlx::core::copy_gpu_inplace(const array &amp;in, array &amp;out, const std::vector&lt; int64_t &gt; &amp;istride, int64_t ioffset, CopyType ctype, const Stream &amp;s)']]],
['copy_5fhartley_112',['copy_hartley',['../namespacepocketfft_1_1detail.html#abac3fcc8ce83800d228774f64c28d4c3',1,'pocketfft::detail::copy_hartley(const multi_iter&lt; vlen &gt; &amp;it, const vtype_t&lt; T &gt; *src, ndarr&lt; T &gt; &amp;dst)'],['../namespacepocketfft_1_1detail.html#ae7b44d2773d9d06a9787aff01d66b3ed',1,'pocketfft::detail::copy_hartley(const multi_iter&lt; vlen &gt; &amp;it, const T *src, ndarr&lt; T &gt; &amp;dst)']]],
['copy_5finplace_113',['copy_inplace',['../namespacemlx_1_1core.html#a98495894a796b2cc6d022e7a03432c64',1,'mlx::core::copy_inplace(const array &amp;src, array &amp;dst, CopyType ctype)'],['../namespacemlx_1_1core.html#aad636e2d0b2f882cadd1b438f4daa9ed',1,'mlx::core::copy_inplace(const array &amp;src, array &amp;dst, const std::vector&lt; int &gt; &amp;data_shape, const std::vector&lt; stride_t &gt; &amp;i_strides, const std::vector&lt; stride_t &gt; &amp;o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype)']]],
['copy_5finput_114',['copy_input',['../namespacepocketfft_1_1detail.html#aff05be3064743c1143b19318ab12ad4a',1,'pocketfft::detail::copy_input(const multi_iter&lt; vlen &gt; &amp;it, const cndarr&lt; cmplx&lt; T &gt; &gt; &amp;src, cmplx&lt; vtype_t&lt; T &gt; &gt; *dst)'],['../namespacepocketfft_1_1detail.html#a30fc708f9d8f9cfa74194925c7863c0a',1,'pocketfft::detail::copy_input(const multi_iter&lt; vlen &gt; &amp;it, const cndarr&lt; T &gt; &amp;src, vtype_t&lt; T &gt; *dst)'],['../namespacepocketfft_1_1detail.html#a3387bd35f237870e42b8461769e6aec4',1,'pocketfft::detail::copy_input(const multi_iter&lt; vlen &gt; &amp;it, const cndarr&lt; T &gt; &amp;src, T *dst)']]],
['copy_5foutput_115',['copy_output',['../namespacepocketfft_1_1detail.html#a1523a037300a8da05db210b802d9cb0e',1,'pocketfft::detail::copy_output(const multi_iter&lt; vlen &gt; &amp;it, const cmplx&lt; vtype_t&lt; T &gt; &gt; *src, ndarr&lt; cmplx&lt; T &gt; &gt; &amp;dst)'],['../namespacepocketfft_1_1detail.html#a21980853aca4d92ed06e3dcffe7ef660',1,'pocketfft::detail::copy_output(const multi_iter&lt; vlen &gt; &amp;it, const vtype_t&lt; T &gt; *src, ndarr&lt; T &gt; &amp;dst)'],['../namespacepocketfft_1_1detail.html#a310481c334e46674710ba794ad7403c0',1,'pocketfft::detail::copy_output(const multi_iter&lt; vlen &gt; &amp;it, const T *src, ndarr&lt; T &gt; &amp;dst)']]],
['copy_5fs_116',['copy_s',['../metal_2kernels_2copy_8h.html#aef09f9b9475345b1bba121d037d222ea',1,'copy.h']]],
['copy_5fs2_117',['copy_s2',['../metal_2kernels_2copy_8h.html#a8023e9335cc5334847a8d315042be3a3',1,'copy.h']]],
['copy_5fshared_5fbuffer_118',['copy_shared_buffer',['../classmlx_1_1core_1_1array.html#a28df7a333d90a311c49bc4bce7a1ad6d',1,'mlx::core::array::copy_shared_buffer(const array &amp;other, const std::vector&lt; size_t &gt; &amp;strides, Flags flags, size_t data_size, size_t offset=0)'],['../classmlx_1_1core_1_1array.html#a92974c656c35a972ad241f80584bbd29',1,'mlx::core::array::copy_shared_buffer(const array &amp;other)']]],
['copy_5fv_119',['copy_v',['../metal_2kernels_2copy_8h.html#ae26a13e0c8e6c15f7b10078e65970659',1,'copy.h']]],
['copy_5fv2_120',['copy_v2',['../metal_2kernels_2copy_8h.html#aee14a5326f53d9b30b0b38e27d180ef3',1,'copy.h']]],
['copytype_121',['CopyType',['../namespacemlx_1_1core.html#abd84ff6c5245e4e170b2ef5247594337',1,'mlx::core']]],
['core_20array_20operations_122',['Core array operations',['../group__ops.html',1,'']]],
['cos_123',['Cos',['../struct_cos.html',1,'Cos'],['../classmlx_1_1core_1_1_cos.html',1,'mlx::core::Cos'],['../structmlx_1_1core_1_1detail_1_1_cos.html',1,'mlx::core::detail::Cos'],['../classmlx_1_1core_1_1_cos.html#a2acb9fcf0901462189c476756fd99995',1,'mlx::core::Cos::Cos()']]],
['cos_124',['cos',['../namespacepocketfft_1_1detail.html#a499c1e8b7d79a5272af024f46c63ff9d',1,'pocketfft::detail::cos()'],['../namespacemetal.html#a2fa4778a6fe2fa43253ea724e5a608a3',1,'metal::cos()'],['../namespacemetal_1_1fast.html#a75b6bb32fa3870eda46a7bfc9f481f88',1,'metal::fast::cos()'],['../namespacemetal_1_1precise.html#ac4941f62e7d8ab9d7cabbd967aa9f220',1,'metal::precise::cos()'],['../group__ops.html#ga39dfdf72b556012aa35ff27a94116e74',1,'mlx::core::cos()']]],
['cosh_125',['Cosh',['../struct_cosh.html',1,'Cosh'],['../classmlx_1_1core_1_1_cosh.html',1,'mlx::core::Cosh'],['../structmlx_1_1core_1_1detail_1_1_cosh.html',1,'mlx::core::detail::Cosh'],['../classmlx_1_1core_1_1_cosh.html#a44e8ac2e09a55ec32e9dc6641eedc8f1',1,'mlx::core::Cosh::Cosh()']]],
['cosh_126',['cosh',['../namespacemetal.html#a8a68a88cc110830d057dbd71431b93c0',1,'metal::cosh()'],['../namespacemetal_1_1fast.html#a31544ad9de28012a4ddda86e3966a77e',1,'metal::fast::cosh()'],['../namespacemetal_1_1precise.html#a72d86d508300a9b58f4ccbbe70da4fbc',1,'metal::precise::cosh()'],['../group__ops.html#ga2181b71cda88007a3092be4795ff0715',1,'mlx::core::cosh()']]],
['cosine_127',['cosine',['../structpocketfft_1_1detail_1_1_exec_dcst.html#a185023fc1e386cc8f233b79c49c1fd8a',1,'pocketfft::detail::ExecDcst']]],
['cospi_128',['cospi',['../namespacemetal.html#a5c2f37939ad705ddea4409d3bedb8ce1',1,'metal::cospi()'],['../namespacemetal_1_1fast.html#a9906b41f75319b384ffb570cc94d67ce',1,'metal::fast::cospi()'],['../namespacemetal_1_1precise.html#a2392b78bd196efdbbac65901c4ab20e7',1,'metal::precise::cospi()']]],
['cost_5fguess_129',['cost_guess',['../structpocketfft_1_1detail_1_1util.html#ad3d874bc3fb0048df2270779a15d4bd0',1,'pocketfft::detail::util']]],
['count_5fdown_130',['count_down',['../classpocketfft_1_1detail_1_1threading_1_1latch.html#a81d6597189b40410e35f3cd653fd1342',1,'pocketfft::detail::threading::latch']]],
['cpu_131',['cpu',['../structmlx_1_1core_1_1_device.html#a69ee81924251dec96f1945c9d91506fd',1,'mlx::core::Device::cpu'],['../structmlx_1_1core_1_1_device.html#ac45b3de9b3458d8f31005136cde20fdbad9747e2da342bdb995f6389533ad1a3d',1,'mlx::core::Device::cpu']]],
['cross_132',['cross',['../namespacemlx_1_1core_1_1linalg.html#abcda3fbda45183c21e7f27aa0dde64e6',1,'mlx::core::linalg']]],
['ctile_133',['Ctile',['../structmlx_1_1steel_1_1_block_m_m_a.html#a81838da5d81e62d372d581be599c5a88',1,'mlx::steel::BlockMMA']]],
['cummax_134',['CumMax',['../struct_cum_max.html',1,'']]],
['cummax_135',['cummax',['../group__ops.html#gaee37cac8476e8f8d666bcded5bc59143',1,'mlx::core']]],
['cummin_136',['CumMin',['../struct_cum_min.html',1,'']]],
['cummin_137',['cummin',['../group__ops.html#ga19c1bf6929fe8d66b9cd408946aea6a8',1,'mlx::core']]],
['cumprod_138',['CumProd',['../struct_cum_prod.html',1,'']]],
['cumprod_139',['cumprod',['../group__ops.html#ga0d71dfbc14ef3ed564b0c5ee26af680f',1,'mlx::core']]],
['cumprod_3c_20bool_20_3e_140',['CumProd&lt; bool &gt;',['../struct_cum_prod_3_01bool_01_4.html',1,'']]],
['cumsum_141',['CumSum',['../struct_cum_sum.html',1,'']]],
['cumsum_142',['cumsum',['../group__ops.html#gaddc825a5c173e195ab0fda83ad630420',1,'mlx::core']]],
['custom_143',['Custom',['../classmlx_1_1core_1_1fast_1_1_custom.html',1,'mlx::core::fast::Custom'],['../classmlx_1_1core_1_1fast_1_1_custom.html#a4186fea23f7156c38960426821fca313',1,'mlx::core::fast::Custom::Custom()']]],
['custom_5ffunction_144',['custom_function',['../namespacemlx_1_1core.html#a8d3ca5fbaecdb995660c24cde5aeebaf',1,'mlx::core']]],
['custom_5fvjp_145',['custom_vjp',['../namespacemlx_1_1core.html#a9290596250fa308df4c69b44483bb8aa',1,'mlx::core']]],
['customkernel_146',['CustomKernel',['../classmlx_1_1core_1_1fast_1_1_custom_kernel.html',1,'mlx::core::fast::CustomKernel'],['../classmlx_1_1core_1_1fast_1_1_custom_kernel.html#a954893e07f0d36715b4e1e414b6f2153',1,'mlx::core::fast::CustomKernel::CustomKernel()']]],
['customkernelshapeinfo_147',['CustomKernelShapeInfo',['../structmlx_1_1core_1_1fast_1_1_custom_kernel_shape_info.html',1,'mlx::core::fast']]],
['customtransforms_148',['CustomTransforms',['../classmlx_1_1core_1_1_custom_transforms.html',1,'mlx::core::CustomTransforms'],['../classmlx_1_1core_1_1_custom_transforms.html#ab52abadb9c6f6db83d087c7b751be488',1,'mlx::core::CustomTransforms::CustomTransforms()']]]
['col_5freduce_5f2pass_35',['col_reduce_2pass',['../reduce__col_8h.html#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d',1,'reduce_col.h']]],
['col_5freduce_5flongcolumn_36',['col_reduce_longcolumn',['../reduce__col_8h.html#a5b4f4c4c247ad341ff8d31dcbbbce0eb',1,'reduce_col.h']]],
['col_5freduce_5flooped_37',['col_reduce_looped',['../reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385',1,'reduce_col.h']]],
['col_5freduce_5fsmall_38',['col_reduce_small',['../reduce__col_8h.html#a7c378443a2b6f4d9210db8a21a9ac4f5',1,'reduce_col.h']]],
['collapse_5fcontiguous_5fdims_39',['collapse_contiguous_dims',['../namespacemlx_1_1core.html#a38fe6ec5220d13d96c7dad7556d2b613',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; std::vector&lt; int64_t &gt; &gt; &amp;strides, int64_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#af2895f9b0083efd8221275eb8cadccbe',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; std::vector&lt; size_t &gt; &gt; &amp;strides, size_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#a90e2b6edc0fe82230cb93f5ea39febb4',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; array &gt; &amp;xs, size_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#ac813412cce77fc1340dcfefc6e099276',1,'mlx::core::collapse_contiguous_dims(Arrays &amp;&amp;... xs)'],['../namespacemlx_1_1core.html#aab3cc7f3808934ae0727b920eba231bd',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; int64_t &gt; &amp;strides, int64_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#a1e0cbcf109d32794ffc8efc7302ba9b0',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; size_t &gt; &amp;strides, size_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#a4ee50bfb240512d0c0ce151dfe2c74ef',1,'mlx::core::collapse_contiguous_dims(const array &amp;a, size_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())']]],
['commandencoder_40',['CommandEncoder',['../structmlx_1_1core_1_1metal_1_1_command_encoder.html',1,'mlx::core::metal::CommandEncoder'],['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#a2334774486f447213ee997e55c2e52a3',1,'mlx::core::metal::CommandEncoder::CommandEncoder(MTL::CommandBuffer *cbuf)'],['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#ac68ca977b5bde5434284ce7979647f14',1,'mlx::core::metal::CommandEncoder::CommandEncoder(const CommandEncoder &amp;)=delete']]],
['commit_5fcommand_5fbuffer_41',['commit_command_buffer',['../classmlx_1_1core_1_1metal_1_1_device.html#a95248f1387824067fd4fed23ace5ac0c',1,'mlx::core::metal::Device']]],
['commonallocator_42',['CommonAllocator',['../classmlx_1_1core_1_1allocator_1_1_common_allocator.html',1,'mlx::core::allocator']]],
['communication_5fstream_43',['communication_stream',['../namespacemlx_1_1core_1_1distributed_1_1detail.html#ac3612edf0e0e18c1e4ba0ce7c6e35cd6',1,'mlx::core::distributed::detail']]],
['compile_44',['compile',['../namespacemlx_1_1core.html#a3ac798e65e59fe10b7fb5c522efce782',1,'mlx::core::compile()'],['../namespacemlx_1_1core_1_1detail.html#ac3b7b09892ff7290d5f3ef26cb444329',1,'mlx::core::detail::compile()']]],
['compile_2eh_45',['compile.h',['../compile_8h.html',1,'']]],
['compile_5favailable_5ffor_5fdevice_46',['compile_available_for_device',['../namespacemlx_1_1core_1_1detail.html#aeeff2ba6ec3d9d4ed090de6d2681dbc2',1,'mlx::core::detail']]],
['compile_5fclear_5fcache_47',['compile_clear_cache',['../namespacemlx_1_1core_1_1detail.html#a3fb927c209b946aefebb195993fbe4cf',1,'mlx::core::detail']]],
['compile_5ferase_48',['compile_erase',['../namespacemlx_1_1core_1_1detail.html#a69eb76a14f845ca000f1ccb2edda0175',1,'mlx::core::detail']]],
['compile_5fimpl_2eh_49',['compile_impl.h',['../compile__impl_8h.html',1,'']]],
['compiled_50',['Compiled',['../classmlx_1_1core_1_1_compiled.html',1,'mlx::core::Compiled'],['../classmlx_1_1core_1_1_compiled.html#a2d8cefff835c419a48a077d306b8e051',1,'mlx::core::Compiled::Compiled()']]],
['compiled_2eh_51',['compiled.h',['../compiled_8h.html',1,'']]],
['compiled_5fallocate_5foutputs_52',['compiled_allocate_outputs',['../namespacemlx_1_1core.html#ab8c3c4fc05745f586de922c8266f4fce',1,'mlx::core']]],
['compiled_5fcheck_5fcontiguity_53',['compiled_check_contiguity',['../namespacemlx_1_1core.html#a3b900ab319948c5a01a3ecd30a709027',1,'mlx::core']]],
['compiled_5fpreamble_2eh_54',['compiled_preamble.h',['../compiled__preamble_8h.html',1,'']]],
['compilemode_55',['CompileMode',['../namespacemlx_1_1core.html#adb15ff2b1ca5207fd4f6e631e2c3bcb4',1,'mlx::core']]],
['complex_2eh_56',['complex.h',['../backend_2metal_2kernels_2complex_8h.html',1,'(Global Namespace)'],['../types_2complex_8h.html',1,'(Global Namespace)']]],
['complex128_5ft_57',['complex128_t',['../structmlx_1_1core_1_1complex128__t.html',1,'mlx::core::complex128_t'],['../structmlx_1_1core_1_1complex128__t.html#aa15d0b805f8790f7c7b76fc7b9d677e0',1,'mlx::core::complex128_t::complex128_t(double v, double u)'],['../structmlx_1_1core_1_1complex128__t.html#abf2842253b874f9f13f39ea68a89e5b6',1,'mlx::core::complex128_t::complex128_t(std::complex&lt; double &gt; v)'],['../structmlx_1_1core_1_1complex128__t.html#a526fba96d7e815360cb4226af085a1bf',1,'mlx::core::complex128_t::complex128_t(T x)']]],
['complex64_58',['complex64',['../structmlx_1_1core_1_1_dtype.html#ade845ef5dcebead13a37fe696436e1daa8c022579455bcd2c681f007e84f4e2cf',1,'mlx::core::Dtype::complex64'],['../namespacemlx_1_1core.html#af99db87e0078bfcdb383f5689bc874d4',1,'mlx::core::complex64']]],
['complex64_5ft_59',['complex64_t',['../structcomplex64__t.html',1,'complex64_t'],['../structmlx_1_1core_1_1complex64__t.html',1,'mlx::core::complex64_t'],['../structcomplex64__t.html#adbd392a5e92d31997380ad0a38be4be8',1,'complex64_t::complex64_t(float real, float imag)'],['../structcomplex64__t.html#a29782289bb90d6294099667b86509cd3',1,'complex64_t::complex64_t()'],['../structcomplex64__t.html#a905b048d70eb8d748a62454268242291',1,'complex64_t::complex64_t() threadgroup'],['../structcomplex64__t.html#a33a2452eb33b5ed53655773539c357a5',1,'complex64_t::complex64_t(T x) thread'],['../structcomplex64__t.html#a89b65ace8588b7bf215355f705eb23d9',1,'complex64_t::complex64_t(T x) threadgroup'],['../structcomplex64__t.html#ac81b486f642fb3b26c5d659917bdbcd0',1,'complex64_t::complex64_t(T x) device'],['../structcomplex64__t.html#a0a27a41206400f1e62b60ceb56960c93',1,'complex64_t::complex64_t(T x) const ant'],['../structmlx_1_1core_1_1complex64__t.html#a697cc973ae27d63c8e00d830e780bd8c',1,'mlx::core::complex64_t::complex64_t(float v, float u)'],['../structmlx_1_1core_1_1complex64__t.html#ae065e39938f9c4374b4116f4c67d4d09',1,'mlx::core::complex64_t::complex64_t(std::complex&lt; float &gt; v)'],['../structmlx_1_1core_1_1complex64__t.html#a2232cbbe591a9d2bc228cb23fac38b50',1,'mlx::core::complex64_t::complex64_t(T x)']]],
['complex_5fbinop_60',['complex_binop',['../types_2complex_8h.html#a9c7995d495359894e1b30c0f1678d6bd',1,'complex.h']]],
['complex_5fbinop_5fhelper_61',['complex_binop_helper',['../types_2complex_8h.html#ac6890f9852de12339b09b65757ebc8c4',1,'complex.h']]],
['complex_5fmul_62',['complex_mul',['../radix_8h.html#a5bfc53b531214c9ce277bebc18aa67d6',1,'radix.h']]],
['complex_5fmul_5fconj_63',['complex_mul_conj',['../radix_8h.html#a0e2dfd3d1dda09f47ccc64eec35629f3',1,'radix.h']]],
['complexfloating_64',['complexfloating',['../structmlx_1_1core_1_1_dtype.html#ac091c39cbd6686ef69aa1e5a2425aa2dafb203630099d501ff7c255a574bc4812',1,'mlx::core::Dtype::complexfloating'],['../namespacemlx_1_1core.html#a70b8e88c9df750af984757105af33423',1,'mlx::core::complexfloating']]],
['compute_5fstrided_5findices_65',['compute_strided_indices',['../struct_read_writer.html#a7c903fbb8b85a856ba5564d7df537cdf',1,'ReadWriter']]],
['concatenate_66',['Concatenate',['../classmlx_1_1core_1_1_concatenate.html',1,'mlx::core::Concatenate'],['../classmlx_1_1core_1_1_concatenate.html#acff07853de2d31faeec7c4ca40ce0888',1,'mlx::core::Concatenate::Concatenate()']]],
['concatenate_67',['concatenate',['../group__ops.html#gabdc36fa65697d0361c8d67495de77129',1,'mlx::core::concatenate(const std::vector&lt; array &gt; &amp;arrays, int axis, StreamOrDevice s={})'],['../group__ops.html#gaa95c34ca3a8877f2c50cb60e7fa312b8',1,'mlx::core::concatenate(const std::vector&lt; array &gt; &amp;arrays, StreamOrDevice s={})']]],
['concatenate_5fgpu_68',['concatenate_gpu',['../namespacemlx_1_1core.html#a050299d0d366ca5c9d09d1004dcc3e7d',1,'mlx::core']]],
['concurrent_5fqueue_69',['concurrent_queue',['../classpocketfft_1_1detail_1_1threading_1_1concurrent__queue.html',1,'pocketfft::detail::threading']]],
['concurrent_5fqueue_3c_20std_3a_3afunction_3c_20void_28_29_3e_20_3e_70',['concurrent_queue&lt; std::function&lt; void()&gt; &gt;',['../classpocketfft_1_1detail_1_1threading_1_1concurrent__queue.html',1,'pocketfft::detail::threading']]],
['concurrentcontext_71',['ConcurrentContext',['../structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html',1,'mlx::core::metal::CommandEncoder::ConcurrentContext'],['../structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html#aee044d7729739c96e845823f9ecc5174',1,'mlx::core::metal::CommandEncoder::ConcurrentContext::ConcurrentContext()']]],
['cond_72',['cond',['../structmlx_1_1core_1_1scheduler_1_1_stream_thread.html#a4ffd524d6a5bedd1a303b63bdde6701c',1,'mlx::core::scheduler::StreamThread']]],
['conj_73',['conj',['../namespacepocketfft_1_1detail.html#a66d79051d502046a9b9f103e744dbad3',1,'pocketfft::detail']]],
['conjugate_74',['Conjugate',['../struct_conjugate.html',1,'Conjugate'],['../classmlx_1_1core_1_1_conjugate.html',1,'mlx::core::Conjugate'],['../structmlx_1_1core_1_1detail_1_1_conjugate.html',1,'mlx::core::detail::Conjugate'],['../classmlx_1_1core_1_1_conjugate.html#a627f9e6a8729fb3ffb3ca3228d007c87',1,'mlx::core::Conjugate::Conjugate()']]],
['conjugate_75',['conjugate',['../group__ops.html#ga5b596906bf8cdc8d97ed6ddc9aeb4c23',1,'mlx::core']]],
['contiguous_76',['contiguous',['../structmlx_1_1core_1_1array_1_1_flags.html#afd0ab11e7a486a2a8e50ee84b971ac8a',1,'mlx::core::array::Flags']]],
['contiguous_5fscan_77',['contiguous_scan',['../scan_8h.html#a60d279b9add7d56639bb209408f09d79',1,'scan.h']]],
['contiguousallreduce_78',['ContiguousAllReduce',['../namespacemlx_1_1core.html#a12412984a1cabfe1189942c898f8fe65ae4e34c7154eb8dc47aa8503209730424',1,'mlx::core']]],
['contiguousiterator_79',['ContiguousIterator',['../structmlx_1_1core_1_1_contiguous_iterator.html',1,'mlx::core::ContiguousIterator&lt; StrideT &gt;'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a68794af4a442d3d8ac4647817af8e1f6',1,'mlx::core::ContiguousIterator::ContiguousIterator()'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a6cb378408b6f546eeb6ade1a4faafe3c',1,'mlx::core::ContiguousIterator::ContiguousIterator(const array &amp;a)'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a16bdacb53f65b7284068cd49d4cba292',1,'mlx::core::ContiguousIterator::ContiguousIterator(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; StrideT &gt; &amp;strides, int dims)']]],
['contiguousreduce_80',['ContiguousReduce',['../namespacemlx_1_1core.html#a12412984a1cabfe1189942c898f8fe65ad2547f25dffe8d8936dbec25601cfc84',1,'mlx::core']]],
['contiguousstridedreduce_81',['ContiguousStridedReduce',['../namespacemlx_1_1core.html#a12412984a1cabfe1189942c898f8fe65ab48dac7508a2c790de1bdc33f29177ed',1,'mlx::core']]],
['conv_82',['conv',['../namespacemlx_1_1core_1_1metal.html#ab1704e853394c725668c06752ebb5c24',1,'mlx::core::metal']]],
['conv_2eh_83',['conv.h',['../conv_8h.html',1,'']]],
['conv1d_84',['conv1d',['../group__ops.html#ga30d47e08093c03a3676f235f9f559411',1,'mlx::core']]],
['conv2d_85',['conv2d',['../group__ops.html#ga73b02833229678786e7f302d458d5a83',1,'mlx::core']]],
['conv2dgeneralbaseinfo_86',['Conv2DGeneralBaseInfo',['../structmlx_1_1steel_1_1_conv2_d_general_base_info.html',1,'mlx::steel']]],
['conv2dgeneraljumpparams_87',['Conv2DGeneralJumpParams',['../structmlx_1_1steel_1_1_conv2_d_general_jump_params.html',1,'mlx::steel']]],
['conv2dinputblockloadergeneral_88',['Conv2DInputBlockLoaderGeneral',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_general.html',1,'mlx::steel::Conv2DInputBlockLoaderGeneral&lt; T, BM, BN, BK, tgp_size, tgp_padding &gt;'],['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_general.html#a1d83af561a483432bf8dcb42e734b23b',1,'mlx::steel::Conv2DInputBlockLoaderGeneral::Conv2DInputBlockLoaderGeneral()']]],
['conv2dinputblockloaderlargefilter_89',['Conv2DInputBlockLoaderLargeFilter',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_large_filter.html',1,'mlx::steel::Conv2DInputBlockLoaderLargeFilter&lt; T, BM, BN, BK, tgp_size, tgp_padding &gt;'],['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_large_filter.html#a8755116a535539744e4947bc69f9c50f',1,'mlx::steel::Conv2DInputBlockLoaderLargeFilter::Conv2DInputBlockLoaderLargeFilter()']]],
['conv2dinputblockloadersmallchannels_90',['Conv2DInputBlockLoaderSmallChannels',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_channels.html',1,'mlx::steel::Conv2DInputBlockLoaderSmallChannels&lt; T, BM, BN, BK, tgp_size, n_channels, tgp_padding &gt;'],['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_channels.html#ab9fd3fdeab94470dde3326f1dd5c455a',1,'mlx::steel::Conv2DInputBlockLoaderSmallChannels::Conv2DInputBlockLoaderSmallChannels()']]],
['conv2dinputblockloadersmallfilter_91',['Conv2DInputBlockLoaderSmallFilter',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_filter.html',1,'mlx::steel::Conv2DInputBlockLoaderSmallFilter&lt; T, BM, BN, BK, tgp_size, tgp_padding &gt;'],['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_filter.html#a0a2cbf57c51cd928722e3f06aafcf933',1,'mlx::steel::Conv2DInputBlockLoaderSmallFilter::Conv2DInputBlockLoaderSmallFilter()']]],
['conv2dweightblockloader_92',['Conv2DWeightBlockLoader',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader.html',1,'mlx::steel::Conv2DWeightBlockLoader&lt; T, BM, BN, BK, tgp_size, tgp_padding &gt;'],['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader.html#a9a7dca3512b64cffb6eac305d795831c',1,'mlx::steel::Conv2DWeightBlockLoader::Conv2DWeightBlockLoader()']]],
['conv2dweightblockloadergeneral_93',['Conv2DWeightBlockLoaderGeneral',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_general.html',1,'mlx::steel::Conv2DWeightBlockLoaderGeneral&lt; T, BM, BN, BK, tgp_size, tgp_padding &gt;'],['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_general.html#ad0550fabbdc9297559381a5b488e9af1',1,'mlx::steel::Conv2DWeightBlockLoaderGeneral::Conv2DWeightBlockLoaderGeneral()']]],
['conv2dweightblockloadersmallchannels_94',['Conv2DWeightBlockLoaderSmallChannels',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_small_channels.html',1,'mlx::steel::Conv2DWeightBlockLoaderSmallChannels&lt; T, BM, BN, BK, tgp_size, n_channels, tgp_padding &gt;'],['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_small_channels.html#ae1806ea1c19713819dee83a38ab35fa6',1,'mlx::steel::Conv2DWeightBlockLoaderSmallChannels::Conv2DWeightBlockLoaderSmallChannels()']]],
['conv3d_95',['conv3d',['../group__ops.html#ga6e9907d2f14dc4803e4306b3dbc4b3ca',1,'mlx::core']]],
['conv_5fgeneral_96',['conv_general',['../group__ops.html#ga2236e5dfc7e52e28abf6c21675d0a51e',1,'mlx::core::conv_general(array input, array weight, std::vector&lt; int &gt; stride={}, std::vector&lt; int &gt; padding_lo={}, std::vector&lt; int &gt; padding_hi={}, std::vector&lt; int &gt; kernel_dilation={}, std::vector&lt; int &gt; input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})'],['../group__ops.html#gab59f89942cd1efaadffe9e8762e3c99d',1,'mlx::core::conv_general(const array &amp;input, const array &amp;weight, std::vector&lt; int &gt; stride={}, std::vector&lt; int &gt; padding={}, std::vector&lt; int &gt; kernel_dilation={}, std::vector&lt; int &gt; input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})']]],
['conv_5ftranspose1d_97',['conv_transpose1d',['../group__ops.html#gaa30bf1adcd78d1c2595d07b215731714',1,'mlx::core']]],
['conv_5ftranspose2d_98',['conv_transpose2d',['../group__ops.html#gaebb59971cb9bc45005dc1d398e4f0a3d',1,'mlx::core']]],
['conv_5ftranspose3d_99',['conv_transpose3d',['../group__ops.html#ga8db814da631d9cd32a8d6563bf4ac530',1,'mlx::core']]],
['convolution_100',['Convolution',['../classmlx_1_1core_1_1_convolution.html',1,'mlx::core::Convolution'],['../classmlx_1_1core_1_1_convolution.html#a6f1de77b719bb13217b0d8c64cabb8ef',1,'mlx::core::Convolution::Convolution()']]],
['copy_101',['Copy',['../classmlx_1_1core_1_1_copy.html',1,'mlx::core::Copy'],['../classmlx_1_1core_1_1_copy.html#a6243e044af119105ffaaed7d405cd584',1,'mlx::core::Copy::Copy()']]],
['copy_102',['copy',['../namespacemlx_1_1core.html#a479648542a2bea151b947b18f0e79dd2',1,'mlx::core::copy()'],['../namespacemlx_1_1core_1_1metal.html#aa215e631e2680f04a591b88d91571719',1,'mlx::core::metal::copy()'],['../group__ops.html#gae306e93af12f774bd80bad6c231b09d6',1,'mlx::core::copy()']]],
['copy_2eh_103',['copy.h',['../common_2copy_8h.html',1,'(Global Namespace)'],['../metal_2copy_8h.html',1,'(Global Namespace)'],['../metal_2kernels_2copy_8h.html',1,'(Global Namespace)']]],
['copy_5fg_104',['copy_g',['../metal_2kernels_2copy_8h.html#a778ce2dbfbaa23b24bd5efbe68448c36',1,'copy.h']]],
['copy_5fg_5fnd1_105',['copy_g_nd1',['../metal_2kernels_2copy_8h.html#aba4530a7db6a61ca36f50e4f5e58fb77',1,'copy.h']]],
['copy_5fg_5fnd2_106',['copy_g_nd2',['../metal_2kernels_2copy_8h.html#aee678c7c31119f3e609685589f37490c',1,'copy.h']]],
['copy_5fg_5fnd3_107',['copy_g_nd3',['../metal_2kernels_2copy_8h.html#a821f8f3f3891159a295c66fc25aed1ff',1,'copy.h']]],
['copy_5fgg_108',['copy_gg',['../metal_2kernels_2copy_8h.html#a1e39c2683eeaf05955e7619fbd34aea5',1,'copy.h']]],
['copy_5fgg_5fnd1_109',['copy_gg_nd1',['../metal_2kernels_2copy_8h.html#a3278d9c999718bee3ccbe2922f501bf1',1,'copy.h']]],
['copy_5fgg_5fnd2_110',['copy_gg_nd2',['../metal_2kernels_2copy_8h.html#a3e2d3cc7f34f56170409b6735f51a950',1,'copy.h']]],
['copy_5fgg_5fnd3_111',['copy_gg_nd3',['../metal_2kernels_2copy_8h.html#a59f43b5bffed936d7559ceb06a10aabd',1,'copy.h']]],
['copy_5fgpu_112',['copy_gpu',['../namespacemlx_1_1core.html#addaa46a13ac2deb1d9ce621338320e0e',1,'mlx::core::copy_gpu(const array &amp;src, array &amp;out, CopyType ctype, const Stream &amp;s)'],['../namespacemlx_1_1core.html#a6a6f4e46c8fc44fdc74c50ace02bcf38',1,'mlx::core::copy_gpu(const array &amp;src, array &amp;out, CopyType ctype)']]],
['copy_5fgpu_5finplace_113',['copy_gpu_inplace',['../namespacemlx_1_1core.html#a69e30f5d30a6d72ac0ffe4886f24b7ba',1,'mlx::core::copy_gpu_inplace(const array &amp;in, array &amp;out, const std::vector&lt; int &gt; &amp;data_shape, const std::vector&lt; stride_t &gt; &amp;i_strides, const std::vector&lt; stride_t &gt; &amp;o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype, const Stream &amp;s)'],['../namespacemlx_1_1core.html#a8e1ccb0ed9387b0a789311d9f8964803',1,'mlx::core::copy_gpu_inplace(const array &amp;src, array &amp;out, CopyType ctype, const Stream &amp;s)'],['../namespacemlx_1_1core.html#ae55b801b09ccf55cba96278163a9b1ef',1,'mlx::core::copy_gpu_inplace(const array &amp;in, array &amp;out, const std::vector&lt; int64_t &gt; &amp;istride, int64_t ioffset, CopyType ctype, const Stream &amp;s)']]],
['copy_5fhartley_114',['copy_hartley',['../namespacepocketfft_1_1detail.html#abac3fcc8ce83800d228774f64c28d4c3',1,'pocketfft::detail::copy_hartley(const multi_iter&lt; vlen &gt; &amp;it, const vtype_t&lt; T &gt; *src, ndarr&lt; T &gt; &amp;dst)'],['../namespacepocketfft_1_1detail.html#ae7b44d2773d9d06a9787aff01d66b3ed',1,'pocketfft::detail::copy_hartley(const multi_iter&lt; vlen &gt; &amp;it, const T *src, ndarr&lt; T &gt; &amp;dst)']]],
['copy_5finplace_115',['copy_inplace',['../namespacemlx_1_1core.html#a98495894a796b2cc6d022e7a03432c64',1,'mlx::core::copy_inplace(const array &amp;src, array &amp;dst, CopyType ctype)'],['../namespacemlx_1_1core.html#aad636e2d0b2f882cadd1b438f4daa9ed',1,'mlx::core::copy_inplace(const array &amp;src, array &amp;dst, const std::vector&lt; int &gt; &amp;data_shape, const std::vector&lt; stride_t &gt; &amp;i_strides, const std::vector&lt; stride_t &gt; &amp;o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype)']]],
['copy_5finput_116',['copy_input',['../namespacepocketfft_1_1detail.html#aff05be3064743c1143b19318ab12ad4a',1,'pocketfft::detail::copy_input(const multi_iter&lt; vlen &gt; &amp;it, const cndarr&lt; cmplx&lt; T &gt; &gt; &amp;src, cmplx&lt; vtype_t&lt; T &gt; &gt; *dst)'],['../namespacepocketfft_1_1detail.html#a30fc708f9d8f9cfa74194925c7863c0a',1,'pocketfft::detail::copy_input(const multi_iter&lt; vlen &gt; &amp;it, const cndarr&lt; T &gt; &amp;src, vtype_t&lt; T &gt; *dst)'],['../namespacepocketfft_1_1detail.html#a3387bd35f237870e42b8461769e6aec4',1,'pocketfft::detail::copy_input(const multi_iter&lt; vlen &gt; &amp;it, const cndarr&lt; T &gt; &amp;src, T *dst)']]],
['copy_5foutput_117',['copy_output',['../namespacepocketfft_1_1detail.html#a1523a037300a8da05db210b802d9cb0e',1,'pocketfft::detail::copy_output(const multi_iter&lt; vlen &gt; &amp;it, const cmplx&lt; vtype_t&lt; T &gt; &gt; *src, ndarr&lt; cmplx&lt; T &gt; &gt; &amp;dst)'],['../namespacepocketfft_1_1detail.html#a21980853aca4d92ed06e3dcffe7ef660',1,'pocketfft::detail::copy_output(const multi_iter&lt; vlen &gt; &amp;it, const vtype_t&lt; T &gt; *src, ndarr&lt; T &gt; &amp;dst)'],['../namespacepocketfft_1_1detail.html#a310481c334e46674710ba794ad7403c0',1,'pocketfft::detail::copy_output(const multi_iter&lt; vlen &gt; &amp;it, const T *src, ndarr&lt; T &gt; &amp;dst)']]],
['copy_5fs_118',['copy_s',['../metal_2kernels_2copy_8h.html#aef09f9b9475345b1bba121d037d222ea',1,'copy.h']]],
['copy_5fs2_119',['copy_s2',['../metal_2kernels_2copy_8h.html#a8023e9335cc5334847a8d315042be3a3',1,'copy.h']]],
['copy_5fshared_5fbuffer_120',['copy_shared_buffer',['../classmlx_1_1core_1_1array.html#a28df7a333d90a311c49bc4bce7a1ad6d',1,'mlx::core::array::copy_shared_buffer(const array &amp;other, const std::vector&lt; size_t &gt; &amp;strides, Flags flags, size_t data_size, size_t offset=0)'],['../classmlx_1_1core_1_1array.html#a92974c656c35a972ad241f80584bbd29',1,'mlx::core::array::copy_shared_buffer(const array &amp;other)']]],
['copy_5fv_121',['copy_v',['../metal_2kernels_2copy_8h.html#ae26a13e0c8e6c15f7b10078e65970659',1,'copy.h']]],
['copy_5fv2_122',['copy_v2',['../metal_2kernels_2copy_8h.html#aee14a5326f53d9b30b0b38e27d180ef3',1,'copy.h']]],
['copytype_123',['CopyType',['../namespacemlx_1_1core.html#abd84ff6c5245e4e170b2ef5247594337',1,'mlx::core']]],
['core_20array_20operations_124',['Core array operations',['../group__ops.html',1,'']]],
['cos_125',['Cos',['../struct_cos.html',1,'Cos'],['../classmlx_1_1core_1_1_cos.html',1,'mlx::core::Cos'],['../structmlx_1_1core_1_1detail_1_1_cos.html',1,'mlx::core::detail::Cos'],['../classmlx_1_1core_1_1_cos.html#a2acb9fcf0901462189c476756fd99995',1,'mlx::core::Cos::Cos()']]],
['cos_126',['cos',['../namespacepocketfft_1_1detail.html#a499c1e8b7d79a5272af024f46c63ff9d',1,'pocketfft::detail::cos()'],['../namespacemetal.html#a2fa4778a6fe2fa43253ea724e5a608a3',1,'metal::cos()'],['../namespacemetal_1_1fast.html#a75b6bb32fa3870eda46a7bfc9f481f88',1,'metal::fast::cos()'],['../namespacemetal_1_1precise.html#ac4941f62e7d8ab9d7cabbd967aa9f220',1,'metal::precise::cos()'],['../group__ops.html#ga39dfdf72b556012aa35ff27a94116e74',1,'mlx::core::cos()']]],
['cosh_127',['Cosh',['../struct_cosh.html',1,'Cosh'],['../classmlx_1_1core_1_1_cosh.html',1,'mlx::core::Cosh'],['../structmlx_1_1core_1_1detail_1_1_cosh.html',1,'mlx::core::detail::Cosh'],['../classmlx_1_1core_1_1_cosh.html#a44e8ac2e09a55ec32e9dc6641eedc8f1',1,'mlx::core::Cosh::Cosh()']]],
['cosh_128',['cosh',['../namespacemetal.html#a8a68a88cc110830d057dbd71431b93c0',1,'metal::cosh()'],['../namespacemetal_1_1fast.html#a31544ad9de28012a4ddda86e3966a77e',1,'metal::fast::cosh()'],['../namespacemetal_1_1precise.html#a72d86d508300a9b58f4ccbbe70da4fbc',1,'metal::precise::cosh()'],['../group__ops.html#ga2181b71cda88007a3092be4795ff0715',1,'mlx::core::cosh()']]],
['cosine_129',['cosine',['../structpocketfft_1_1detail_1_1_exec_dcst.html#a185023fc1e386cc8f233b79c49c1fd8a',1,'pocketfft::detail::ExecDcst']]],
['cospi_130',['cospi',['../namespacemetal.html#a5c2f37939ad705ddea4409d3bedb8ce1',1,'metal::cospi()'],['../namespacemetal_1_1fast.html#a9906b41f75319b384ffb570cc94d67ce',1,'metal::fast::cospi()'],['../namespacemetal_1_1precise.html#a2392b78bd196efdbbac65901c4ab20e7',1,'metal::precise::cospi()']]],
['cost_5fguess_131',['cost_guess',['../structpocketfft_1_1detail_1_1util.html#ad3d874bc3fb0048df2270779a15d4bd0',1,'pocketfft::detail::util']]],
['count_5fdown_132',['count_down',['../classpocketfft_1_1detail_1_1threading_1_1latch.html#a81d6597189b40410e35f3cd653fd1342',1,'pocketfft::detail::threading::latch']]],
['cpu_133',['cpu',['../structmlx_1_1core_1_1_device.html#a69ee81924251dec96f1945c9d91506fd',1,'mlx::core::Device::cpu'],['../structmlx_1_1core_1_1_device.html#ac45b3de9b3458d8f31005136cde20fdbad9747e2da342bdb995f6389533ad1a3d',1,'mlx::core::Device::cpu']]],
['cross_134',['cross',['../namespacemlx_1_1core_1_1linalg.html#abcda3fbda45183c21e7f27aa0dde64e6',1,'mlx::core::linalg']]],
['ctile_135',['Ctile',['../structmlx_1_1steel_1_1_block_m_m_a.html#a81838da5d81e62d372d581be599c5a88',1,'mlx::steel::BlockMMA']]],
['cummax_136',['CumMax',['../struct_cum_max.html',1,'']]],
['cummax_137',['cummax',['../group__ops.html#gaee37cac8476e8f8d666bcded5bc59143',1,'mlx::core']]],
['cummin_138',['CumMin',['../struct_cum_min.html',1,'']]],
['cummin_139',['cummin',['../group__ops.html#ga19c1bf6929fe8d66b9cd408946aea6a8',1,'mlx::core']]],
['cumprod_140',['CumProd',['../struct_cum_prod.html',1,'']]],
['cumprod_141',['cumprod',['../group__ops.html#ga0d71dfbc14ef3ed564b0c5ee26af680f',1,'mlx::core']]],
['cumprod_3c_20bool_20_3e_142',['CumProd&lt; bool &gt;',['../struct_cum_prod_3_01bool_01_4.html',1,'']]],
['cumsum_143',['CumSum',['../struct_cum_sum.html',1,'']]],
['cumsum_144',['cumsum',['../group__ops.html#gaddc825a5c173e195ab0fda83ad630420',1,'mlx::core']]],
['custom_145',['Custom',['../classmlx_1_1core_1_1fast_1_1_custom.html',1,'mlx::core::fast::Custom'],['../classmlx_1_1core_1_1fast_1_1_custom.html#a4186fea23f7156c38960426821fca313',1,'mlx::core::fast::Custom::Custom()']]],
['custom_5ffunction_146',['custom_function',['../namespacemlx_1_1core.html#a8d3ca5fbaecdb995660c24cde5aeebaf',1,'mlx::core']]],
['custom_5fvjp_147',['custom_vjp',['../namespacemlx_1_1core.html#a9290596250fa308df4c69b44483bb8aa',1,'mlx::core']]],
['customkernel_148',['CustomKernel',['../classmlx_1_1core_1_1fast_1_1_custom_kernel.html',1,'mlx::core::fast::CustomKernel'],['../classmlx_1_1core_1_1fast_1_1_custom_kernel.html#a954893e07f0d36715b4e1e414b6f2153',1,'mlx::core::fast::CustomKernel::CustomKernel()']]],
['customkernelshapeinfo_149',['CustomKernelShapeInfo',['../structmlx_1_1core_1_1fast_1_1_custom_kernel_shape_info.html',1,'mlx::core::fast']]],
['customtransforms_150',['CustomTransforms',['../classmlx_1_1core_1_1_custom_transforms.html',1,'mlx::core::CustomTransforms'],['../classmlx_1_1core_1_1_custom_transforms.html#ab52abadb9c6f6db83d087c7b751be488',1,'mlx::core::CustomTransforms::CustomTransforms()']]]
];

View File

@ -67,7 +67,7 @@ var searchData=
['get_5fpool_64',['get_pool',['../namespacepocketfft_1_1detail_1_1threading.html#a7ec2b3f99232bd0f15f7b022c59d139a',1,'pocketfft::detail::threading']]],
['get_5fprimitive_5fstring_65',['get_primitive_string',['../namespacemlx_1_1core.html#ad4be35b310a252edd80d9cf04f094a60',1,'mlx::core']]],
['get_5fquantized_5fkernel_66',['get_quantized_kernel',['../namespacemlx_1_1core.html#aa3faeae5378bfaafe3ce3432a051e43e',1,'mlx::core']]],
['get_5freduce_5finit_5fkernel_67',['get_reduce_init_kernel',['../namespacemlx_1_1core.html#a51c4bb09230348bd0252e22bfdc9bc89',1,'mlx::core']]],
['get_5freduce_5finit_5fkernel_67',['get_reduce_init_kernel',['../namespacemlx_1_1core.html#a3bd386cb6db09f636963ce66ceaf8647',1,'mlx::core']]],
['get_5freduce_5fkernel_68',['get_reduce_kernel',['../namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b',1,'mlx::core']]],
['get_5freduction_5fplan_69',['get_reduction_plan',['../namespacemlx_1_1core.html#ac97b5a6f009ca3d99854ce9512c20dba',1,'mlx::core']]],
['get_5fscan_5fkernel_70',['get_scan_kernel',['../namespacemlx_1_1core.html#aeefaff208444d3fa61ecc0946fe1de5f',1,'mlx::core']]],

View File

@ -29,84 +29,85 @@ var searchData=
['max_5fthreads_26',['max_threads',['../namespacepocketfft_1_1detail_1_1threading.html#a2d5c0729f0b66cf061918baea4337d70',1,'pocketfft::detail::threading']]],
['maximum_27',['Maximum',['../struct_maximum.html',1,'Maximum'],['../structmlx_1_1core_1_1detail_1_1_maximum.html',1,'mlx::core::detail::Maximum'],['../classmlx_1_1core_1_1_maximum.html',1,'mlx::core::Maximum'],['../classmlx_1_1core_1_1_maximum.html#a28389307e385efe1b2955b86b115e816',1,'mlx::core::Maximum::Maximum()']]],
['maximum_28',['maximum',['../group__ops.html#ga7ade2ea305e2e4219c3609443fb5db8d',1,'mlx::core']]],
['mb_5fblock_5fmerge_29',['mb_block_merge',['../sort_8h.html#ab381cd57f344bc7304ab580bfdc78807',1,'sort.h']]],
['mb_5fblock_5fpartition_30',['mb_block_partition',['../sort_8h.html#a32cbe4163b8b0f5cb2c97b256119a4b2',1,'sort.h']]],
['mb_5fblock_5fsort_31',['mb_block_sort',['../sort_8h.html#aa48ff1aff1e9dc1301b6781aa0721d6b',1,'sort.h']]],
['mean_32',['mean',['../group__ops.html#gade46e768fd46b8b640eb16f26abeecef',1,'mlx::core::mean(const array &amp;a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga52b59fdd8e8430538e564f5bbcfa31e6',1,'mlx::core::mean(const array &amp;a, StreamOrDevice s={})'],['../group__ops.html#ga066161f3d3e395a1d76c638cb680d444',1,'mlx::core::mean(const array &amp;a, const std::vector&lt; int &gt; &amp;axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#ga45fba73eab0e3b6e128ed3ce2f43a5da',1,'mlx::core::mean(const array &amp;a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
['median3_33',['median3',['../namespacemetal.html#aa3ff49457ce3c93fc1c0897fd1525157',1,'metal::median3()'],['../namespacemetal_1_1fast.html#a742b55f1e4369921ee7f60d70185bfbc',1,'metal::fast::median3()'],['../namespacemetal_1_1precise.html#a14555ff99c4388493fec48e070144ae2',1,'metal::precise::median3()']]],
['merge_5fpartition_34',['merge_partition',['../struct_block_merge_sort.html#ab2300cbecb23f3433bad888924c831ca',1,'BlockMergeSort::merge_partition()'],['../struct_kernel_multi_block_merge_sort.html#ab15895b4233aba0e279cc44a07a201fe',1,'KernelMultiBlockMergeSort::merge_partition()']]],
['merge_5fstep_35',['merge_step',['../struct_block_merge_sort.html#ab65f190edf1851b37c39ad49ce99a43c',1,'BlockMergeSort']]],
['meshgrid_36',['meshgrid',['../group__ops.html#ga577c911618575314de63d1060656a26e',1,'mlx::core']]],
['metal_37',['metal',['../namespacemetal.html',1,'']]],
['metal_2eh_38',['metal.h',['../metal_8h.html',1,'']]],
['metal_3a_3afast_39',['fast',['../namespacemetal_1_1fast.html',1,'metal']]],
['metal_3a_3aprecise_40',['precise',['../namespacemetal_1_1precise.html',1,'metal']]],
['metal_5fimpl_2eh_41',['metal_impl.h',['../metal__impl_8h.html',1,'']]],
['metal_5fkernel_42',['metal_kernel',['../namespacemlx_1_1core_1_1fast.html#ab16436b465dc10ce472193d541d8426e',1,'mlx::core::fast']]],
['metalallocator_43',['MetalAllocator',['../classmlx_1_1core_1_1metal_1_1_metal_allocator.html',1,'mlx::core::metal']]],
['metalkernelfunction_44',['MetalKernelFunction',['../namespacemlx_1_1core_1_1fast.html#a0e8c2c4ea7a946568c8fe5b4810417e0',1,'mlx::core::fast']]],
['min_45',['Min',['../struct_min.html',1,'Min&lt; U &gt;'],['../classmlx_1_1core_1_1distributed_1_1_all_reduce.html#abb4560980e5d01aed14175ce8f6fc924a4f685dcd48e6614d6bb2ccda4f2686ef',1,'mlx::core::distributed::AllReduce::Min'],['../classmlx_1_1core_1_1_reduce.html#a0848518b16ae6d4043d6be247bdf31c9a0d3d1f5c94725bdc42fa692e2c074418',1,'mlx::core::Reduce::Min'],['../classmlx_1_1core_1_1_scan.html#a47bf2ec54ead4b8f00f9f188518630f1a7d2ee8f14f2e70a9d47170fecc6da898',1,'mlx::core::Scan::Min'],['../classmlx_1_1core_1_1_scatter.html#a614d19af11dc30644b2b4941033b613cad914e4c3475ce9858f2de4bf35dcfdbf',1,'mlx::core::Scatter::Min']]],
['min_46',['min',['../struct_limits.html#a6e81584ba65a4dc6ff9366b458e3a20e',1,'Limits::min'],['../struct_limits_3_01uint8__t_01_4.html#a408bd5a337e7292f06e63da81193629a',1,'Limits&lt; uint8_t &gt;::min'],['../struct_limits_3_01uint16__t_01_4.html#ae173984c3be8b6750f27daed581805fe',1,'Limits&lt; uint16_t &gt;::min'],['../struct_limits_3_01uint32__t_01_4.html#ab0c3975e02053b234c7b606ababa66e1',1,'Limits&lt; uint32_t &gt;::min'],['../struct_limits_3_01uint64__t_01_4.html#a80627f39e951398283942cefa48f4dd0',1,'Limits&lt; uint64_t &gt;::min'],['../struct_limits_3_01int8__t_01_4.html#a7a809307d2bba80382f0645d277eaa4b',1,'Limits&lt; int8_t &gt;::min'],['../struct_limits_3_01int16__t_01_4.html#adca7139647801e223c35b0abc7da5240',1,'Limits&lt; int16_t &gt;::min'],['../struct_limits_3_01int32__t_01_4.html#af336a1b22a8ed6a83a4cfb5bf8869771',1,'Limits&lt; int32_t &gt;::min'],['../struct_limits_3_01int64__t_01_4.html#a1c90fb96af515badaccaa835b08f7428',1,'Limits&lt; int64_t &gt;::min'],['../struct_limits_3_01half_01_4.html#aca7b036c257878bf1b80912fb5d4516d',1,'Limits&lt; half &gt;::min'],['../struct_limits_3_01float_01_4.html#a3225e334d372ee86128c89a440d8648f',1,'Limits&lt; float &gt;::min'],['../struct_limits_3_01bfloat16__t_01_4.html#a2fd1811b9f615b2b897904bc27d1cb49',1,'Limits&lt; bfloat16_t &gt;::min'],['../struct_limits_3_01bool_01_4.html#a139f787b57536d455490b8ef801d37cc',1,'Limits&lt; bool &gt;::min'],['../struct_limits_3_01complex64__t_01_4.html#aa67b04aa7abcd67f7af0808737ab8e14',1,'Limits&lt; complex64_t &gt;::min'],['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#adaed80031f5ca0ff69d30ec4c5d0c98f',1,'metal::_numeric_limits_impl&lt; bfloat16_t &gt;::min()'],['../namespacemetal.html#a6653b28c9473087141eddce39878d4d3',1,'metal::min()'],['../namespacemetal_1_1fast.html#a3e958e56a4712687c381a0b64d123e61',1,'metal::fast::min()'],['../namespacemetal_1_1precise.html#afed0da2f7df3505b5dffa2389c3cb36e',1,'metal::precise::min()'],['../group__ops.html#gab27599802617a4c8f9964ab5f4ffee12',1,'mlx::core::min(const array &amp;a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga0140b91e9cdfc3fef0da8e332f65a9e8',1,'mlx::core::min(const array &amp;a, StreamOrDevice s={})'],['../group__ops.html#ga6efb83cd46436678c8f8c4af15cc00f5',1,'mlx::core::min(const array &amp;a, const std::vector&lt; int &gt; &amp;axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#ga36fa315eef677f4143868f552cd26d03',1,'mlx::core::min(const array &amp;a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
['min3_47',['min3',['../namespacemetal.html#a005510c8c0f964ce2b8aad3ba76a7a3f',1,'metal::min3()'],['../namespacemetal_1_1fast.html#a606a4c1b34ce05ea89ca5af81724036f',1,'metal::fast::min3()'],['../namespacemetal_1_1precise.html#a4d37ce31c3549ca4772a4ee29798e231',1,'metal::precise::min3()']]],
['min_5fexponent_48',['min_exponent',['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#a13829f8c7a7c0efdc8946eff5d3c9470',1,'metal::_numeric_limits_impl&lt; bfloat16_t &gt;']]],
['min_5fexponent10_49',['min_exponent10',['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#aeaed172780720e06b8731cef3177e277',1,'metal::_numeric_limits_impl&lt; bfloat16_t &gt;']]],
['minimum_50',['Minimum',['../struct_minimum.html',1,'Minimum'],['../structmlx_1_1core_1_1detail_1_1_minimum.html',1,'mlx::core::detail::Minimum'],['../classmlx_1_1core_1_1_minimum.html',1,'mlx::core::Minimum'],['../classmlx_1_1core_1_1_minimum.html#ab0f2ce17108df44b82cff68886b0f6f5',1,'mlx::core::Minimum::Minimum()']]],
['minimum_51',['minimum',['../group__ops.html#ga49ba00c090f81f331c91b0c97040bce0',1,'mlx::core']]],
['mlx_52',['mlx',['../namespacemlx.html',1,'']]],
['mlx_2eh_53',['mlx.h',['../mlx_8h.html',1,'']]],
['mlx_3a_3acore_54',['core',['../namespacemlx_1_1core.html',1,'mlx']]],
['mlx_3a_3acore_3a_3aallocator_55',['allocator',['../namespacemlx_1_1core_1_1allocator.html',1,'mlx::core']]],
['mlx_3a_3acore_3a_3adetail_56',['detail',['../namespacemlx_1_1core_1_1detail.html',1,'mlx::core']]],
['mlx_3a_3acore_3a_3adistributed_57',['distributed',['../namespacemlx_1_1core_1_1distributed.html',1,'mlx::core']]],
['mlx_3a_3acore_3a_3adistributed_3a_3adetail_58',['detail',['../namespacemlx_1_1core_1_1distributed_1_1detail.html',1,'mlx::core::distributed']]],
['mlx_3a_3acore_3a_3afast_59',['fast',['../namespacemlx_1_1core_1_1fast.html',1,'mlx::core']]],
['mlx_3a_3acore_3a_3afft_60',['fft',['../namespacemlx_1_1core_1_1fft.html',1,'mlx::core']]],
['mlx_3a_3acore_3a_3aio_61',['io',['../namespacemlx_1_1core_1_1io.html',1,'mlx::core']]],
['mlx_3a_3acore_3a_3alinalg_62',['linalg',['../namespacemlx_1_1core_1_1linalg.html',1,'mlx::core']]],
['mlx_3a_3acore_3a_3ametal_63',['metal',['../namespacemlx_1_1core_1_1metal.html',1,'mlx::core']]],
['mlx_3a_3acore_3a_3arandom_64',['random',['../namespacemlx_1_1core_1_1random.html',1,'mlx::core']]],
['mlx_3a_3acore_3a_3ascheduler_65',['scheduler',['../namespacemlx_1_1core_1_1scheduler.html',1,'mlx::core']]],
['mlx_3a_3asteel_66',['steel',['../namespacemlx_1_1steel.html',1,'mlx']]],
['mlx_5fatomic_67',['mlx_atomic',['../structmlx__atomic.html',1,'']]],
['mlx_5fatomic_3c_20t_2c_20enable_5fif_5ft_3c_20is_5fmetal_5fatomic_3c_20t_20_3e_20_3e_20_3e_68',['mlx_atomic&lt; T, enable_if_t&lt; is_metal_atomic&lt; T &gt; &gt; &gt;',['../structmlx__atomic_3_01_t_00_01enable__if__t_3_01is__metal__atomic_3_01_t_01_4_01_4_01_4.html',1,'']]],
['mlx_5fatomic_5fcompare_5fexchange_5fweak_5fexplicit_69',['mlx_atomic_compare_exchange_weak_explicit',['../atomic_8h.html#ad7f32327ff66354cfa2f0cfdac79316f',1,'mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic&lt; T &gt; *object, thread T *expected, T val, size_t offset):&#160;atomic.h'],['../atomic_8h.html#aa8f47b2e9b95d4b00ad51f08b070deb5',1,'mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic&lt; T &gt; *object, thread uint *expected, uint val, size_t offset):&#160;atomic.h']]],
['mlx_5fatomic_5ffetch_5fadd_5fexplicit_70',['mlx_atomic_fetch_add_explicit',['../atomic_8h.html#aad448d9e06e001700b65ca8317216a3b',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fand_5fexplicit_71',['mlx_atomic_fetch_and_explicit',['../atomic_8h.html#a253e3c870c0ddc7c28ab2f6ca2c3eae5',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmax_5fexplicit_72',['mlx_atomic_fetch_max_explicit',['../atomic_8h.html#ac480f2b459a8ad9095cee353e152d00c',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmax_5fexplicit_3c_20float_20_3e_73',['mlx_atomic_fetch_max_explicit&lt; float &gt;',['../atomic_8h.html#a1dce2abfa16417122c4d2bf261129ae4',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmin_5fexplicit_74',['mlx_atomic_fetch_min_explicit',['../atomic_8h.html#a2ec33dca0039bd944d73d1c2b378cc19',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmin_5fexplicit_3c_20float_20_3e_75',['mlx_atomic_fetch_min_explicit&lt; float &gt;',['../atomic_8h.html#ab7d1dc49f319f239b7ee0b7c72976dd0',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmul_5fexplicit_76',['mlx_atomic_fetch_mul_explicit',['../atomic_8h.html#adfdbea60436f14f1af9ce36e2a0a77a3',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5for_5fexplicit_77',['mlx_atomic_fetch_or_explicit',['../atomic_8h.html#ab7391f197001471e4788312bdb6ab37a',1,'atomic.h']]],
['mlx_5fatomic_5fload_5fexplicit_78',['mlx_atomic_load_explicit',['../atomic_8h.html#a253a4e8c2c5768a069e2791b627dfc99',1,'atomic.h']]],
['mlx_5fatomic_5fstore_5fexplicit_79',['mlx_atomic_store_explicit',['../atomic_8h.html#a0ae453140b0819a4c02f265334de98c0',1,'atomic.h']]],
['mlx_5flapack_5ffunc_80',['MLX_LAPACK_FUNC',['../lapack_8h.html#ae22db9704827bf013a0a61f21a47464b',1,'lapack.h']]],
['mlx_5fmtl_5fconst_81',['MLX_MTL_CONST',['../kernels_2gemv__masked_8h.html#a0386011c52d03e60885a31e6fbd903dd',1,'MLX_MTL_CONST:&#160;gemv_masked.h'],['../quantized_8h.html#a0386011c52d03e60885a31e6fbd903dd',1,'MLX_MTL_CONST:&#160;quantized.h'],['../sort_8h.html#a0386011c52d03e60885a31e6fbd903dd',1,'MLX_MTL_CONST:&#160;sort.h']]],
['mlx_5fmtl_5floop_5funroll_82',['MLX_MTL_LOOP_UNROLL',['../sort_8h.html#ad34b622323cebef136669fedd7229515',1,'sort.h']]],
['mlx_5fmtl_5fpragma_5funroll_83',['MLX_MTL_PRAGMA_UNROLL',['../kernels_2gemv__masked_8h.html#a069b682d7d21827461544817d722bfd3',1,'MLX_MTL_PRAGMA_UNROLL:&#160;gemv_masked.h'],['../backend_2metal_2kernels_2utils_8h.html#a069b682d7d21827461544817d722bfd3',1,'MLX_MTL_PRAGMA_UNROLL:&#160;utils.h']]],
['mlxconvparams_84',['MLXConvParams',['../struct_m_l_x_conv_params.html',1,'']]],
['mlxconvparams_3c_202_20_3e_85',['MLXConvParams&lt; 2 &gt;',['../struct_m_l_x_conv_params.html',1,'']]],
['mlxfastattentionparams_86',['MLXFastAttentionParams',['../struct_m_l_x_fast_attention_params.html',1,'']]],
['mlxscaleddotproductattentionparams_87',['MLXScaledDotProductAttentionParams',['../struct_m_l_x_scaled_dot_product_attention_params.html',1,'']]],
['mma_88',['mma',['../structmlx_1_1steel_1_1_base_m_m_a_frag_3_01_t_00_018_00_018_01_4.html#a8028512f5a3d2b6acaf966be529627a3',1,'mlx::steel::BaseMMAFrag&lt; T, 8, 8 &gt;::mma(thread frag_type &amp;D, thread frag_type &amp;A, thread frag_type &amp;B, thread frag_type &amp;C)'],['../structmlx_1_1steel_1_1_base_m_m_a_frag_3_01_t_00_018_00_018_01_4.html#a1868f57d57c8adedab2c58492ec76946',1,'mlx::steel::BaseMMAFrag&lt; T, 8, 8 &gt;::mma(thread mat_type &amp;D, thread mat_type &amp;A, thread mat_type &amp;B, thread mat_type &amp;C)'],['../structmlx_1_1steel_1_1_block_m_m_a.html#a6a2c2a6d5e767d52c41b42a9d36086b0',1,'mlx::steel::BlockMMA::mma()']]],
['mma_2eh_89',['mma.h',['../mma_8h.html',1,'']]],
['mma_5ft_90',['mma_t',['../structmlx_1_1steel_1_1_g_e_m_m_kernel.html#add8c6a31011a4895667c2a94a5af3782',1,'mlx::steel::GEMMKernel']]],
['mmafrag_5facc_5ft_91',['MMAFrag_acc_t',['../structmlx_1_1steel_1_1_block_m_m_a.html#ae2c42cb6d0dde785859164c195f4d13c',1,'mlx::steel::BlockMMA']]],
['mmafrag_5ft_92',['MMAFrag_t',['../structmlx_1_1steel_1_1_m_m_a_tile.html#abe33de70e34300745bad9aa822fd0382',1,'mlx::steel::MMATile']]],
['mmatile_93',['MMATile',['../structmlx_1_1steel_1_1_m_m_a_tile.html',1,'mlx::steel::MMATile&lt; T, kTileRows_, kTileCols_, MMAFrag_ &gt;'],['../structmlx_1_1steel_1_1_m_m_a_tile.html#aa3fb310dd08ec23c334511f7b316d1b6',1,'mlx::steel::MMATile::MMATile()']]],
['mmatile_3c_20float_2c_201_2c_20tn_2c_20mlx_3a_3asteel_3a_3abasemmafrag_20_3e_94',['MMATile&lt; float, 1, TN, mlx::steel::BaseMMAFrag &gt;',['../structmlx_1_1steel_1_1_m_m_a_tile.html',1,'mlx::steel']]],
['mmatile_3c_20float_2c_20tm_2c_201_2c_20mlx_3a_3asteel_3a_3abasemmafrag_20_3e_95',['MMATile&lt; float, TM, 1, mlx::steel::BaseMMAFrag &gt;',['../structmlx_1_1steel_1_1_m_m_a_tile.html',1,'mlx::steel']]],
['mmatile_3c_20float_2c_20tm_2c_20tn_2c_20mlx_3a_3asteel_3a_3abasemmafrag_20_3e_96',['MMATile&lt; float, TM, TN, mlx::steel::BaseMMAFrag &gt;',['../structmlx_1_1steel_1_1_m_m_a_tile.html',1,'mlx::steel']]],
['move_5fshared_5fbuffer_97',['move_shared_buffer',['../classmlx_1_1core_1_1array.html#acce00db63e0f3d80f797b02397ade836',1,'mlx::core::array::move_shared_buffer(array other, const std::vector&lt; size_t &gt; &amp;strides, Flags flags, size_t data_size, size_t offset=0)'],['../classmlx_1_1core_1_1array.html#a38d7ad605f8282e5e49d0c09e0555c78',1,'mlx::core::array::move_shared_buffer(array other)']]],
['moveaxis_98',['moveaxis',['../group__ops.html#ga24067d10a842db2c9d509ea48135a2c3',1,'mlx::core']]],
['mpinplace_99',['MPINPLACE',['../namespacepocketfft_1_1detail.html#af5eedf3cdfc83c0a30807092c39a9ce2',1,'pocketfft::detail']]],
['mtl_5fconst_100',['MTL_CONST',['../defines_8h.html#a767ed9f2604de22b259cee02c4ce1d22',1,'defines.h']]],
['mtl_5fdevice_101',['mtl_device',['../classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653',1,'mlx::core::metal::Device']]],
['mtl_5fresidency_5fset_102',['mtl_residency_set',['../classmlx_1_1core_1_1metal_1_1_residency_set.html#ac4bfe5ef5e2eaebc458a1ed1953d15e9',1,'mlx::core::metal::ResidencySet']]],
['mtlfclist_103',['MTLFCList',['../namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54',1,'mlx::core::metal']]],
['mtx_104',['mtx',['../structmlx_1_1core_1_1scheduler_1_1_stream_thread.html#a70410c9e612f871663929f1e8441a976',1,'mlx::core::scheduler::StreamThread']]],
['multi_5fiter_105',['multi_iter',['../classpocketfft_1_1detail_1_1multi__iter.html',1,'pocketfft::detail::multi_iter&lt; N &gt;'],['../classpocketfft_1_1detail_1_1multi__iter.html#a9be43bb18840202da6d17988fccc64b9',1,'pocketfft::detail::multi_iter::multi_iter()']]],
['multiply_106',['Multiply',['../structmlx_1_1core_1_1detail_1_1_multiply.html',1,'mlx::core::detail::Multiply'],['../classmlx_1_1core_1_1_multiply.html',1,'mlx::core::Multiply'],['../struct_multiply.html',1,'Multiply'],['../classmlx_1_1core_1_1_multiply.html#aca5c50f900321f3eb4d6fbcbc225c00c',1,'mlx::core::Multiply::Multiply()']]],
['multiply_107',['multiply',['../group__ops.html#gaf57392e641640b5d06e4c99518391c38',1,'mlx::core']]],
['multivariate_5fnormal_108',['multivariate_normal',['../namespacemlx_1_1core_1_1random.html#a8c37da3c1c0c561cad7499d6d9db81fb',1,'mlx::core::random']]]
['maybeinsertbarrier_29',['maybeInsertBarrier',['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#ad538ae88f90560063f9ba502e2795991',1,'mlx::core::metal::CommandEncoder']]],
['mb_5fblock_5fmerge_30',['mb_block_merge',['../sort_8h.html#ab381cd57f344bc7304ab580bfdc78807',1,'sort.h']]],
['mb_5fblock_5fpartition_31',['mb_block_partition',['../sort_8h.html#a32cbe4163b8b0f5cb2c97b256119a4b2',1,'sort.h']]],
['mb_5fblock_5fsort_32',['mb_block_sort',['../sort_8h.html#aa48ff1aff1e9dc1301b6781aa0721d6b',1,'sort.h']]],
['mean_33',['mean',['../group__ops.html#gade46e768fd46b8b640eb16f26abeecef',1,'mlx::core::mean(const array &amp;a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga52b59fdd8e8430538e564f5bbcfa31e6',1,'mlx::core::mean(const array &amp;a, StreamOrDevice s={})'],['../group__ops.html#ga066161f3d3e395a1d76c638cb680d444',1,'mlx::core::mean(const array &amp;a, const std::vector&lt; int &gt; &amp;axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#ga45fba73eab0e3b6e128ed3ce2f43a5da',1,'mlx::core::mean(const array &amp;a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
['median3_34',['median3',['../namespacemetal.html#aa3ff49457ce3c93fc1c0897fd1525157',1,'metal::median3()'],['../namespacemetal_1_1fast.html#a742b55f1e4369921ee7f60d70185bfbc',1,'metal::fast::median3()'],['../namespacemetal_1_1precise.html#a14555ff99c4388493fec48e070144ae2',1,'metal::precise::median3()']]],
['merge_5fpartition_35',['merge_partition',['../struct_block_merge_sort.html#ab2300cbecb23f3433bad888924c831ca',1,'BlockMergeSort::merge_partition()'],['../struct_kernel_multi_block_merge_sort.html#ab15895b4233aba0e279cc44a07a201fe',1,'KernelMultiBlockMergeSort::merge_partition()']]],
['merge_5fstep_36',['merge_step',['../struct_block_merge_sort.html#ab65f190edf1851b37c39ad49ce99a43c',1,'BlockMergeSort']]],
['meshgrid_37',['meshgrid',['../group__ops.html#ga577c911618575314de63d1060656a26e',1,'mlx::core']]],
['metal_38',['metal',['../namespacemetal.html',1,'']]],
['metal_2eh_39',['metal.h',['../metal_8h.html',1,'']]],
['metal_3a_3afast_40',['fast',['../namespacemetal_1_1fast.html',1,'metal']]],
['metal_3a_3aprecise_41',['precise',['../namespacemetal_1_1precise.html',1,'metal']]],
['metal_5fimpl_2eh_42',['metal_impl.h',['../metal__impl_8h.html',1,'']]],
['metal_5fkernel_43',['metal_kernel',['../namespacemlx_1_1core_1_1fast.html#ab16436b465dc10ce472193d541d8426e',1,'mlx::core::fast']]],
['metalallocator_44',['MetalAllocator',['../classmlx_1_1core_1_1metal_1_1_metal_allocator.html',1,'mlx::core::metal']]],
['metalkernelfunction_45',['MetalKernelFunction',['../namespacemlx_1_1core_1_1fast.html#a0e8c2c4ea7a946568c8fe5b4810417e0',1,'mlx::core::fast']]],
['min_46',['Min',['../struct_min.html',1,'Min&lt; U &gt;'],['../classmlx_1_1core_1_1distributed_1_1_all_reduce.html#abb4560980e5d01aed14175ce8f6fc924a4f685dcd48e6614d6bb2ccda4f2686ef',1,'mlx::core::distributed::AllReduce::Min'],['../classmlx_1_1core_1_1_reduce.html#a0848518b16ae6d4043d6be247bdf31c9a0d3d1f5c94725bdc42fa692e2c074418',1,'mlx::core::Reduce::Min'],['../classmlx_1_1core_1_1_scan.html#a47bf2ec54ead4b8f00f9f188518630f1a7d2ee8f14f2e70a9d47170fecc6da898',1,'mlx::core::Scan::Min'],['../classmlx_1_1core_1_1_scatter.html#a614d19af11dc30644b2b4941033b613cad914e4c3475ce9858f2de4bf35dcfdbf',1,'mlx::core::Scatter::Min']]],
['min_47',['min',['../struct_limits.html#a6e81584ba65a4dc6ff9366b458e3a20e',1,'Limits::min'],['../struct_limits_3_01uint8__t_01_4.html#a408bd5a337e7292f06e63da81193629a',1,'Limits&lt; uint8_t &gt;::min'],['../struct_limits_3_01uint16__t_01_4.html#ae173984c3be8b6750f27daed581805fe',1,'Limits&lt; uint16_t &gt;::min'],['../struct_limits_3_01uint32__t_01_4.html#ab0c3975e02053b234c7b606ababa66e1',1,'Limits&lt; uint32_t &gt;::min'],['../struct_limits_3_01uint64__t_01_4.html#a80627f39e951398283942cefa48f4dd0',1,'Limits&lt; uint64_t &gt;::min'],['../struct_limits_3_01int8__t_01_4.html#a7a809307d2bba80382f0645d277eaa4b',1,'Limits&lt; int8_t &gt;::min'],['../struct_limits_3_01int16__t_01_4.html#adca7139647801e223c35b0abc7da5240',1,'Limits&lt; int16_t &gt;::min'],['../struct_limits_3_01int32__t_01_4.html#af336a1b22a8ed6a83a4cfb5bf8869771',1,'Limits&lt; int32_t &gt;::min'],['../struct_limits_3_01int64__t_01_4.html#a1c90fb96af515badaccaa835b08f7428',1,'Limits&lt; int64_t &gt;::min'],['../struct_limits_3_01half_01_4.html#aca7b036c257878bf1b80912fb5d4516d',1,'Limits&lt; half &gt;::min'],['../struct_limits_3_01float_01_4.html#a3225e334d372ee86128c89a440d8648f',1,'Limits&lt; float &gt;::min'],['../struct_limits_3_01bfloat16__t_01_4.html#a2fd1811b9f615b2b897904bc27d1cb49',1,'Limits&lt; bfloat16_t &gt;::min'],['../struct_limits_3_01bool_01_4.html#a139f787b57536d455490b8ef801d37cc',1,'Limits&lt; bool &gt;::min'],['../struct_limits_3_01complex64__t_01_4.html#aa67b04aa7abcd67f7af0808737ab8e14',1,'Limits&lt; complex64_t &gt;::min'],['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#adaed80031f5ca0ff69d30ec4c5d0c98f',1,'metal::_numeric_limits_impl&lt; bfloat16_t &gt;::min()'],['../namespacemetal.html#a6653b28c9473087141eddce39878d4d3',1,'metal::min()'],['../namespacemetal_1_1fast.html#a3e958e56a4712687c381a0b64d123e61',1,'metal::fast::min()'],['../namespacemetal_1_1precise.html#afed0da2f7df3505b5dffa2389c3cb36e',1,'metal::precise::min()'],['../group__ops.html#gab27599802617a4c8f9964ab5f4ffee12',1,'mlx::core::min(const array &amp;a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga0140b91e9cdfc3fef0da8e332f65a9e8',1,'mlx::core::min(const array &amp;a, StreamOrDevice s={})'],['../group__ops.html#ga6efb83cd46436678c8f8c4af15cc00f5',1,'mlx::core::min(const array &amp;a, const std::vector&lt; int &gt; &amp;axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#ga36fa315eef677f4143868f552cd26d03',1,'mlx::core::min(const array &amp;a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
['min3_48',['min3',['../namespacemetal.html#a005510c8c0f964ce2b8aad3ba76a7a3f',1,'metal::min3()'],['../namespacemetal_1_1fast.html#a606a4c1b34ce05ea89ca5af81724036f',1,'metal::fast::min3()'],['../namespacemetal_1_1precise.html#a4d37ce31c3549ca4772a4ee29798e231',1,'metal::precise::min3()']]],
['min_5fexponent_49',['min_exponent',['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#a13829f8c7a7c0efdc8946eff5d3c9470',1,'metal::_numeric_limits_impl&lt; bfloat16_t &gt;']]],
['min_5fexponent10_50',['min_exponent10',['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#aeaed172780720e06b8731cef3177e277',1,'metal::_numeric_limits_impl&lt; bfloat16_t &gt;']]],
['minimum_51',['Minimum',['../struct_minimum.html',1,'Minimum'],['../structmlx_1_1core_1_1detail_1_1_minimum.html',1,'mlx::core::detail::Minimum'],['../classmlx_1_1core_1_1_minimum.html',1,'mlx::core::Minimum'],['../classmlx_1_1core_1_1_minimum.html#ab0f2ce17108df44b82cff68886b0f6f5',1,'mlx::core::Minimum::Minimum()']]],
['minimum_52',['minimum',['../group__ops.html#ga49ba00c090f81f331c91b0c97040bce0',1,'mlx::core']]],
['mlx_53',['mlx',['../namespacemlx.html',1,'']]],
['mlx_2eh_54',['mlx.h',['../mlx_8h.html',1,'']]],
['mlx_3a_3acore_55',['core',['../namespacemlx_1_1core.html',1,'mlx']]],
['mlx_3a_3acore_3a_3aallocator_56',['allocator',['../namespacemlx_1_1core_1_1allocator.html',1,'mlx::core']]],
['mlx_3a_3acore_3a_3adetail_57',['detail',['../namespacemlx_1_1core_1_1detail.html',1,'mlx::core']]],
['mlx_3a_3acore_3a_3adistributed_58',['distributed',['../namespacemlx_1_1core_1_1distributed.html',1,'mlx::core']]],
['mlx_3a_3acore_3a_3adistributed_3a_3adetail_59',['detail',['../namespacemlx_1_1core_1_1distributed_1_1detail.html',1,'mlx::core::distributed']]],
['mlx_3a_3acore_3a_3afast_60',['fast',['../namespacemlx_1_1core_1_1fast.html',1,'mlx::core']]],
['mlx_3a_3acore_3a_3afft_61',['fft',['../namespacemlx_1_1core_1_1fft.html',1,'mlx::core']]],
['mlx_3a_3acore_3a_3aio_62',['io',['../namespacemlx_1_1core_1_1io.html',1,'mlx::core']]],
['mlx_3a_3acore_3a_3alinalg_63',['linalg',['../namespacemlx_1_1core_1_1linalg.html',1,'mlx::core']]],
['mlx_3a_3acore_3a_3ametal_64',['metal',['../namespacemlx_1_1core_1_1metal.html',1,'mlx::core']]],
['mlx_3a_3acore_3a_3arandom_65',['random',['../namespacemlx_1_1core_1_1random.html',1,'mlx::core']]],
['mlx_3a_3acore_3a_3ascheduler_66',['scheduler',['../namespacemlx_1_1core_1_1scheduler.html',1,'mlx::core']]],
['mlx_3a_3asteel_67',['steel',['../namespacemlx_1_1steel.html',1,'mlx']]],
['mlx_5fatomic_68',['mlx_atomic',['../structmlx__atomic.html',1,'']]],
['mlx_5fatomic_3c_20t_2c_20enable_5fif_5ft_3c_20is_5fmetal_5fatomic_3c_20t_20_3e_20_3e_20_3e_69',['mlx_atomic&lt; T, enable_if_t&lt; is_metal_atomic&lt; T &gt; &gt; &gt;',['../structmlx__atomic_3_01_t_00_01enable__if__t_3_01is__metal__atomic_3_01_t_01_4_01_4_01_4.html',1,'']]],
['mlx_5fatomic_5fcompare_5fexchange_5fweak_5fexplicit_70',['mlx_atomic_compare_exchange_weak_explicit',['../atomic_8h.html#ad7f32327ff66354cfa2f0cfdac79316f',1,'mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic&lt; T &gt; *object, thread T *expected, T val, size_t offset):&#160;atomic.h'],['../atomic_8h.html#aa8f47b2e9b95d4b00ad51f08b070deb5',1,'mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic&lt; T &gt; *object, thread uint *expected, uint val, size_t offset):&#160;atomic.h']]],
['mlx_5fatomic_5ffetch_5fadd_5fexplicit_71',['mlx_atomic_fetch_add_explicit',['../atomic_8h.html#aad448d9e06e001700b65ca8317216a3b',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fand_5fexplicit_72',['mlx_atomic_fetch_and_explicit',['../atomic_8h.html#a253e3c870c0ddc7c28ab2f6ca2c3eae5',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmax_5fexplicit_73',['mlx_atomic_fetch_max_explicit',['../atomic_8h.html#ac480f2b459a8ad9095cee353e152d00c',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmax_5fexplicit_3c_20float_20_3e_74',['mlx_atomic_fetch_max_explicit&lt; float &gt;',['../atomic_8h.html#a1dce2abfa16417122c4d2bf261129ae4',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmin_5fexplicit_75',['mlx_atomic_fetch_min_explicit',['../atomic_8h.html#a2ec33dca0039bd944d73d1c2b378cc19',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmin_5fexplicit_3c_20float_20_3e_76',['mlx_atomic_fetch_min_explicit&lt; float &gt;',['../atomic_8h.html#ab7d1dc49f319f239b7ee0b7c72976dd0',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmul_5fexplicit_77',['mlx_atomic_fetch_mul_explicit',['../atomic_8h.html#adfdbea60436f14f1af9ce36e2a0a77a3',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5for_5fexplicit_78',['mlx_atomic_fetch_or_explicit',['../atomic_8h.html#ab7391f197001471e4788312bdb6ab37a',1,'atomic.h']]],
['mlx_5fatomic_5fload_5fexplicit_79',['mlx_atomic_load_explicit',['../atomic_8h.html#a253a4e8c2c5768a069e2791b627dfc99',1,'atomic.h']]],
['mlx_5fatomic_5fstore_5fexplicit_80',['mlx_atomic_store_explicit',['../atomic_8h.html#a0ae453140b0819a4c02f265334de98c0',1,'atomic.h']]],
['mlx_5flapack_5ffunc_81',['MLX_LAPACK_FUNC',['../lapack_8h.html#ae22db9704827bf013a0a61f21a47464b',1,'lapack.h']]],
['mlx_5fmtl_5fconst_82',['MLX_MTL_CONST',['../kernels_2gemv__masked_8h.html#a0386011c52d03e60885a31e6fbd903dd',1,'MLX_MTL_CONST:&#160;gemv_masked.h'],['../quantized_8h.html#a0386011c52d03e60885a31e6fbd903dd',1,'MLX_MTL_CONST:&#160;quantized.h'],['../sort_8h.html#a0386011c52d03e60885a31e6fbd903dd',1,'MLX_MTL_CONST:&#160;sort.h']]],
['mlx_5fmtl_5floop_5funroll_83',['MLX_MTL_LOOP_UNROLL',['../sort_8h.html#ad34b622323cebef136669fedd7229515',1,'sort.h']]],
['mlx_5fmtl_5fpragma_5funroll_84',['MLX_MTL_PRAGMA_UNROLL',['../kernels_2gemv__masked_8h.html#a069b682d7d21827461544817d722bfd3',1,'MLX_MTL_PRAGMA_UNROLL:&#160;gemv_masked.h'],['../backend_2metal_2kernels_2utils_8h.html#a069b682d7d21827461544817d722bfd3',1,'MLX_MTL_PRAGMA_UNROLL:&#160;utils.h']]],
['mlxconvparams_85',['MLXConvParams',['../struct_m_l_x_conv_params.html',1,'']]],
['mlxconvparams_3c_202_20_3e_86',['MLXConvParams&lt; 2 &gt;',['../struct_m_l_x_conv_params.html',1,'']]],
['mlxfastattentionparams_87',['MLXFastAttentionParams',['../struct_m_l_x_fast_attention_params.html',1,'']]],
['mlxscaleddotproductattentionparams_88',['MLXScaledDotProductAttentionParams',['../struct_m_l_x_scaled_dot_product_attention_params.html',1,'']]],
['mma_89',['mma',['../structmlx_1_1steel_1_1_base_m_m_a_frag_3_01_t_00_018_00_018_01_4.html#a8028512f5a3d2b6acaf966be529627a3',1,'mlx::steel::BaseMMAFrag&lt; T, 8, 8 &gt;::mma(thread frag_type &amp;D, thread frag_type &amp;A, thread frag_type &amp;B, thread frag_type &amp;C)'],['../structmlx_1_1steel_1_1_base_m_m_a_frag_3_01_t_00_018_00_018_01_4.html#a1868f57d57c8adedab2c58492ec76946',1,'mlx::steel::BaseMMAFrag&lt; T, 8, 8 &gt;::mma(thread mat_type &amp;D, thread mat_type &amp;A, thread mat_type &amp;B, thread mat_type &amp;C)'],['../structmlx_1_1steel_1_1_block_m_m_a.html#a6a2c2a6d5e767d52c41b42a9d36086b0',1,'mlx::steel::BlockMMA::mma()']]],
['mma_2eh_90',['mma.h',['../mma_8h.html',1,'']]],
['mma_5ft_91',['mma_t',['../structmlx_1_1steel_1_1_g_e_m_m_kernel.html#add8c6a31011a4895667c2a94a5af3782',1,'mlx::steel::GEMMKernel']]],
['mmafrag_5facc_5ft_92',['MMAFrag_acc_t',['../structmlx_1_1steel_1_1_block_m_m_a.html#ae2c42cb6d0dde785859164c195f4d13c',1,'mlx::steel::BlockMMA']]],
['mmafrag_5ft_93',['MMAFrag_t',['../structmlx_1_1steel_1_1_m_m_a_tile.html#abe33de70e34300745bad9aa822fd0382',1,'mlx::steel::MMATile']]],
['mmatile_94',['MMATile',['../structmlx_1_1steel_1_1_m_m_a_tile.html',1,'mlx::steel::MMATile&lt; T, kTileRows_, kTileCols_, MMAFrag_ &gt;'],['../structmlx_1_1steel_1_1_m_m_a_tile.html#aa3fb310dd08ec23c334511f7b316d1b6',1,'mlx::steel::MMATile::MMATile()']]],
['mmatile_3c_20float_2c_201_2c_20tn_2c_20mlx_3a_3asteel_3a_3abasemmafrag_20_3e_95',['MMATile&lt; float, 1, TN, mlx::steel::BaseMMAFrag &gt;',['../structmlx_1_1steel_1_1_m_m_a_tile.html',1,'mlx::steel']]],
['mmatile_3c_20float_2c_20tm_2c_201_2c_20mlx_3a_3asteel_3a_3abasemmafrag_20_3e_96',['MMATile&lt; float, TM, 1, mlx::steel::BaseMMAFrag &gt;',['../structmlx_1_1steel_1_1_m_m_a_tile.html',1,'mlx::steel']]],
['mmatile_3c_20float_2c_20tm_2c_20tn_2c_20mlx_3a_3asteel_3a_3abasemmafrag_20_3e_97',['MMATile&lt; float, TM, TN, mlx::steel::BaseMMAFrag &gt;',['../structmlx_1_1steel_1_1_m_m_a_tile.html',1,'mlx::steel']]],
['move_5fshared_5fbuffer_98',['move_shared_buffer',['../classmlx_1_1core_1_1array.html#acce00db63e0f3d80f797b02397ade836',1,'mlx::core::array::move_shared_buffer(array other, const std::vector&lt; size_t &gt; &amp;strides, Flags flags, size_t data_size, size_t offset=0)'],['../classmlx_1_1core_1_1array.html#a38d7ad605f8282e5e49d0c09e0555c78',1,'mlx::core::array::move_shared_buffer(array other)']]],
['moveaxis_99',['moveaxis',['../group__ops.html#ga24067d10a842db2c9d509ea48135a2c3',1,'mlx::core']]],
['mpinplace_100',['MPINPLACE',['../namespacepocketfft_1_1detail.html#af5eedf3cdfc83c0a30807092c39a9ce2',1,'pocketfft::detail']]],
['mtl_5fconst_101',['MTL_CONST',['../defines_8h.html#a767ed9f2604de22b259cee02c4ce1d22',1,'defines.h']]],
['mtl_5fdevice_102',['mtl_device',['../classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653',1,'mlx::core::metal::Device']]],
['mtl_5fresidency_5fset_103',['mtl_residency_set',['../classmlx_1_1core_1_1metal_1_1_residency_set.html#ac4bfe5ef5e2eaebc458a1ed1953d15e9',1,'mlx::core::metal::ResidencySet']]],
['mtlfclist_104',['MTLFCList',['../namespacemlx_1_1core_1_1metal.html#a616e09a1ef321d527770721cef264c54',1,'mlx::core::metal']]],
['mtx_105',['mtx',['../structmlx_1_1core_1_1scheduler_1_1_stream_thread.html#a70410c9e612f871663929f1e8441a976',1,'mlx::core::scheduler::StreamThread']]],
['multi_5fiter_106',['multi_iter',['../classpocketfft_1_1detail_1_1multi__iter.html',1,'pocketfft::detail::multi_iter&lt; N &gt;'],['../classpocketfft_1_1detail_1_1multi__iter.html#a9be43bb18840202da6d17988fccc64b9',1,'pocketfft::detail::multi_iter::multi_iter()']]],
['multiply_107',['Multiply',['../structmlx_1_1core_1_1detail_1_1_multiply.html',1,'mlx::core::detail::Multiply'],['../classmlx_1_1core_1_1_multiply.html',1,'mlx::core::Multiply'],['../struct_multiply.html',1,'Multiply'],['../classmlx_1_1core_1_1_multiply.html#aca5c50f900321f3eb4d6fbcbc225c00c',1,'mlx::core::Multiply::Multiply()']]],
['multiply_108',['multiply',['../group__ops.html#gaf57392e641640b5d06e4c99518391c38',1,'mlx::core']]],
['multivariate_5fnormal_109',['multivariate_normal',['../namespacemlx_1_1core_1_1random.html#a8c37da3c1c0c561cad7499d6d9db81fb',1,'mlx::core::random']]]
];

View File

@ -22,7 +22,7 @@ var searchData=
['all_19',['all',['../group__ops.html#ga3b1b90ef1275ca17655b6d7f25d3ee68',1,'mlx::core::all(const array &amp;a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga3689e12e8f42dadb4cbe2b07dc4099f4',1,'mlx::core::all(const array &amp;a, StreamOrDevice s={})'],['../group__ops.html#gac0919c6ba53aea35a7683dea7e9a9a59',1,'mlx::core::all(const array &amp;a, const std::vector&lt; int &gt; &amp;axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#gae2d5fcc5b62d673cca76c08b7b4afbbc',1,'mlx::core::all(const array &amp;a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
['all_5fgather_20',['all_gather',['../namespacemlx_1_1core_1_1distributed_1_1detail.html#aeb5a1726358213bc75756506f7b54d04',1,'mlx::core::distributed::detail::all_gather()'],['../namespacemlx_1_1core_1_1distributed.html#a82ef5e8cc7ac62cd228e51b1c1b77cb7',1,'mlx::core::distributed::all_gather()']]],
['all_5freduce_21',['all_reduce',['../reduce__all_8h.html#a99ef48ae72b3e715c5f4d7ea07cd213d',1,'reduce_all.h']]],
['all_5freduce_5fdispatch_22',['all_reduce_dispatch',['../namespacemlx_1_1core.html#af7b7ca7c6aa87558d9f98cee5c7a99a8',1,'mlx::core']]],
['all_5freduce_5fdispatch_22',['all_reduce_dispatch',['../namespacemlx_1_1core.html#a3ab0fd997d9a35782106ff083a72e098',1,'mlx::core']]],
['all_5fsum_23',['all_sum',['../namespacemlx_1_1core_1_1distributed_1_1detail.html#aa1d225b25f7b6426c48c5e35860ee960',1,'mlx::core::distributed::detail::all_sum()'],['../namespacemlx_1_1core_1_1distributed.html#a67ccb1a5445fc6f5db49dd36a15e5980',1,'mlx::core::distributed::all_sum()']]],
['allclose_24',['allclose',['../group__ops.html#gaf0cd4257de7542daf9faf5e605e31020',1,'mlx::core']]],
['allgather_25',['AllGather',['../classmlx_1_1core_1_1distributed_1_1_all_gather.html#af4b10a5b61f160fb64353057c185b661',1,'mlx::core::distributed::AllGather']]],

View File

@ -22,5 +22,6 @@ var searchData=
['quantizedmatmul_19',['QuantizedMatmul',['../classmlx_1_1core_1_1_quantized_matmul.html#a5bd164d038d9dc21919f7e0bfdeaa25c',1,'mlx::core::QuantizedMatmul']]],
['quiet_5fnan_20',['quiet_NaN',['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#aebeb07c01984be246bc2d1b8f8e4ac7b',1,'metal::_numeric_limits_impl&lt; bfloat16_t &gt;']]],
['qvm_21',['qvm',['../quantized_8h.html#ad84f7d5ab9e32dbbe3ca759ae5d5d5c5',1,'quantized.h']]],
['qvm_5fimpl_22',['qvm_impl',['../quantized_8h.html#a4a8c8db7d5d480733726fd6d1a645e12',1,'quantized.h']]]
['qvm_5fimpl_22',['qvm_impl',['../quantized_8h.html#a1546533c5b925b2fbb3bec870ec7487a',1,'quantized.h']]],
['qvm_5fsplit_5fk_23',['qvm_split_k',['../quantized_8h.html#ab8243818512d6078d23e6ffb65fd7bb8',1,'quantized.h']]]
];

View File

@ -17,7 +17,7 @@ var searchData=
['scatter_5fprod_14',['scatter_prod',['../group__ops.html#ga3708b5bcb61e2c63d213c4ce6ad0ffc0',1,'mlx::core::scatter_prod(const array &amp;a, const std::vector&lt; array &gt; &amp;indices, const array &amp;updates, const std::vector&lt; int &gt; &amp;axes, StreamOrDevice s={})'],['../group__ops.html#gaf83c53c453faa9083ba27e4b97539339',1,'mlx::core::scatter_prod(const array &amp;a, const array &amp;indices, const array &amp;updates, int axis, StreamOrDevice s={})']]],
['scheduler_15',['Scheduler',['../classmlx_1_1core_1_1scheduler_1_1_scheduler.html#a3ae42aed78a2200e9d02776fcd2316ba',1,'mlx::core::scheduler::Scheduler::Scheduler()'],['../classmlx_1_1core_1_1scheduler_1_1_scheduler.html#a61a74e3628899e66dde600e24a750648',1,'mlx::core::scheduler::Scheduler::Scheduler(const Scheduler &amp;)=delete'],['../classmlx_1_1core_1_1scheduler_1_1_scheduler.html#ac3f77b7c93220dadd0b3bb2e903b7059',1,'mlx::core::scheduler::Scheduler::Scheduler(Scheduler &amp;&amp;)=delete']]],
['scheduler_16',['scheduler',['../namespacemlx_1_1core_1_1scheduler.html#ae856e468c2f7c8f8ec672522cc13730b',1,'mlx::core::scheduler']]],
['sdpa_5fvector_17',['sdpa_vector',['../sdpa__vector_8h.html#a6f0d7918430064bab910bdaa6c64e927',1,'sdpa_vector.h']]],
['sdpa_5fvector_17',['sdpa_vector',['../sdpa__vector_8h.html#a4bf36f16e16c1c62d9b243573568e5ae',1,'sdpa_vector.h']]],
['seed_18',['seed',['../classmlx_1_1core_1_1random_1_1_key_sequence.html#a9f19c5da2031cba50d0ff996924347d8',1,'mlx::core::random::KeySequence::seed()'],['../namespacemlx_1_1core_1_1random.html#ac4ad325b613257306df74595d3d0e23b',1,'mlx::core::random::seed()']]],
['seek_19',['seek',['../structmlx_1_1core_1_1_contiguous_iterator.html#a24719ee9e8667885d29c2ad74445520c',1,'mlx::core::ContiguousIterator::seek()'],['../classmlx_1_1core_1_1io_1_1_reader.html#acea55078bd39ccaa27a9a36f17a39cd1',1,'mlx::core::io::Reader::seek()'],['../classmlx_1_1core_1_1io_1_1_writer.html#a9c1716dda53aa36faea9c8fb1a3e34d4',1,'mlx::core::io::Writer::seek()'],['../classmlx_1_1core_1_1io_1_1_parallel_file_reader.html#a673c16b669f3cee13f387b7b0a1f39f7',1,'mlx::core::io::ParallelFileReader::seek()'],['../classmlx_1_1core_1_1io_1_1_file_writer.html#a9646f4ea048ae58719daeb588e2de433',1,'mlx::core::io::FileWriter::seek()']]],
['select_20',['Select',['../classmlx_1_1core_1_1_select.html#a6f833fe55dd68ad3726bbf9a8f75eec9',1,'mlx::core::Select']]],

View File

@ -18,85 +18,87 @@ var searchData=
['clip_15',['clip',['../group__ops.html#ga157cd7c23f9b306fee2e1eb2b9bf1dd8',1,'mlx::core']]],
['cmplx_16',['cmplx',['../structpocketfft_1_1detail_1_1cmplx.html#a5b1ce506f1023f5254025ac81b831a2c',1,'pocketfft::detail::cmplx::cmplx()'],['../structpocketfft_1_1detail_1_1cmplx.html#a05491b4f1f22ca0bc49012f6a1c1710a',1,'pocketfft::detail::cmplx::cmplx(T r_, T i_)']]],
['cndarr_17',['cndarr',['../classpocketfft_1_1detail_1_1cndarr.html#abf73f1b4ddcfb27d7f85cfa441607129',1,'pocketfft::detail::cndarr']]],
['col_5freduce_5flooped_18',['col_reduce_looped',['../reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385',1,'reduce_col.h']]],
['col_5freduce_5fsmall_19',['col_reduce_small',['../reduce__col_8h.html#adf7aeb18cd1d5042cf6d9b46b582d8ce',1,'reduce_col.h']]],
['collapse_5fcontiguous_5fdims_20',['collapse_contiguous_dims',['../namespacemlx_1_1core.html#a38fe6ec5220d13d96c7dad7556d2b613',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; std::vector&lt; int64_t &gt; &gt; &amp;strides, int64_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#af2895f9b0083efd8221275eb8cadccbe',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; std::vector&lt; size_t &gt; &gt; &amp;strides, size_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#a90e2b6edc0fe82230cb93f5ea39febb4',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; array &gt; &amp;xs, size_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#ac813412cce77fc1340dcfefc6e099276',1,'mlx::core::collapse_contiguous_dims(Arrays &amp;&amp;... xs)'],['../namespacemlx_1_1core.html#aab3cc7f3808934ae0727b920eba231bd',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; int64_t &gt; &amp;strides, int64_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#a1e0cbcf109d32794ffc8efc7302ba9b0',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; size_t &gt; &amp;strides, size_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#a4ee50bfb240512d0c0ce151dfe2c74ef',1,'mlx::core::collapse_contiguous_dims(const array &amp;a, size_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())']]],
['commandencoder_21',['CommandEncoder',['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#a2334774486f447213ee997e55c2e52a3',1,'mlx::core::metal::CommandEncoder::CommandEncoder(MTL::CommandBuffer *cbuf)'],['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#ac68ca977b5bde5434284ce7979647f14',1,'mlx::core::metal::CommandEncoder::CommandEncoder(const CommandEncoder &amp;)=delete']]],
['commit_5fcommand_5fbuffer_22',['commit_command_buffer',['../classmlx_1_1core_1_1metal_1_1_device.html#a95248f1387824067fd4fed23ace5ac0c',1,'mlx::core::metal::Device']]],
['communication_5fstream_23',['communication_stream',['../namespacemlx_1_1core_1_1distributed_1_1detail.html#ac3612edf0e0e18c1e4ba0ce7c6e35cd6',1,'mlx::core::distributed::detail']]],
['compile_24',['compile',['../namespacemlx_1_1core.html#a3ac798e65e59fe10b7fb5c522efce782',1,'mlx::core::compile()'],['../namespacemlx_1_1core_1_1detail.html#ac3b7b09892ff7290d5f3ef26cb444329',1,'mlx::core::detail::compile(const std::function&lt; std::vector&lt; array &gt;(const std::vector&lt; array &gt; &amp;)&gt; &amp;fun, std::uintptr_t fun_id, bool shapeless=false, std::vector&lt; uint64_t &gt; constants={})']]],
['compile_5favailable_5ffor_5fdevice_25',['compile_available_for_device',['../namespacemlx_1_1core_1_1detail.html#aeeff2ba6ec3d9d4ed090de6d2681dbc2',1,'mlx::core::detail']]],
['compile_5fclear_5fcache_26',['compile_clear_cache',['../namespacemlx_1_1core_1_1detail.html#a3fb927c209b946aefebb195993fbe4cf',1,'mlx::core::detail']]],
['compile_5ferase_27',['compile_erase',['../namespacemlx_1_1core_1_1detail.html#a69eb76a14f845ca000f1ccb2edda0175',1,'mlx::core::detail']]],
['compiled_28',['Compiled',['../classmlx_1_1core_1_1_compiled.html#a2d8cefff835c419a48a077d306b8e051',1,'mlx::core::Compiled']]],
['compiled_5fallocate_5foutputs_29',['compiled_allocate_outputs',['../namespacemlx_1_1core.html#ab8c3c4fc05745f586de922c8266f4fce',1,'mlx::core']]],
['compiled_5fcheck_5fcontiguity_30',['compiled_check_contiguity',['../namespacemlx_1_1core.html#a3b900ab319948c5a01a3ecd30a709027',1,'mlx::core']]],
['complex128_5ft_31',['complex128_t',['../structmlx_1_1core_1_1complex128__t.html#aa15d0b805f8790f7c7b76fc7b9d677e0',1,'mlx::core::complex128_t::complex128_t(double v, double u)'],['../structmlx_1_1core_1_1complex128__t.html#abf2842253b874f9f13f39ea68a89e5b6',1,'mlx::core::complex128_t::complex128_t(std::complex&lt; double &gt; v)'],['../structmlx_1_1core_1_1complex128__t.html#a526fba96d7e815360cb4226af085a1bf',1,'mlx::core::complex128_t::complex128_t(T x)']]],
['complex64_5ft_32',['complex64_t',['../structcomplex64__t.html#adbd392a5e92d31997380ad0a38be4be8',1,'complex64_t::complex64_t(float real, float imag)'],['../structcomplex64__t.html#a29782289bb90d6294099667b86509cd3',1,'complex64_t::complex64_t()'],['../structcomplex64__t.html#a905b048d70eb8d748a62454268242291',1,'complex64_t::complex64_t() threadgroup'],['../structcomplex64__t.html#a33a2452eb33b5ed53655773539c357a5',1,'complex64_t::complex64_t(T x) thread'],['../structcomplex64__t.html#a89b65ace8588b7bf215355f705eb23d9',1,'complex64_t::complex64_t(T x) threadgroup'],['../structcomplex64__t.html#ac81b486f642fb3b26c5d659917bdbcd0',1,'complex64_t::complex64_t(T x) device'],['../structcomplex64__t.html#a0a27a41206400f1e62b60ceb56960c93',1,'complex64_t::complex64_t(T x) const ant'],['../structmlx_1_1core_1_1complex64__t.html#a697cc973ae27d63c8e00d830e780bd8c',1,'mlx::core::complex64_t::complex64_t(float v, float u)'],['../structmlx_1_1core_1_1complex64__t.html#ae065e39938f9c4374b4116f4c67d4d09',1,'mlx::core::complex64_t::complex64_t(std::complex&lt; float &gt; v)'],['../structmlx_1_1core_1_1complex64__t.html#a2232cbbe591a9d2bc228cb23fac38b50',1,'mlx::core::complex64_t::complex64_t(T x)']]],
['complex_5fmul_33',['complex_mul',['../radix_8h.html#a5bfc53b531214c9ce277bebc18aa67d6',1,'radix.h']]],
['complex_5fmul_5fconj_34',['complex_mul_conj',['../radix_8h.html#a0e2dfd3d1dda09f47ccc64eec35629f3',1,'radix.h']]],
['compute_5fstrided_5findices_35',['compute_strided_indices',['../struct_read_writer.html#a7c903fbb8b85a856ba5564d7df537cdf',1,'ReadWriter']]],
['concatenate_36',['Concatenate',['../classmlx_1_1core_1_1_concatenate.html#acff07853de2d31faeec7c4ca40ce0888',1,'mlx::core::Concatenate']]],
['concatenate_37',['concatenate',['../group__ops.html#gabdc36fa65697d0361c8d67495de77129',1,'mlx::core::concatenate(const std::vector&lt; array &gt; &amp;arrays, int axis, StreamOrDevice s={})'],['../group__ops.html#gaa95c34ca3a8877f2c50cb60e7fa312b8',1,'mlx::core::concatenate(const std::vector&lt; array &gt; &amp;arrays, StreamOrDevice s={})']]],
['concatenate_5fgpu_38',['concatenate_gpu',['../namespacemlx_1_1core.html#a050299d0d366ca5c9d09d1004dcc3e7d',1,'mlx::core']]],
['concurrentcontext_39',['ConcurrentContext',['../structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html#aee044d7729739c96e845823f9ecc5174',1,'mlx::core::metal::CommandEncoder::ConcurrentContext']]],
['conj_40',['conj',['../namespacepocketfft_1_1detail.html#a66d79051d502046a9b9f103e744dbad3',1,'pocketfft::detail']]],
['conjugate_41',['Conjugate',['../classmlx_1_1core_1_1_conjugate.html#a627f9e6a8729fb3ffb3ca3228d007c87',1,'mlx::core::Conjugate']]],
['conjugate_42',['conjugate',['../group__ops.html#ga5b596906bf8cdc8d97ed6ddc9aeb4c23',1,'mlx::core']]],
['contiguous_5fscan_43',['contiguous_scan',['../scan_8h.html#a60d279b9add7d56639bb209408f09d79',1,'scan.h']]],
['contiguousiterator_44',['ContiguousIterator',['../structmlx_1_1core_1_1_contiguous_iterator.html#a68794af4a442d3d8ac4647817af8e1f6',1,'mlx::core::ContiguousIterator::ContiguousIterator()'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a6cb378408b6f546eeb6ade1a4faafe3c',1,'mlx::core::ContiguousIterator::ContiguousIterator(const array &amp;a)'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a16bdacb53f65b7284068cd49d4cba292',1,'mlx::core::ContiguousIterator::ContiguousIterator(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; StrideT &gt; &amp;strides, int dims)']]],
['conv_45',['conv',['../namespacemlx_1_1core_1_1metal.html#ab1704e853394c725668c06752ebb5c24',1,'mlx::core::metal']]],
['conv1d_46',['conv1d',['../group__ops.html#ga30d47e08093c03a3676f235f9f559411',1,'mlx::core']]],
['conv2d_47',['conv2d',['../group__ops.html#ga73b02833229678786e7f302d458d5a83',1,'mlx::core']]],
['conv2dinputblockloadergeneral_48',['Conv2DInputBlockLoaderGeneral',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_general.html#a1d83af561a483432bf8dcb42e734b23b',1,'mlx::steel::Conv2DInputBlockLoaderGeneral']]],
['conv2dinputblockloaderlargefilter_49',['Conv2DInputBlockLoaderLargeFilter',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_large_filter.html#a8755116a535539744e4947bc69f9c50f',1,'mlx::steel::Conv2DInputBlockLoaderLargeFilter']]],
['conv2dinputblockloadersmallchannels_50',['Conv2DInputBlockLoaderSmallChannels',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_channels.html#ab9fd3fdeab94470dde3326f1dd5c455a',1,'mlx::steel::Conv2DInputBlockLoaderSmallChannels']]],
['conv2dinputblockloadersmallfilter_51',['Conv2DInputBlockLoaderSmallFilter',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_filter.html#a0a2cbf57c51cd928722e3f06aafcf933',1,'mlx::steel::Conv2DInputBlockLoaderSmallFilter']]],
['conv2dweightblockloader_52',['Conv2DWeightBlockLoader',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader.html#a9a7dca3512b64cffb6eac305d795831c',1,'mlx::steel::Conv2DWeightBlockLoader']]],
['conv2dweightblockloadergeneral_53',['Conv2DWeightBlockLoaderGeneral',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_general.html#ad0550fabbdc9297559381a5b488e9af1',1,'mlx::steel::Conv2DWeightBlockLoaderGeneral']]],
['conv2dweightblockloadersmallchannels_54',['Conv2DWeightBlockLoaderSmallChannels',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_small_channels.html#ae1806ea1c19713819dee83a38ab35fa6',1,'mlx::steel::Conv2DWeightBlockLoaderSmallChannels']]],
['conv3d_55',['conv3d',['../group__ops.html#ga6e9907d2f14dc4803e4306b3dbc4b3ca',1,'mlx::core']]],
['conv_5fgeneral_56',['conv_general',['../group__ops.html#ga2236e5dfc7e52e28abf6c21675d0a51e',1,'mlx::core::conv_general(array input, array weight, std::vector&lt; int &gt; stride={}, std::vector&lt; int &gt; padding_lo={}, std::vector&lt; int &gt; padding_hi={}, std::vector&lt; int &gt; kernel_dilation={}, std::vector&lt; int &gt; input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})'],['../group__ops.html#gab59f89942cd1efaadffe9e8762e3c99d',1,'mlx::core::conv_general(const array &amp;input, const array &amp;weight, std::vector&lt; int &gt; stride={}, std::vector&lt; int &gt; padding={}, std::vector&lt; int &gt; kernel_dilation={}, std::vector&lt; int &gt; input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})']]],
['conv_5ftranspose1d_57',['conv_transpose1d',['../group__ops.html#gaa30bf1adcd78d1c2595d07b215731714',1,'mlx::core']]],
['conv_5ftranspose2d_58',['conv_transpose2d',['../group__ops.html#gaebb59971cb9bc45005dc1d398e4f0a3d',1,'mlx::core']]],
['conv_5ftranspose3d_59',['conv_transpose3d',['../group__ops.html#ga8db814da631d9cd32a8d6563bf4ac530',1,'mlx::core']]],
['convolution_60',['Convolution',['../classmlx_1_1core_1_1_convolution.html#a6f1de77b719bb13217b0d8c64cabb8ef',1,'mlx::core::Convolution']]],
['copy_61',['Copy',['../classmlx_1_1core_1_1_copy.html#a6243e044af119105ffaaed7d405cd584',1,'mlx::core::Copy']]],
['copy_62',['copy',['../namespacemlx_1_1core.html#a479648542a2bea151b947b18f0e79dd2',1,'mlx::core::copy()'],['../namespacemlx_1_1core_1_1metal.html#aa215e631e2680f04a591b88d91571719',1,'mlx::core::metal::copy()'],['../group__ops.html#gae306e93af12f774bd80bad6c231b09d6',1,'mlx::core::copy()']]],
['copy_5fg_63',['copy_g',['../metal_2kernels_2copy_8h.html#a778ce2dbfbaa23b24bd5efbe68448c36',1,'copy.h']]],
['copy_5fg_5fnd1_64',['copy_g_nd1',['../metal_2kernels_2copy_8h.html#aba4530a7db6a61ca36f50e4f5e58fb77',1,'copy.h']]],
['copy_5fg_5fnd2_65',['copy_g_nd2',['../metal_2kernels_2copy_8h.html#aee678c7c31119f3e609685589f37490c',1,'copy.h']]],
['copy_5fg_5fnd3_66',['copy_g_nd3',['../metal_2kernels_2copy_8h.html#a821f8f3f3891159a295c66fc25aed1ff',1,'copy.h']]],
['copy_5fgg_67',['copy_gg',['../metal_2kernels_2copy_8h.html#a1e39c2683eeaf05955e7619fbd34aea5',1,'copy.h']]],
['copy_5fgg_5fnd1_68',['copy_gg_nd1',['../metal_2kernels_2copy_8h.html#a3278d9c999718bee3ccbe2922f501bf1',1,'copy.h']]],
['copy_5fgg_5fnd2_69',['copy_gg_nd2',['../metal_2kernels_2copy_8h.html#a3e2d3cc7f34f56170409b6735f51a950',1,'copy.h']]],
['copy_5fgg_5fnd3_70',['copy_gg_nd3',['../metal_2kernels_2copy_8h.html#a59f43b5bffed936d7559ceb06a10aabd',1,'copy.h']]],
['copy_5fgpu_71',['copy_gpu',['../namespacemlx_1_1core.html#addaa46a13ac2deb1d9ce621338320e0e',1,'mlx::core::copy_gpu(const array &amp;src, array &amp;out, CopyType ctype, const Stream &amp;s)'],['../namespacemlx_1_1core.html#a6a6f4e46c8fc44fdc74c50ace02bcf38',1,'mlx::core::copy_gpu(const array &amp;src, array &amp;out, CopyType ctype)']]],
['copy_5fgpu_5finplace_72',['copy_gpu_inplace',['../namespacemlx_1_1core.html#a69e30f5d30a6d72ac0ffe4886f24b7ba',1,'mlx::core::copy_gpu_inplace(const array &amp;in, array &amp;out, const std::vector&lt; int &gt; &amp;data_shape, const std::vector&lt; stride_t &gt; &amp;i_strides, const std::vector&lt; stride_t &gt; &amp;o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype, const Stream &amp;s)'],['../namespacemlx_1_1core.html#a8e1ccb0ed9387b0a789311d9f8964803',1,'mlx::core::copy_gpu_inplace(const array &amp;src, array &amp;out, CopyType ctype, const Stream &amp;s)'],['../namespacemlx_1_1core.html#ae55b801b09ccf55cba96278163a9b1ef',1,'mlx::core::copy_gpu_inplace(const array &amp;in, array &amp;out, const std::vector&lt; int64_t &gt; &amp;istride, int64_t ioffset, CopyType ctype, const Stream &amp;s)']]],
['copy_5fhartley_73',['copy_hartley',['../namespacepocketfft_1_1detail.html#abac3fcc8ce83800d228774f64c28d4c3',1,'pocketfft::detail::copy_hartley(const multi_iter&lt; vlen &gt; &amp;it, const vtype_t&lt; T &gt; *src, ndarr&lt; T &gt; &amp;dst)'],['../namespacepocketfft_1_1detail.html#ae7b44d2773d9d06a9787aff01d66b3ed',1,'pocketfft::detail::copy_hartley(const multi_iter&lt; vlen &gt; &amp;it, const T *src, ndarr&lt; T &gt; &amp;dst)']]],
['copy_5finplace_74',['copy_inplace',['../namespacemlx_1_1core.html#a98495894a796b2cc6d022e7a03432c64',1,'mlx::core::copy_inplace(const array &amp;src, array &amp;dst, CopyType ctype)'],['../namespacemlx_1_1core.html#aad636e2d0b2f882cadd1b438f4daa9ed',1,'mlx::core::copy_inplace(const array &amp;src, array &amp;dst, const std::vector&lt; int &gt; &amp;data_shape, const std::vector&lt; stride_t &gt; &amp;i_strides, const std::vector&lt; stride_t &gt; &amp;o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype)']]],
['copy_5finput_75',['copy_input',['../namespacepocketfft_1_1detail.html#aff05be3064743c1143b19318ab12ad4a',1,'pocketfft::detail::copy_input(const multi_iter&lt; vlen &gt; &amp;it, const cndarr&lt; cmplx&lt; T &gt; &gt; &amp;src, cmplx&lt; vtype_t&lt; T &gt; &gt; *dst)'],['../namespacepocketfft_1_1detail.html#a30fc708f9d8f9cfa74194925c7863c0a',1,'pocketfft::detail::copy_input(const multi_iter&lt; vlen &gt; &amp;it, const cndarr&lt; T &gt; &amp;src, vtype_t&lt; T &gt; *dst)'],['../namespacepocketfft_1_1detail.html#a3387bd35f237870e42b8461769e6aec4',1,'pocketfft::detail::copy_input(const multi_iter&lt; vlen &gt; &amp;it, const cndarr&lt; T &gt; &amp;src, T *dst)']]],
['copy_5foutput_76',['copy_output',['../namespacepocketfft_1_1detail.html#a1523a037300a8da05db210b802d9cb0e',1,'pocketfft::detail::copy_output(const multi_iter&lt; vlen &gt; &amp;it, const cmplx&lt; vtype_t&lt; T &gt; &gt; *src, ndarr&lt; cmplx&lt; T &gt; &gt; &amp;dst)'],['../namespacepocketfft_1_1detail.html#a21980853aca4d92ed06e3dcffe7ef660',1,'pocketfft::detail::copy_output(const multi_iter&lt; vlen &gt; &amp;it, const vtype_t&lt; T &gt; *src, ndarr&lt; T &gt; &amp;dst)'],['../namespacepocketfft_1_1detail.html#a310481c334e46674710ba794ad7403c0',1,'pocketfft::detail::copy_output(const multi_iter&lt; vlen &gt; &amp;it, const T *src, ndarr&lt; T &gt; &amp;dst)']]],
['copy_5fs_77',['copy_s',['../metal_2kernels_2copy_8h.html#aef09f9b9475345b1bba121d037d222ea',1,'copy.h']]],
['copy_5fs2_78',['copy_s2',['../metal_2kernels_2copy_8h.html#a8023e9335cc5334847a8d315042be3a3',1,'copy.h']]],
['copy_5fshared_5fbuffer_79',['copy_shared_buffer',['../classmlx_1_1core_1_1array.html#a28df7a333d90a311c49bc4bce7a1ad6d',1,'mlx::core::array::copy_shared_buffer(const array &amp;other, const std::vector&lt; size_t &gt; &amp;strides, Flags flags, size_t data_size, size_t offset=0)'],['../classmlx_1_1core_1_1array.html#a92974c656c35a972ad241f80584bbd29',1,'mlx::core::array::copy_shared_buffer(const array &amp;other)']]],
['copy_5fv_80',['copy_v',['../metal_2kernels_2copy_8h.html#ae26a13e0c8e6c15f7b10078e65970659',1,'copy.h']]],
['copy_5fv2_81',['copy_v2',['../metal_2kernels_2copy_8h.html#aee14a5326f53d9b30b0b38e27d180ef3',1,'copy.h']]],
['cos_82',['Cos',['../classmlx_1_1core_1_1_cos.html#a2acb9fcf0901462189c476756fd99995',1,'mlx::core::Cos']]],
['cos_83',['cos',['../namespacepocketfft_1_1detail.html#a499c1e8b7d79a5272af024f46c63ff9d',1,'pocketfft::detail::cos()'],['../namespacemetal.html#a2fa4778a6fe2fa43253ea724e5a608a3',1,'metal::cos()'],['../namespacemetal_1_1fast.html#a75b6bb32fa3870eda46a7bfc9f481f88',1,'metal::fast::cos()'],['../namespacemetal_1_1precise.html#ac4941f62e7d8ab9d7cabbd967aa9f220',1,'metal::precise::cos()'],['../group__ops.html#ga39dfdf72b556012aa35ff27a94116e74',1,'mlx::core::cos()']]],
['cosh_84',['Cosh',['../classmlx_1_1core_1_1_cosh.html#a44e8ac2e09a55ec32e9dc6641eedc8f1',1,'mlx::core::Cosh']]],
['cosh_85',['cosh',['../namespacemetal.html#a8a68a88cc110830d057dbd71431b93c0',1,'metal::cosh()'],['../namespacemetal_1_1fast.html#a31544ad9de28012a4ddda86e3966a77e',1,'metal::fast::cosh()'],['../namespacemetal_1_1precise.html#a72d86d508300a9b58f4ccbbe70da4fbc',1,'metal::precise::cosh()'],['../group__ops.html#ga2181b71cda88007a3092be4795ff0715',1,'mlx::core::cosh()']]],
['cospi_86',['cospi',['../namespacemetal.html#a5c2f37939ad705ddea4409d3bedb8ce1',1,'metal::cospi()'],['../namespacemetal_1_1fast.html#a9906b41f75319b384ffb570cc94d67ce',1,'metal::fast::cospi()'],['../namespacemetal_1_1precise.html#a2392b78bd196efdbbac65901c4ab20e7',1,'metal::precise::cospi()']]],
['cost_5fguess_87',['cost_guess',['../structpocketfft_1_1detail_1_1util.html#ad3d874bc3fb0048df2270779a15d4bd0',1,'pocketfft::detail::util']]],
['count_5fdown_88',['count_down',['../classpocketfft_1_1detail_1_1threading_1_1latch.html#a81d6597189b40410e35f3cd653fd1342',1,'pocketfft::detail::threading::latch']]],
['cross_89',['cross',['../namespacemlx_1_1core_1_1linalg.html#abcda3fbda45183c21e7f27aa0dde64e6',1,'mlx::core::linalg']]],
['cummax_90',['cummax',['../group__ops.html#gaee37cac8476e8f8d666bcded5bc59143',1,'mlx::core']]],
['cummin_91',['cummin',['../group__ops.html#ga19c1bf6929fe8d66b9cd408946aea6a8',1,'mlx::core']]],
['cumprod_92',['cumprod',['../group__ops.html#ga0d71dfbc14ef3ed564b0c5ee26af680f',1,'mlx::core']]],
['cumsum_93',['cumsum',['../group__ops.html#gaddc825a5c173e195ab0fda83ad630420',1,'mlx::core']]],
['custom_94',['Custom',['../classmlx_1_1core_1_1fast_1_1_custom.html#a4186fea23f7156c38960426821fca313',1,'mlx::core::fast::Custom']]],
['custom_5ffunction_95',['custom_function',['../namespacemlx_1_1core.html#a8d3ca5fbaecdb995660c24cde5aeebaf',1,'mlx::core']]],
['custom_5fvjp_96',['custom_vjp',['../namespacemlx_1_1core.html#a9290596250fa308df4c69b44483bb8aa',1,'mlx::core']]],
['customkernel_97',['CustomKernel',['../classmlx_1_1core_1_1fast_1_1_custom_kernel.html#a954893e07f0d36715b4e1e414b6f2153',1,'mlx::core::fast::CustomKernel']]],
['customtransforms_98',['CustomTransforms',['../classmlx_1_1core_1_1_custom_transforms.html#ab52abadb9c6f6db83d087c7b751be488',1,'mlx::core::CustomTransforms']]]
['col_5freduce_5f2pass_18',['col_reduce_2pass',['../reduce__col_8h.html#a0e92fc74eeaa8ee2ceb83bafc6eb1d7d',1,'reduce_col.h']]],
['col_5freduce_5flongcolumn_19',['col_reduce_longcolumn',['../reduce__col_8h.html#a5b4f4c4c247ad341ff8d31dcbbbce0eb',1,'reduce_col.h']]],
['col_5freduce_5flooped_20',['col_reduce_looped',['../reduce__col_8h.html#a11bfc6112ae2386ac03f5ea7b7d93385',1,'reduce_col.h']]],
['col_5freduce_5fsmall_21',['col_reduce_small',['../reduce__col_8h.html#a7c378443a2b6f4d9210db8a21a9ac4f5',1,'reduce_col.h']]],
['collapse_5fcontiguous_5fdims_22',['collapse_contiguous_dims',['../namespacemlx_1_1core.html#a38fe6ec5220d13d96c7dad7556d2b613',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; std::vector&lt; int64_t &gt; &gt; &amp;strides, int64_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#af2895f9b0083efd8221275eb8cadccbe',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; std::vector&lt; size_t &gt; &gt; &amp;strides, size_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#a90e2b6edc0fe82230cb93f5ea39febb4',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; array &gt; &amp;xs, size_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#ac813412cce77fc1340dcfefc6e099276',1,'mlx::core::collapse_contiguous_dims(Arrays &amp;&amp;... xs)'],['../namespacemlx_1_1core.html#aab3cc7f3808934ae0727b920eba231bd',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; int64_t &gt; &amp;strides, int64_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#a1e0cbcf109d32794ffc8efc7302ba9b0',1,'mlx::core::collapse_contiguous_dims(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; size_t &gt; &amp;strides, size_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())'],['../namespacemlx_1_1core.html#a4ee50bfb240512d0c0ce151dfe2c74ef',1,'mlx::core::collapse_contiguous_dims(const array &amp;a, size_t size_cap=std::numeric_limits&lt; int32_t &gt;::max())']]],
['commandencoder_23',['CommandEncoder',['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#a2334774486f447213ee997e55c2e52a3',1,'mlx::core::metal::CommandEncoder::CommandEncoder(MTL::CommandBuffer *cbuf)'],['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#ac68ca977b5bde5434284ce7979647f14',1,'mlx::core::metal::CommandEncoder::CommandEncoder(const CommandEncoder &amp;)=delete']]],
['commit_5fcommand_5fbuffer_24',['commit_command_buffer',['../classmlx_1_1core_1_1metal_1_1_device.html#a95248f1387824067fd4fed23ace5ac0c',1,'mlx::core::metal::Device']]],
['communication_5fstream_25',['communication_stream',['../namespacemlx_1_1core_1_1distributed_1_1detail.html#ac3612edf0e0e18c1e4ba0ce7c6e35cd6',1,'mlx::core::distributed::detail']]],
['compile_26',['compile',['../namespacemlx_1_1core.html#a3ac798e65e59fe10b7fb5c522efce782',1,'mlx::core::compile()'],['../namespacemlx_1_1core_1_1detail.html#ac3b7b09892ff7290d5f3ef26cb444329',1,'mlx::core::detail::compile(const std::function&lt; std::vector&lt; array &gt;(const std::vector&lt; array &gt; &amp;)&gt; &amp;fun, std::uintptr_t fun_id, bool shapeless=false, std::vector&lt; uint64_t &gt; constants={})']]],
['compile_5favailable_5ffor_5fdevice_27',['compile_available_for_device',['../namespacemlx_1_1core_1_1detail.html#aeeff2ba6ec3d9d4ed090de6d2681dbc2',1,'mlx::core::detail']]],
['compile_5fclear_5fcache_28',['compile_clear_cache',['../namespacemlx_1_1core_1_1detail.html#a3fb927c209b946aefebb195993fbe4cf',1,'mlx::core::detail']]],
['compile_5ferase_29',['compile_erase',['../namespacemlx_1_1core_1_1detail.html#a69eb76a14f845ca000f1ccb2edda0175',1,'mlx::core::detail']]],
['compiled_30',['Compiled',['../classmlx_1_1core_1_1_compiled.html#a2d8cefff835c419a48a077d306b8e051',1,'mlx::core::Compiled']]],
['compiled_5fallocate_5foutputs_31',['compiled_allocate_outputs',['../namespacemlx_1_1core.html#ab8c3c4fc05745f586de922c8266f4fce',1,'mlx::core']]],
['compiled_5fcheck_5fcontiguity_32',['compiled_check_contiguity',['../namespacemlx_1_1core.html#a3b900ab319948c5a01a3ecd30a709027',1,'mlx::core']]],
['complex128_5ft_33',['complex128_t',['../structmlx_1_1core_1_1complex128__t.html#aa15d0b805f8790f7c7b76fc7b9d677e0',1,'mlx::core::complex128_t::complex128_t(double v, double u)'],['../structmlx_1_1core_1_1complex128__t.html#abf2842253b874f9f13f39ea68a89e5b6',1,'mlx::core::complex128_t::complex128_t(std::complex&lt; double &gt; v)'],['../structmlx_1_1core_1_1complex128__t.html#a526fba96d7e815360cb4226af085a1bf',1,'mlx::core::complex128_t::complex128_t(T x)']]],
['complex64_5ft_34',['complex64_t',['../structcomplex64__t.html#adbd392a5e92d31997380ad0a38be4be8',1,'complex64_t::complex64_t(float real, float imag)'],['../structcomplex64__t.html#a29782289bb90d6294099667b86509cd3',1,'complex64_t::complex64_t()'],['../structcomplex64__t.html#a905b048d70eb8d748a62454268242291',1,'complex64_t::complex64_t() threadgroup'],['../structcomplex64__t.html#a33a2452eb33b5ed53655773539c357a5',1,'complex64_t::complex64_t(T x) thread'],['../structcomplex64__t.html#a89b65ace8588b7bf215355f705eb23d9',1,'complex64_t::complex64_t(T x) threadgroup'],['../structcomplex64__t.html#ac81b486f642fb3b26c5d659917bdbcd0',1,'complex64_t::complex64_t(T x) device'],['../structcomplex64__t.html#a0a27a41206400f1e62b60ceb56960c93',1,'complex64_t::complex64_t(T x) const ant'],['../structmlx_1_1core_1_1complex64__t.html#a697cc973ae27d63c8e00d830e780bd8c',1,'mlx::core::complex64_t::complex64_t(float v, float u)'],['../structmlx_1_1core_1_1complex64__t.html#ae065e39938f9c4374b4116f4c67d4d09',1,'mlx::core::complex64_t::complex64_t(std::complex&lt; float &gt; v)'],['../structmlx_1_1core_1_1complex64__t.html#a2232cbbe591a9d2bc228cb23fac38b50',1,'mlx::core::complex64_t::complex64_t(T x)']]],
['complex_5fmul_35',['complex_mul',['../radix_8h.html#a5bfc53b531214c9ce277bebc18aa67d6',1,'radix.h']]],
['complex_5fmul_5fconj_36',['complex_mul_conj',['../radix_8h.html#a0e2dfd3d1dda09f47ccc64eec35629f3',1,'radix.h']]],
['compute_5fstrided_5findices_37',['compute_strided_indices',['../struct_read_writer.html#a7c903fbb8b85a856ba5564d7df537cdf',1,'ReadWriter']]],
['concatenate_38',['Concatenate',['../classmlx_1_1core_1_1_concatenate.html#acff07853de2d31faeec7c4ca40ce0888',1,'mlx::core::Concatenate']]],
['concatenate_39',['concatenate',['../group__ops.html#gabdc36fa65697d0361c8d67495de77129',1,'mlx::core::concatenate(const std::vector&lt; array &gt; &amp;arrays, int axis, StreamOrDevice s={})'],['../group__ops.html#gaa95c34ca3a8877f2c50cb60e7fa312b8',1,'mlx::core::concatenate(const std::vector&lt; array &gt; &amp;arrays, StreamOrDevice s={})']]],
['concatenate_5fgpu_40',['concatenate_gpu',['../namespacemlx_1_1core.html#a050299d0d366ca5c9d09d1004dcc3e7d',1,'mlx::core']]],
['concurrentcontext_41',['ConcurrentContext',['../structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html#aee044d7729739c96e845823f9ecc5174',1,'mlx::core::metal::CommandEncoder::ConcurrentContext']]],
['conj_42',['conj',['../namespacepocketfft_1_1detail.html#a66d79051d502046a9b9f103e744dbad3',1,'pocketfft::detail']]],
['conjugate_43',['Conjugate',['../classmlx_1_1core_1_1_conjugate.html#a627f9e6a8729fb3ffb3ca3228d007c87',1,'mlx::core::Conjugate']]],
['conjugate_44',['conjugate',['../group__ops.html#ga5b596906bf8cdc8d97ed6ddc9aeb4c23',1,'mlx::core']]],
['contiguous_5fscan_45',['contiguous_scan',['../scan_8h.html#a60d279b9add7d56639bb209408f09d79',1,'scan.h']]],
['contiguousiterator_46',['ContiguousIterator',['../structmlx_1_1core_1_1_contiguous_iterator.html#a68794af4a442d3d8ac4647817af8e1f6',1,'mlx::core::ContiguousIterator::ContiguousIterator()'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a6cb378408b6f546eeb6ade1a4faafe3c',1,'mlx::core::ContiguousIterator::ContiguousIterator(const array &amp;a)'],['../structmlx_1_1core_1_1_contiguous_iterator.html#a16bdacb53f65b7284068cd49d4cba292',1,'mlx::core::ContiguousIterator::ContiguousIterator(const std::vector&lt; int &gt; &amp;shape, const std::vector&lt; StrideT &gt; &amp;strides, int dims)']]],
['conv_47',['conv',['../namespacemlx_1_1core_1_1metal.html#ab1704e853394c725668c06752ebb5c24',1,'mlx::core::metal']]],
['conv1d_48',['conv1d',['../group__ops.html#ga30d47e08093c03a3676f235f9f559411',1,'mlx::core']]],
['conv2d_49',['conv2d',['../group__ops.html#ga73b02833229678786e7f302d458d5a83',1,'mlx::core']]],
['conv2dinputblockloadergeneral_50',['Conv2DInputBlockLoaderGeneral',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_general.html#a1d83af561a483432bf8dcb42e734b23b',1,'mlx::steel::Conv2DInputBlockLoaderGeneral']]],
['conv2dinputblockloaderlargefilter_51',['Conv2DInputBlockLoaderLargeFilter',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_large_filter.html#a8755116a535539744e4947bc69f9c50f',1,'mlx::steel::Conv2DInputBlockLoaderLargeFilter']]],
['conv2dinputblockloadersmallchannels_52',['Conv2DInputBlockLoaderSmallChannels',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_channels.html#ab9fd3fdeab94470dde3326f1dd5c455a',1,'mlx::steel::Conv2DInputBlockLoaderSmallChannels']]],
['conv2dinputblockloadersmallfilter_53',['Conv2DInputBlockLoaderSmallFilter',['../structmlx_1_1steel_1_1_conv2_d_input_block_loader_small_filter.html#a0a2cbf57c51cd928722e3f06aafcf933',1,'mlx::steel::Conv2DInputBlockLoaderSmallFilter']]],
['conv2dweightblockloader_54',['Conv2DWeightBlockLoader',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader.html#a9a7dca3512b64cffb6eac305d795831c',1,'mlx::steel::Conv2DWeightBlockLoader']]],
['conv2dweightblockloadergeneral_55',['Conv2DWeightBlockLoaderGeneral',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_general.html#ad0550fabbdc9297559381a5b488e9af1',1,'mlx::steel::Conv2DWeightBlockLoaderGeneral']]],
['conv2dweightblockloadersmallchannels_56',['Conv2DWeightBlockLoaderSmallChannels',['../structmlx_1_1steel_1_1_conv2_d_weight_block_loader_small_channels.html#ae1806ea1c19713819dee83a38ab35fa6',1,'mlx::steel::Conv2DWeightBlockLoaderSmallChannels']]],
['conv3d_57',['conv3d',['../group__ops.html#ga6e9907d2f14dc4803e4306b3dbc4b3ca',1,'mlx::core']]],
['conv_5fgeneral_58',['conv_general',['../group__ops.html#ga2236e5dfc7e52e28abf6c21675d0a51e',1,'mlx::core::conv_general(array input, array weight, std::vector&lt; int &gt; stride={}, std::vector&lt; int &gt; padding_lo={}, std::vector&lt; int &gt; padding_hi={}, std::vector&lt; int &gt; kernel_dilation={}, std::vector&lt; int &gt; input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})'],['../group__ops.html#gab59f89942cd1efaadffe9e8762e3c99d',1,'mlx::core::conv_general(const array &amp;input, const array &amp;weight, std::vector&lt; int &gt; stride={}, std::vector&lt; int &gt; padding={}, std::vector&lt; int &gt; kernel_dilation={}, std::vector&lt; int &gt; input_dilation={}, int groups=1, bool flip=false, StreamOrDevice s={})']]],
['conv_5ftranspose1d_59',['conv_transpose1d',['../group__ops.html#gaa30bf1adcd78d1c2595d07b215731714',1,'mlx::core']]],
['conv_5ftranspose2d_60',['conv_transpose2d',['../group__ops.html#gaebb59971cb9bc45005dc1d398e4f0a3d',1,'mlx::core']]],
['conv_5ftranspose3d_61',['conv_transpose3d',['../group__ops.html#ga8db814da631d9cd32a8d6563bf4ac530',1,'mlx::core']]],
['convolution_62',['Convolution',['../classmlx_1_1core_1_1_convolution.html#a6f1de77b719bb13217b0d8c64cabb8ef',1,'mlx::core::Convolution']]],
['copy_63',['Copy',['../classmlx_1_1core_1_1_copy.html#a6243e044af119105ffaaed7d405cd584',1,'mlx::core::Copy']]],
['copy_64',['copy',['../namespacemlx_1_1core.html#a479648542a2bea151b947b18f0e79dd2',1,'mlx::core::copy()'],['../namespacemlx_1_1core_1_1metal.html#aa215e631e2680f04a591b88d91571719',1,'mlx::core::metal::copy()'],['../group__ops.html#gae306e93af12f774bd80bad6c231b09d6',1,'mlx::core::copy()']]],
['copy_5fg_65',['copy_g',['../metal_2kernels_2copy_8h.html#a778ce2dbfbaa23b24bd5efbe68448c36',1,'copy.h']]],
['copy_5fg_5fnd1_66',['copy_g_nd1',['../metal_2kernels_2copy_8h.html#aba4530a7db6a61ca36f50e4f5e58fb77',1,'copy.h']]],
['copy_5fg_5fnd2_67',['copy_g_nd2',['../metal_2kernels_2copy_8h.html#aee678c7c31119f3e609685589f37490c',1,'copy.h']]],
['copy_5fg_5fnd3_68',['copy_g_nd3',['../metal_2kernels_2copy_8h.html#a821f8f3f3891159a295c66fc25aed1ff',1,'copy.h']]],
['copy_5fgg_69',['copy_gg',['../metal_2kernels_2copy_8h.html#a1e39c2683eeaf05955e7619fbd34aea5',1,'copy.h']]],
['copy_5fgg_5fnd1_70',['copy_gg_nd1',['../metal_2kernels_2copy_8h.html#a3278d9c999718bee3ccbe2922f501bf1',1,'copy.h']]],
['copy_5fgg_5fnd2_71',['copy_gg_nd2',['../metal_2kernels_2copy_8h.html#a3e2d3cc7f34f56170409b6735f51a950',1,'copy.h']]],
['copy_5fgg_5fnd3_72',['copy_gg_nd3',['../metal_2kernels_2copy_8h.html#a59f43b5bffed936d7559ceb06a10aabd',1,'copy.h']]],
['copy_5fgpu_73',['copy_gpu',['../namespacemlx_1_1core.html#addaa46a13ac2deb1d9ce621338320e0e',1,'mlx::core::copy_gpu(const array &amp;src, array &amp;out, CopyType ctype, const Stream &amp;s)'],['../namespacemlx_1_1core.html#a6a6f4e46c8fc44fdc74c50ace02bcf38',1,'mlx::core::copy_gpu(const array &amp;src, array &amp;out, CopyType ctype)']]],
['copy_5fgpu_5finplace_74',['copy_gpu_inplace',['../namespacemlx_1_1core.html#a69e30f5d30a6d72ac0ffe4886f24b7ba',1,'mlx::core::copy_gpu_inplace(const array &amp;in, array &amp;out, const std::vector&lt; int &gt; &amp;data_shape, const std::vector&lt; stride_t &gt; &amp;i_strides, const std::vector&lt; stride_t &gt; &amp;o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype, const Stream &amp;s)'],['../namespacemlx_1_1core.html#a8e1ccb0ed9387b0a789311d9f8964803',1,'mlx::core::copy_gpu_inplace(const array &amp;src, array &amp;out, CopyType ctype, const Stream &amp;s)'],['../namespacemlx_1_1core.html#ae55b801b09ccf55cba96278163a9b1ef',1,'mlx::core::copy_gpu_inplace(const array &amp;in, array &amp;out, const std::vector&lt; int64_t &gt; &amp;istride, int64_t ioffset, CopyType ctype, const Stream &amp;s)']]],
['copy_5fhartley_75',['copy_hartley',['../namespacepocketfft_1_1detail.html#abac3fcc8ce83800d228774f64c28d4c3',1,'pocketfft::detail::copy_hartley(const multi_iter&lt; vlen &gt; &amp;it, const vtype_t&lt; T &gt; *src, ndarr&lt; T &gt; &amp;dst)'],['../namespacepocketfft_1_1detail.html#ae7b44d2773d9d06a9787aff01d66b3ed',1,'pocketfft::detail::copy_hartley(const multi_iter&lt; vlen &gt; &amp;it, const T *src, ndarr&lt; T &gt; &amp;dst)']]],
['copy_5finplace_76',['copy_inplace',['../namespacemlx_1_1core.html#a98495894a796b2cc6d022e7a03432c64',1,'mlx::core::copy_inplace(const array &amp;src, array &amp;dst, CopyType ctype)'],['../namespacemlx_1_1core.html#aad636e2d0b2f882cadd1b438f4daa9ed',1,'mlx::core::copy_inplace(const array &amp;src, array &amp;dst, const std::vector&lt; int &gt; &amp;data_shape, const std::vector&lt; stride_t &gt; &amp;i_strides, const std::vector&lt; stride_t &gt; &amp;o_strides, int64_t i_offset, int64_t o_offset, CopyType ctype)']]],
['copy_5finput_77',['copy_input',['../namespacepocketfft_1_1detail.html#aff05be3064743c1143b19318ab12ad4a',1,'pocketfft::detail::copy_input(const multi_iter&lt; vlen &gt; &amp;it, const cndarr&lt; cmplx&lt; T &gt; &gt; &amp;src, cmplx&lt; vtype_t&lt; T &gt; &gt; *dst)'],['../namespacepocketfft_1_1detail.html#a30fc708f9d8f9cfa74194925c7863c0a',1,'pocketfft::detail::copy_input(const multi_iter&lt; vlen &gt; &amp;it, const cndarr&lt; T &gt; &amp;src, vtype_t&lt; T &gt; *dst)'],['../namespacepocketfft_1_1detail.html#a3387bd35f237870e42b8461769e6aec4',1,'pocketfft::detail::copy_input(const multi_iter&lt; vlen &gt; &amp;it, const cndarr&lt; T &gt; &amp;src, T *dst)']]],
['copy_5foutput_78',['copy_output',['../namespacepocketfft_1_1detail.html#a1523a037300a8da05db210b802d9cb0e',1,'pocketfft::detail::copy_output(const multi_iter&lt; vlen &gt; &amp;it, const cmplx&lt; vtype_t&lt; T &gt; &gt; *src, ndarr&lt; cmplx&lt; T &gt; &gt; &amp;dst)'],['../namespacepocketfft_1_1detail.html#a21980853aca4d92ed06e3dcffe7ef660',1,'pocketfft::detail::copy_output(const multi_iter&lt; vlen &gt; &amp;it, const vtype_t&lt; T &gt; *src, ndarr&lt; T &gt; &amp;dst)'],['../namespacepocketfft_1_1detail.html#a310481c334e46674710ba794ad7403c0',1,'pocketfft::detail::copy_output(const multi_iter&lt; vlen &gt; &amp;it, const T *src, ndarr&lt; T &gt; &amp;dst)']]],
['copy_5fs_79',['copy_s',['../metal_2kernels_2copy_8h.html#aef09f9b9475345b1bba121d037d222ea',1,'copy.h']]],
['copy_5fs2_80',['copy_s2',['../metal_2kernels_2copy_8h.html#a8023e9335cc5334847a8d315042be3a3',1,'copy.h']]],
['copy_5fshared_5fbuffer_81',['copy_shared_buffer',['../classmlx_1_1core_1_1array.html#a28df7a333d90a311c49bc4bce7a1ad6d',1,'mlx::core::array::copy_shared_buffer(const array &amp;other, const std::vector&lt; size_t &gt; &amp;strides, Flags flags, size_t data_size, size_t offset=0)'],['../classmlx_1_1core_1_1array.html#a92974c656c35a972ad241f80584bbd29',1,'mlx::core::array::copy_shared_buffer(const array &amp;other)']]],
['copy_5fv_82',['copy_v',['../metal_2kernels_2copy_8h.html#ae26a13e0c8e6c15f7b10078e65970659',1,'copy.h']]],
['copy_5fv2_83',['copy_v2',['../metal_2kernels_2copy_8h.html#aee14a5326f53d9b30b0b38e27d180ef3',1,'copy.h']]],
['cos_84',['Cos',['../classmlx_1_1core_1_1_cos.html#a2acb9fcf0901462189c476756fd99995',1,'mlx::core::Cos']]],
['cos_85',['cos',['../namespacepocketfft_1_1detail.html#a499c1e8b7d79a5272af024f46c63ff9d',1,'pocketfft::detail::cos()'],['../namespacemetal.html#a2fa4778a6fe2fa43253ea724e5a608a3',1,'metal::cos()'],['../namespacemetal_1_1fast.html#a75b6bb32fa3870eda46a7bfc9f481f88',1,'metal::fast::cos()'],['../namespacemetal_1_1precise.html#ac4941f62e7d8ab9d7cabbd967aa9f220',1,'metal::precise::cos()'],['../group__ops.html#ga39dfdf72b556012aa35ff27a94116e74',1,'mlx::core::cos()']]],
['cosh_86',['Cosh',['../classmlx_1_1core_1_1_cosh.html#a44e8ac2e09a55ec32e9dc6641eedc8f1',1,'mlx::core::Cosh']]],
['cosh_87',['cosh',['../namespacemetal.html#a8a68a88cc110830d057dbd71431b93c0',1,'metal::cosh()'],['../namespacemetal_1_1fast.html#a31544ad9de28012a4ddda86e3966a77e',1,'metal::fast::cosh()'],['../namespacemetal_1_1precise.html#a72d86d508300a9b58f4ccbbe70da4fbc',1,'metal::precise::cosh()'],['../group__ops.html#ga2181b71cda88007a3092be4795ff0715',1,'mlx::core::cosh()']]],
['cospi_88',['cospi',['../namespacemetal.html#a5c2f37939ad705ddea4409d3bedb8ce1',1,'metal::cospi()'],['../namespacemetal_1_1fast.html#a9906b41f75319b384ffb570cc94d67ce',1,'metal::fast::cospi()'],['../namespacemetal_1_1precise.html#a2392b78bd196efdbbac65901c4ab20e7',1,'metal::precise::cospi()']]],
['cost_5fguess_89',['cost_guess',['../structpocketfft_1_1detail_1_1util.html#ad3d874bc3fb0048df2270779a15d4bd0',1,'pocketfft::detail::util']]],
['count_5fdown_90',['count_down',['../classpocketfft_1_1detail_1_1threading_1_1latch.html#a81d6597189b40410e35f3cd653fd1342',1,'pocketfft::detail::threading::latch']]],
['cross_91',['cross',['../namespacemlx_1_1core_1_1linalg.html#abcda3fbda45183c21e7f27aa0dde64e6',1,'mlx::core::linalg']]],
['cummax_92',['cummax',['../group__ops.html#gaee37cac8476e8f8d666bcded5bc59143',1,'mlx::core']]],
['cummin_93',['cummin',['../group__ops.html#ga19c1bf6929fe8d66b9cd408946aea6a8',1,'mlx::core']]],
['cumprod_94',['cumprod',['../group__ops.html#ga0d71dfbc14ef3ed564b0c5ee26af680f',1,'mlx::core']]],
['cumsum_95',['cumsum',['../group__ops.html#gaddc825a5c173e195ab0fda83ad630420',1,'mlx::core']]],
['custom_96',['Custom',['../classmlx_1_1core_1_1fast_1_1_custom.html#a4186fea23f7156c38960426821fca313',1,'mlx::core::fast::Custom']]],
['custom_5ffunction_97',['custom_function',['../namespacemlx_1_1core.html#a8d3ca5fbaecdb995660c24cde5aeebaf',1,'mlx::core']]],
['custom_5fvjp_98',['custom_vjp',['../namespacemlx_1_1core.html#a9290596250fa308df4c69b44483bb8aa',1,'mlx::core']]],
['customkernel_99',['CustomKernel',['../classmlx_1_1core_1_1fast_1_1_custom_kernel.html#a954893e07f0d36715b4e1e414b6f2153',1,'mlx::core::fast::CustomKernel']]],
['customtransforms_100',['CustomTransforms',['../classmlx_1_1core_1_1_custom_transforms.html#ab52abadb9c6f6db83d087c7b751be488',1,'mlx::core::CustomTransforms']]]
];

View File

@ -44,7 +44,7 @@ var searchData=
['get_5fpool_41',['get_pool',['../namespacepocketfft_1_1detail_1_1threading.html#a7ec2b3f99232bd0f15f7b022c59d139a',1,'pocketfft::detail::threading']]],
['get_5fprimitive_5fstring_42',['get_primitive_string',['../namespacemlx_1_1core.html#ad4be35b310a252edd80d9cf04f094a60',1,'mlx::core']]],
['get_5fquantized_5fkernel_43',['get_quantized_kernel',['../namespacemlx_1_1core.html#aa3faeae5378bfaafe3ce3432a051e43e',1,'mlx::core']]],
['get_5freduce_5finit_5fkernel_44',['get_reduce_init_kernel',['../namespacemlx_1_1core.html#a51c4bb09230348bd0252e22bfdc9bc89',1,'mlx::core']]],
['get_5freduce_5finit_5fkernel_44',['get_reduce_init_kernel',['../namespacemlx_1_1core.html#a3bd386cb6db09f636963ce66ceaf8647',1,'mlx::core']]],
['get_5freduce_5fkernel_45',['get_reduce_kernel',['../namespacemlx_1_1core.html#a7aa91fcfe8b9caa42d60a957f11bfe6b',1,'mlx::core']]],
['get_5freduction_5fplan_46',['get_reduction_plan',['../namespacemlx_1_1core.html#ac97b5a6f009ca3d99854ce9512c20dba',1,'mlx::core']]],
['get_5fscan_5fkernel_47',['get_scan_kernel',['../namespacemlx_1_1core.html#aeefaff208444d3fa61ecc0946fe1de5f',1,'mlx::core']]],

View File

@ -14,39 +14,40 @@ var searchData=
['max3_11',['max3',['../namespacemetal.html#a00f9c0ad66d969794614f56912eed9c9',1,'metal::max3()'],['../namespacemetal_1_1fast.html#a6fc2cf18ffa8149561864c86dba0f803',1,'metal::fast::max3()'],['../namespacemetal_1_1precise.html#ac490e8614ebd2c9343af1ae6c0d4e82c',1,'metal::precise::max3()']]],
['maximum_12',['Maximum',['../classmlx_1_1core_1_1_maximum.html#a28389307e385efe1b2955b86b115e816',1,'mlx::core::Maximum']]],
['maximum_13',['maximum',['../group__ops.html#ga7ade2ea305e2e4219c3609443fb5db8d',1,'mlx::core']]],
['mb_5fblock_5fmerge_14',['mb_block_merge',['../sort_8h.html#ab381cd57f344bc7304ab580bfdc78807',1,'sort.h']]],
['mb_5fblock_5fpartition_15',['mb_block_partition',['../sort_8h.html#a32cbe4163b8b0f5cb2c97b256119a4b2',1,'sort.h']]],
['mb_5fblock_5fsort_16',['mb_block_sort',['../sort_8h.html#aa48ff1aff1e9dc1301b6781aa0721d6b',1,'sort.h']]],
['mean_17',['mean',['../group__ops.html#gade46e768fd46b8b640eb16f26abeecef',1,'mlx::core::mean(const array &amp;a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga52b59fdd8e8430538e564f5bbcfa31e6',1,'mlx::core::mean(const array &amp;a, StreamOrDevice s={})'],['../group__ops.html#ga066161f3d3e395a1d76c638cb680d444',1,'mlx::core::mean(const array &amp;a, const std::vector&lt; int &gt; &amp;axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#ga45fba73eab0e3b6e128ed3ce2f43a5da',1,'mlx::core::mean(const array &amp;a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
['median3_18',['median3',['../namespacemetal.html#aa3ff49457ce3c93fc1c0897fd1525157',1,'metal::median3()'],['../namespacemetal_1_1fast.html#a742b55f1e4369921ee7f60d70185bfbc',1,'metal::fast::median3()'],['../namespacemetal_1_1precise.html#a14555ff99c4388493fec48e070144ae2',1,'metal::precise::median3()']]],
['merge_5fpartition_19',['merge_partition',['../struct_block_merge_sort.html#ab2300cbecb23f3433bad888924c831ca',1,'BlockMergeSort::merge_partition()'],['../struct_kernel_multi_block_merge_sort.html#ab15895b4233aba0e279cc44a07a201fe',1,'KernelMultiBlockMergeSort::merge_partition()']]],
['merge_5fstep_20',['merge_step',['../struct_block_merge_sort.html#ab65f190edf1851b37c39ad49ce99a43c',1,'BlockMergeSort']]],
['meshgrid_21',['meshgrid',['../group__ops.html#ga577c911618575314de63d1060656a26e',1,'mlx::core']]],
['metal_5fkernel_22',['metal_kernel',['../namespacemlx_1_1core_1_1fast.html#ab16436b465dc10ce472193d541d8426e',1,'mlx::core::fast']]],
['min_23',['min',['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#adaed80031f5ca0ff69d30ec4c5d0c98f',1,'metal::_numeric_limits_impl&lt; bfloat16_t &gt;::min()'],['../namespacemetal.html#a6653b28c9473087141eddce39878d4d3',1,'metal::min()'],['../namespacemetal_1_1fast.html#a3e958e56a4712687c381a0b64d123e61',1,'metal::fast::min()'],['../namespacemetal_1_1precise.html#afed0da2f7df3505b5dffa2389c3cb36e',1,'metal::precise::min()'],['../group__ops.html#gab27599802617a4c8f9964ab5f4ffee12',1,'mlx::core::min(const array &amp;a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga0140b91e9cdfc3fef0da8e332f65a9e8',1,'mlx::core::min(const array &amp;a, StreamOrDevice s={})'],['../group__ops.html#ga6efb83cd46436678c8f8c4af15cc00f5',1,'mlx::core::min(const array &amp;a, const std::vector&lt; int &gt; &amp;axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#ga36fa315eef677f4143868f552cd26d03',1,'mlx::core::min(const array &amp;a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
['min3_24',['min3',['../namespacemetal.html#a005510c8c0f964ce2b8aad3ba76a7a3f',1,'metal::min3()'],['../namespacemetal_1_1fast.html#a606a4c1b34ce05ea89ca5af81724036f',1,'metal::fast::min3()'],['../namespacemetal_1_1precise.html#a4d37ce31c3549ca4772a4ee29798e231',1,'metal::precise::min3()']]],
['minimum_25',['Minimum',['../classmlx_1_1core_1_1_minimum.html#ab0f2ce17108df44b82cff68886b0f6f5',1,'mlx::core::Minimum']]],
['minimum_26',['minimum',['../group__ops.html#ga49ba00c090f81f331c91b0c97040bce0',1,'mlx::core']]],
['mlx_5fatomic_5fcompare_5fexchange_5fweak_5fexplicit_27',['mlx_atomic_compare_exchange_weak_explicit',['../atomic_8h.html#ad7f32327ff66354cfa2f0cfdac79316f',1,'mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic&lt; T &gt; *object, thread T *expected, T val, size_t offset):&#160;atomic.h'],['../atomic_8h.html#aa8f47b2e9b95d4b00ad51f08b070deb5',1,'mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic&lt; T &gt; *object, thread uint *expected, uint val, size_t offset):&#160;atomic.h']]],
['mlx_5fatomic_5ffetch_5fadd_5fexplicit_28',['mlx_atomic_fetch_add_explicit',['../atomic_8h.html#aad448d9e06e001700b65ca8317216a3b',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fand_5fexplicit_29',['mlx_atomic_fetch_and_explicit',['../atomic_8h.html#a253e3c870c0ddc7c28ab2f6ca2c3eae5',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmax_5fexplicit_30',['mlx_atomic_fetch_max_explicit',['../atomic_8h.html#ac480f2b459a8ad9095cee353e152d00c',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmax_5fexplicit_3c_20float_20_3e_31',['mlx_atomic_fetch_max_explicit&lt; float &gt;',['../atomic_8h.html#a1dce2abfa16417122c4d2bf261129ae4',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmin_5fexplicit_32',['mlx_atomic_fetch_min_explicit',['../atomic_8h.html#a2ec33dca0039bd944d73d1c2b378cc19',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmin_5fexplicit_3c_20float_20_3e_33',['mlx_atomic_fetch_min_explicit&lt; float &gt;',['../atomic_8h.html#ab7d1dc49f319f239b7ee0b7c72976dd0',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmul_5fexplicit_34',['mlx_atomic_fetch_mul_explicit',['../atomic_8h.html#adfdbea60436f14f1af9ce36e2a0a77a3',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5for_5fexplicit_35',['mlx_atomic_fetch_or_explicit',['../atomic_8h.html#ab7391f197001471e4788312bdb6ab37a',1,'atomic.h']]],
['mlx_5fatomic_5fload_5fexplicit_36',['mlx_atomic_load_explicit',['../atomic_8h.html#a253a4e8c2c5768a069e2791b627dfc99',1,'atomic.h']]],
['mlx_5fatomic_5fstore_5fexplicit_37',['mlx_atomic_store_explicit',['../atomic_8h.html#a0ae453140b0819a4c02f265334de98c0',1,'atomic.h']]],
['mma_38',['mma',['../structmlx_1_1steel_1_1_base_m_m_a_frag_3_01_t_00_018_00_018_01_4.html#a8028512f5a3d2b6acaf966be529627a3',1,'mlx::steel::BaseMMAFrag&lt; T, 8, 8 &gt;::mma(thread frag_type &amp;D, thread frag_type &amp;A, thread frag_type &amp;B, thread frag_type &amp;C)'],['../structmlx_1_1steel_1_1_base_m_m_a_frag_3_01_t_00_018_00_018_01_4.html#a1868f57d57c8adedab2c58492ec76946',1,'mlx::steel::BaseMMAFrag&lt; T, 8, 8 &gt;::mma(thread mat_type &amp;D, thread mat_type &amp;A, thread mat_type &amp;B, thread mat_type &amp;C)'],['../structmlx_1_1steel_1_1_block_m_m_a.html#a6a2c2a6d5e767d52c41b42a9d36086b0',1,'mlx::steel::BlockMMA::mma()']]],
['mmatile_39',['MMATile',['../structmlx_1_1steel_1_1_m_m_a_tile.html#aa3fb310dd08ec23c334511f7b316d1b6',1,'mlx::steel::MMATile']]],
['move_5fshared_5fbuffer_40',['move_shared_buffer',['../classmlx_1_1core_1_1array.html#acce00db63e0f3d80f797b02397ade836',1,'mlx::core::array::move_shared_buffer(array other, const std::vector&lt; size_t &gt; &amp;strides, Flags flags, size_t data_size, size_t offset=0)'],['../classmlx_1_1core_1_1array.html#a38d7ad605f8282e5e49d0c09e0555c78',1,'mlx::core::array::move_shared_buffer(array other)']]],
['moveaxis_41',['moveaxis',['../group__ops.html#ga24067d10a842db2c9d509ea48135a2c3',1,'mlx::core']]],
['mpinplace_42',['MPINPLACE',['../namespacepocketfft_1_1detail.html#af5eedf3cdfc83c0a30807092c39a9ce2',1,'pocketfft::detail']]],
['mtl_5fdevice_43',['mtl_device',['../classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653',1,'mlx::core::metal::Device']]],
['mtl_5fresidency_5fset_44',['mtl_residency_set',['../classmlx_1_1core_1_1metal_1_1_residency_set.html#ac4bfe5ef5e2eaebc458a1ed1953d15e9',1,'mlx::core::metal::ResidencySet']]],
['multi_5fiter_45',['multi_iter',['../classpocketfft_1_1detail_1_1multi__iter.html#a9be43bb18840202da6d17988fccc64b9',1,'pocketfft::detail::multi_iter']]],
['multiply_46',['Multiply',['../classmlx_1_1core_1_1_multiply.html#aca5c50f900321f3eb4d6fbcbc225c00c',1,'mlx::core::Multiply']]],
['multiply_47',['multiply',['../group__ops.html#gaf57392e641640b5d06e4c99518391c38',1,'mlx::core']]],
['multivariate_5fnormal_48',['multivariate_normal',['../namespacemlx_1_1core_1_1random.html#a8c37da3c1c0c561cad7499d6d9db81fb',1,'mlx::core::random']]]
['maybeinsertbarrier_14',['maybeInsertBarrier',['../structmlx_1_1core_1_1metal_1_1_command_encoder.html#ad538ae88f90560063f9ba502e2795991',1,'mlx::core::metal::CommandEncoder']]],
['mb_5fblock_5fmerge_15',['mb_block_merge',['../sort_8h.html#ab381cd57f344bc7304ab580bfdc78807',1,'sort.h']]],
['mb_5fblock_5fpartition_16',['mb_block_partition',['../sort_8h.html#a32cbe4163b8b0f5cb2c97b256119a4b2',1,'sort.h']]],
['mb_5fblock_5fsort_17',['mb_block_sort',['../sort_8h.html#aa48ff1aff1e9dc1301b6781aa0721d6b',1,'sort.h']]],
['mean_18',['mean',['../group__ops.html#gade46e768fd46b8b640eb16f26abeecef',1,'mlx::core::mean(const array &amp;a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga52b59fdd8e8430538e564f5bbcfa31e6',1,'mlx::core::mean(const array &amp;a, StreamOrDevice s={})'],['../group__ops.html#ga066161f3d3e395a1d76c638cb680d444',1,'mlx::core::mean(const array &amp;a, const std::vector&lt; int &gt; &amp;axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#ga45fba73eab0e3b6e128ed3ce2f43a5da',1,'mlx::core::mean(const array &amp;a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
['median3_19',['median3',['../namespacemetal.html#aa3ff49457ce3c93fc1c0897fd1525157',1,'metal::median3()'],['../namespacemetal_1_1fast.html#a742b55f1e4369921ee7f60d70185bfbc',1,'metal::fast::median3()'],['../namespacemetal_1_1precise.html#a14555ff99c4388493fec48e070144ae2',1,'metal::precise::median3()']]],
['merge_5fpartition_20',['merge_partition',['../struct_block_merge_sort.html#ab2300cbecb23f3433bad888924c831ca',1,'BlockMergeSort::merge_partition()'],['../struct_kernel_multi_block_merge_sort.html#ab15895b4233aba0e279cc44a07a201fe',1,'KernelMultiBlockMergeSort::merge_partition()']]],
['merge_5fstep_21',['merge_step',['../struct_block_merge_sort.html#ab65f190edf1851b37c39ad49ce99a43c',1,'BlockMergeSort']]],
['meshgrid_22',['meshgrid',['../group__ops.html#ga577c911618575314de63d1060656a26e',1,'mlx::core']]],
['metal_5fkernel_23',['metal_kernel',['../namespacemlx_1_1core_1_1fast.html#ab16436b465dc10ce472193d541d8426e',1,'mlx::core::fast']]],
['min_24',['min',['../structmetal_1_1__numeric__limits__impl_3_01bfloat16__t_01_4.html#adaed80031f5ca0ff69d30ec4c5d0c98f',1,'metal::_numeric_limits_impl&lt; bfloat16_t &gt;::min()'],['../namespacemetal.html#a6653b28c9473087141eddce39878d4d3',1,'metal::min()'],['../namespacemetal_1_1fast.html#a3e958e56a4712687c381a0b64d123e61',1,'metal::fast::min()'],['../namespacemetal_1_1precise.html#afed0da2f7df3505b5dffa2389c3cb36e',1,'metal::precise::min()'],['../group__ops.html#gab27599802617a4c8f9964ab5f4ffee12',1,'mlx::core::min(const array &amp;a, bool keepdims, StreamOrDevice s={})'],['../group__ops.html#ga0140b91e9cdfc3fef0da8e332f65a9e8',1,'mlx::core::min(const array &amp;a, StreamOrDevice s={})'],['../group__ops.html#ga6efb83cd46436678c8f8c4af15cc00f5',1,'mlx::core::min(const array &amp;a, const std::vector&lt; int &gt; &amp;axes, bool keepdims=false, StreamOrDevice s={})'],['../group__ops.html#ga36fa315eef677f4143868f552cd26d03',1,'mlx::core::min(const array &amp;a, int axis, bool keepdims=false, StreamOrDevice s={})']]],
['min3_25',['min3',['../namespacemetal.html#a005510c8c0f964ce2b8aad3ba76a7a3f',1,'metal::min3()'],['../namespacemetal_1_1fast.html#a606a4c1b34ce05ea89ca5af81724036f',1,'metal::fast::min3()'],['../namespacemetal_1_1precise.html#a4d37ce31c3549ca4772a4ee29798e231',1,'metal::precise::min3()']]],
['minimum_26',['Minimum',['../classmlx_1_1core_1_1_minimum.html#ab0f2ce17108df44b82cff68886b0f6f5',1,'mlx::core::Minimum']]],
['minimum_27',['minimum',['../group__ops.html#ga49ba00c090f81f331c91b0c97040bce0',1,'mlx::core']]],
['mlx_5fatomic_5fcompare_5fexchange_5fweak_5fexplicit_28',['mlx_atomic_compare_exchange_weak_explicit',['../atomic_8h.html#ad7f32327ff66354cfa2f0cfdac79316f',1,'mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic&lt; T &gt; *object, thread T *expected, T val, size_t offset):&#160;atomic.h'],['../atomic_8h.html#aa8f47b2e9b95d4b00ad51f08b070deb5',1,'mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic&lt; T &gt; *object, thread uint *expected, uint val, size_t offset):&#160;atomic.h']]],
['mlx_5fatomic_5ffetch_5fadd_5fexplicit_29',['mlx_atomic_fetch_add_explicit',['../atomic_8h.html#aad448d9e06e001700b65ca8317216a3b',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fand_5fexplicit_30',['mlx_atomic_fetch_and_explicit',['../atomic_8h.html#a253e3c870c0ddc7c28ab2f6ca2c3eae5',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmax_5fexplicit_31',['mlx_atomic_fetch_max_explicit',['../atomic_8h.html#ac480f2b459a8ad9095cee353e152d00c',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmax_5fexplicit_3c_20float_20_3e_32',['mlx_atomic_fetch_max_explicit&lt; float &gt;',['../atomic_8h.html#a1dce2abfa16417122c4d2bf261129ae4',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmin_5fexplicit_33',['mlx_atomic_fetch_min_explicit',['../atomic_8h.html#a2ec33dca0039bd944d73d1c2b378cc19',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmin_5fexplicit_3c_20float_20_3e_34',['mlx_atomic_fetch_min_explicit&lt; float &gt;',['../atomic_8h.html#ab7d1dc49f319f239b7ee0b7c72976dd0',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5fmul_5fexplicit_35',['mlx_atomic_fetch_mul_explicit',['../atomic_8h.html#adfdbea60436f14f1af9ce36e2a0a77a3',1,'atomic.h']]],
['mlx_5fatomic_5ffetch_5for_5fexplicit_36',['mlx_atomic_fetch_or_explicit',['../atomic_8h.html#ab7391f197001471e4788312bdb6ab37a',1,'atomic.h']]],
['mlx_5fatomic_5fload_5fexplicit_37',['mlx_atomic_load_explicit',['../atomic_8h.html#a253a4e8c2c5768a069e2791b627dfc99',1,'atomic.h']]],
['mlx_5fatomic_5fstore_5fexplicit_38',['mlx_atomic_store_explicit',['../atomic_8h.html#a0ae453140b0819a4c02f265334de98c0',1,'atomic.h']]],
['mma_39',['mma',['../structmlx_1_1steel_1_1_base_m_m_a_frag_3_01_t_00_018_00_018_01_4.html#a8028512f5a3d2b6acaf966be529627a3',1,'mlx::steel::BaseMMAFrag&lt; T, 8, 8 &gt;::mma(thread frag_type &amp;D, thread frag_type &amp;A, thread frag_type &amp;B, thread frag_type &amp;C)'],['../structmlx_1_1steel_1_1_base_m_m_a_frag_3_01_t_00_018_00_018_01_4.html#a1868f57d57c8adedab2c58492ec76946',1,'mlx::steel::BaseMMAFrag&lt; T, 8, 8 &gt;::mma(thread mat_type &amp;D, thread mat_type &amp;A, thread mat_type &amp;B, thread mat_type &amp;C)'],['../structmlx_1_1steel_1_1_block_m_m_a.html#a6a2c2a6d5e767d52c41b42a9d36086b0',1,'mlx::steel::BlockMMA::mma()']]],
['mmatile_40',['MMATile',['../structmlx_1_1steel_1_1_m_m_a_tile.html#aa3fb310dd08ec23c334511f7b316d1b6',1,'mlx::steel::MMATile']]],
['move_5fshared_5fbuffer_41',['move_shared_buffer',['../classmlx_1_1core_1_1array.html#acce00db63e0f3d80f797b02397ade836',1,'mlx::core::array::move_shared_buffer(array other, const std::vector&lt; size_t &gt; &amp;strides, Flags flags, size_t data_size, size_t offset=0)'],['../classmlx_1_1core_1_1array.html#a38d7ad605f8282e5e49d0c09e0555c78',1,'mlx::core::array::move_shared_buffer(array other)']]],
['moveaxis_42',['moveaxis',['../group__ops.html#ga24067d10a842db2c9d509ea48135a2c3',1,'mlx::core']]],
['mpinplace_43',['MPINPLACE',['../namespacepocketfft_1_1detail.html#af5eedf3cdfc83c0a30807092c39a9ce2',1,'pocketfft::detail']]],
['mtl_5fdevice_44',['mtl_device',['../classmlx_1_1core_1_1metal_1_1_device.html#a31dba377f2be44a746db10d1b9367653',1,'mlx::core::metal::Device']]],
['mtl_5fresidency_5fset_45',['mtl_residency_set',['../classmlx_1_1core_1_1metal_1_1_residency_set.html#ac4bfe5ef5e2eaebc458a1ed1953d15e9',1,'mlx::core::metal::ResidencySet']]],
['multi_5fiter_46',['multi_iter',['../classpocketfft_1_1detail_1_1multi__iter.html#a9be43bb18840202da6d17988fccc64b9',1,'pocketfft::detail::multi_iter']]],
['multiply_47',['Multiply',['../classmlx_1_1core_1_1_multiply.html#aca5c50f900321f3eb4d6fbcbc225c00c',1,'mlx::core::Multiply']]],
['multiply_48',['multiply',['../group__ops.html#gaf57392e641640b5d06e4c99518391c38',1,'mlx::core']]],
['multivariate_5fnormal_49',['multivariate_normal',['../namespacemlx_1_1core_1_1random.html#a8c37da3c1c0c561cad7499d6d9db81fb',1,'mlx::core::random']]]
];

File diff suppressed because one or more lines are too long

View File

@ -99,13 +99,14 @@ $(function(){ initResizable(false); });
<tr class="even"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a74bcd8e35f80f5a62db48c4a2bb0173e">dispatchThreadgroups</a>(MTL::Size grid_dims, MTL::Size group_dims)</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
<tr class="odd"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a1e41477f2f489e38499f7830a91c9810">dispatchThreads</a>(MTL::Size grid_dims, MTL::Size group_dims)</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
<tr class="even"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a27ded7e54bc1712063c874646b445509">inputs</a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
<tr class="odd"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aac45ab0630ea32cf7d15c7ba3e229966">operator-&gt;</a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
<tr class="even"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a3f42a1362b4a513fa89e7b3dcc570a8e">operator=</a>(const CommandEncoder &amp;)=delete</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
<tr class="odd"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aefa48740fdee884f02e2d379bca4e78f">outputs</a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
<tr class="even"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ab69ff0d7f14b9b59db4df0608193dce4">set_input_array</a>(const array &amp;a, int idx, int64_t offset=0)</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
<tr class="odd"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a6a2e28e542eaa2886041bddd51ff6522">set_output_array</a>(array &amp;a, int idx, int64_t offset=0)</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
<tr class="even"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a48b548a0b15f9d1279c938a1c6167034">start_concurrent</a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
<tr class="odd"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a9b6dd221ccd2d939d544004cb6279198">~CommandEncoder</a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
<tr class="odd"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ad538ae88f90560063f9ba502e2795991">maybeInsertBarrier</a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
<tr class="even"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aac45ab0630ea32cf7d15c7ba3e229966">operator-&gt;</a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
<tr class="odd"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a3f42a1362b4a513fa89e7b3dcc570a8e">operator=</a>(const CommandEncoder &amp;)=delete</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
<tr class="even"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#aefa48740fdee884f02e2d379bca4e78f">outputs</a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
<tr class="odd"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#ab69ff0d7f14b9b59db4df0608193dce4">set_input_array</a>(const array &amp;a, int idx, int64_t offset=0)</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
<tr class="even"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a6a2e28e542eaa2886041bddd51ff6522">set_output_array</a>(array &amp;a, int idx, int64_t offset=0)</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
<tr class="odd"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a48b548a0b15f9d1279c938a1c6167034">start_concurrent</a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"><span class="mlabel">inline</span></td></tr>
<tr class="even"><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html#a9b6dd221ccd2d939d544004cb6279198">~CommandEncoder</a>()</td><td class="entry"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder.html">mlx::core::metal::CommandEncoder</a></td><td class="entry"></td></tr>
</table></div><!-- contents -->
<!-- start footer part -->
<hr class="footer"/><address class="footer"><small>

View File

@ -121,6 +121,8 @@ Public Member Functions</h2></td></tr>
<tr class="separator:a74bcd8e35f80f5a62db48c4a2bb0173e"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a1e41477f2f489e38499f7830a91c9810" id="r_a1e41477f2f489e38499f7830a91c9810"><td class="memItemLeft" align="right" valign="top">void&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#a1e41477f2f489e38499f7830a91c9810">dispatchThreads</a> (MTL::Size grid_dims, MTL::Size group_dims)</td></tr>
<tr class="separator:a1e41477f2f489e38499f7830a91c9810"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:ad538ae88f90560063f9ba502e2795991" id="r_ad538ae88f90560063f9ba502e2795991"><td class="memItemLeft" align="right" valign="top">void&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#ad538ae88f90560063f9ba502e2795991">maybeInsertBarrier</a> ()</td></tr>
<tr class="separator:ad538ae88f90560063f9ba502e2795991"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a48b548a0b15f9d1279c938a1c6167034" id="r_a48b548a0b15f9d1279c938a1c6167034"><td class="memItemLeft" align="right" valign="top"><a class="el" href="structmlx_1_1core_1_1metal_1_1_command_encoder_1_1_concurrent_context.html">ConcurrentContext</a>&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#a48b548a0b15f9d1279c938a1c6167034">start_concurrent</a> ()</td></tr>
<tr class="separator:a48b548a0b15f9d1279c938a1c6167034"><td class="memSeparator" colspan="2">&#160;</td></tr>
<tr class="memitem:a9b6dd221ccd2d939d544004cb6279198" id="r_a9b6dd221ccd2d939d544004cb6279198"><td class="memItemLeft" align="right" valign="top">&#160;</td><td class="memItemRight" valign="bottom"><a class="el" href="#a9b6dd221ccd2d939d544004cb6279198">~CommandEncoder</a> ()</td></tr>
@ -256,6 +258,23 @@ Public Member Functions</h2></td></tr>
</table>
</div><div class="memdoc">
</div>
</div>
<a id="ad538ae88f90560063f9ba502e2795991" name="ad538ae88f90560063f9ba502e2795991"></a>
<h2 class="memtitle"><span class="permalink"><a href="#ad538ae88f90560063f9ba502e2795991">&#9670;&#160;</a></span>maybeInsertBarrier()</h2>
<div class="memitem">
<div class="memproto">
<table class="memname">
<tr>
<td class="memname">void mlx::core::metal::CommandEncoder::maybeInsertBarrier </td>
<td>(</td>
<td class="paramname"><span class="paramname"><em></em></span></td><td>)</td>
<td></td>
</tr>
</table>
</div><div class="memdoc">
</div>
</div>
<a id="aac45ab0630ea32cf7d15c7ba3e229966" name="aac45ab0630ea32cf7d15c7ba3e229966"></a>

View File

@ -986,13 +986,13 @@ We will prioritize including it.</p>
<span class="n">ys</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">100</span><span class="p">,</span> <span class="mi">4096</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">naive_add</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">ys</span><span class="p">):</span>
<span class="k">return</span> <span class="p">[</span><span class="n">xs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">ys</span><span class="p">[:,</span> <span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">xs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">])]</span>
<span class="k">return</span> <span class="p">[</span><span class="n">xs</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">ys</span><span class="p">[:,</span> <span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">xs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])]</span>
</pre></div>
</div>
<p>Instead you can use <a class="reference internal" href="../python/_autosummary/mlx.core.vmap.html#mlx.core.vmap" title="mlx.core.vmap"><code class="xref py py-func docutils literal notranslate"><span class="pre">vmap()</span></code></a> to automatically vectorize the addition:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># Vectorize over the second dimension of x and the</span>
<span class="c1"># first dimension of y</span>
<span class="n">vmap_add</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">vmap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span>
<span class="n">vmap_add</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">vmap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
</pre></div>
</div>
<p>The <code class="docutils literal notranslate"><span class="pre">in_axes</span></code> parameter can be used to specify which dimensions of the

View File

@ -922,7 +922,7 @@ undefined behavior.</p></li>
from the GPU. Performing bounds checking for array indices before launching the
kernel would be extremely inefficient.</p>
<p>Indexing with boolean masks is something that MLX may support in the future. In
general, MLX has limited support for operations for which outputs
general, MLX has limited support for operations for which output
<em>shapes</em> are dependent on input <em>data</em>. Other examples of these types of
operations which MLX does not yet support include <a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.nonzero.html#numpy.nonzero" title="(in NumPy v2.1)"><code class="xref py py-func docutils literal notranslate"><span class="pre">numpy.nonzero()</span></code></a> and the
single input version of <a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.where.html#numpy.where" title="(in NumPy v2.1)"><code class="xref py py-func docutils literal notranslate"><span class="pre">numpy.where()</span></code></a>.</p>

View File

@ -952,7 +952,7 @@ stochastic gradient descent). A natural and usually efficient place to use
</div>
<p>An important behavior to be aware of is when the graph will be implicitly
evaluated. Anytime you <code class="docutils literal notranslate"><span class="pre">print</span></code> an array, convert it to an
<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v2.1)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">numpy.ndarray</span></code></a>, or otherwise access its memory via <a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#memoryview" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">memoryview</span></code></a>,
<a class="reference external" href="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray" title="(in NumPy v2.1)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">numpy.ndarray</span></code></a>, or otherwise access its memory via <a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#memoryview" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">memoryview</span></code></a>,
the graph will be evaluated. Saving arrays via <a class="reference internal" href="../python/_autosummary/mlx.core.save.html#mlx.core.save" title="mlx.core.save"><code class="xref py py-func docutils literal notranslate"><span class="pre">save()</span></code></a> (or any other MLX
saving functions) will also evaluate the array.</p>
<p>Calling <a class="reference internal" href="../python/_autosummary/mlx.core.array.item.html#mlx.core.array.item" title="mlx.core.array.item"><code class="xref py py-func docutils literal notranslate"><span class="pre">array.item()</span></code></a> on a scalar array will also evaluate it. In the